Source code for scine_puffin.jobs.templates.scine_observers

# -*- coding: utf-8 -*-
from __future__ import annotations
__copyright__ = """ This code is licensed under the 3-clause BSD license.
Copyright ETH Zurich, Department of Chemistry and Applied Biosciences, Reiher Group.
See LICENSE.txt for details.
"""

from abc import ABC, abstractmethod
from typing import Any, Dict, TYPE_CHECKING, List

from numpy.random import default_rng

from scine_puffin.utilities.imports import module_exists, requires, MissingDependency

if module_exists("scine_database") or TYPE_CHECKING:
    import scine_database as db
else:
    db = MissingDependency("scine_database")
if module_exists("scine_utilities") or TYPE_CHECKING:
    import scine_utilities as utils
else:
    utils = MissingDependency("scine_utilities")


[docs]class Observer(ABC): """ Abstract base class that defines an observer pattern in a Puffin job in order to observe each calculation. """
[docs] @abstractmethod def gather(self, cycle: int, atoms, results, tag: str) -> None: raise NotImplementedError
[docs] @abstractmethod def finalize(self, db_manager: db.Manager, charge: int, multiplicity: int) -> None: raise NotImplementedError
[docs]class StoreEverythingObserver(Observer): """ Observer implementation that stores every structure and calculated properties in the database. """ def __init__(self, calculation_id: db.ID, model: db.Model) -> None: super().__init__() self.data: List[Dict[str, Any]] = [] self.white_list = ['energy', 'gradients'] self.calculation_id = calculation_id self.model = model
[docs] def gather(self, cycle: int, atoms: utils.AtomCollection, results: utils.Results, tag: str) -> None: tmp = { 'atoms': atoms, 'tag': tag, } for result_str in dir(results): if not result_str.startswith('__'): if getattr(results, result_str) is not None: tmp[result_str] = getattr(results, result_str) self.data.append(tmp)
[docs] @staticmethod @requires('database') def tag_to_label(tag: str) -> db.Label: mapping = { 'geometry_optimization': db.Label.MINIMUM_GUESS, 'ts_optimization': db.Label.TS_GUESS, 'irc_forward': db.Label.ELEMENTARY_STEP_OPTIMIZED, 'irc_backward': db.Label.ELEMENTARY_STEP_OPTIMIZED, 'afir_scan': db.Label.REACTIVE_COMPLEX_SCANNED, 'nt1_scan': db.Label.REACTIVE_COMPLEX_SCANNED, 'nt2_scan': db.Label.REACTIVE_COMPLEX_SCANNED, } return mapping[tag]
[docs] @requires('database') def finalize(self, db_manager: db.Manager, charge: int, multiplicity: int) -> None: has_white_list: bool = (len(self.white_list) > 0) structures = db_manager.get_collection('structures') properties = db_manager.get_collection('properties') for result in self.data: structure = db.Structure.make(result['atoms'], charge, multiplicity, structures) label = StoreEverythingObserver.tag_to_label(result['tag']) structure.set_label(label) for property_name in result: if property_name in ['atoms', 'tag', 'successful_calculation']: continue if (has_white_list and property_name in self.white_list) or not has_white_list: db_name = property_name if property_name == 'energy': db_name = 'electronic_energy' new_prop: db.Property = db.NumberProperty.make(db_name, self.model, result[property_name], properties) elif property_name == 'gradients': new_prop = db.DenseMatrixProperty.make( db_name, self.model, result[property_name], properties) else: continue new_prop.set_structure(structure.get_id()) new_prop.set_calculation(self.calculation_id) structure.add_property(db_name, new_prop.get_id())
[docs]class StoreWithFrequencyObserver(StoreEverythingObserver): """ Observer implementation that stores every nth structure and calculated properties in the database. """ def __init__(self, calculation_id: db.ID, model: db.Model, frequency: float) -> None: super().__init__(calculation_id, model) self.frequency = frequency
[docs] def gather(self, cycle: int, atoms: utils.AtomCollection, results: utils.Results, tag: str) -> None: if self.frequency == 0: return if cycle % self.frequency == 0: super().gather(cycle, atoms, results, tag)
[docs]class StoreWithFractionObserver(StoreEverythingObserver): """ Observer implementation that stores a given fraction of structures and their properties in the database. Which structures are stored is determined at random. """ def __init__(self, calculation_id: db.ID, model: db.Model, fraction: float) -> None: super().__init__(calculation_id, model) self.fraction = fraction self.rng = default_rng()
[docs] def gather(self, cycle: int, atoms: utils.AtomCollection, results: utils.Results, tag: str) -> None: if self.rng.random() < self.fraction: super().gather(cycle, atoms, results, tag)