#!/usr/bin/env python3
# -*- coding: utf-8 -*-
__copyright__ = """ This code is licensed under the 3-clause BSD license.
Copyright ETH Zurich, Department of Chemistry and Applied Biosciences, Reiher Group.
See LICENSE.txt for details.
"""
from abc import abstractmethod, ABC
from collections import defaultdict
from json import dumps
from typing import Callable, Dict, Iterator, List, Set, Optional, Tuple, Union
from warnings import warn
from numpy import ndarray
import scine_database as db
from scine_database.queries import stop_on_timeout
from scine_utilities import ValueCollection
# Local application imports
from scine_chemoton.filters.aggregate_filters import AggregateFilter
from scine_chemoton.filters.reactive_site_filters import ReactiveSiteFilter
from scine_chemoton.utilities.place_holder_model import (
construct_place_holder_model,
PlaceHolderModelType
)
from .trial_generator import TrialGenerator
from .trial_generator.bond_based import BondBased
from .. import Gear, _initialize_a_gear_to_a_db
[docs]class ElementaryStepGear(Gear, ABC):
"""
Base class for elementary step reaction generators
"""
[docs] class Options(Gear.Options):
"""
The options for an ElementarySteps Gear.
"""
__slots__ = (
"_parent",
"enable_unimolecular_trials",
"enable_bimolecular_trials",
"run_one_cycle_with_settings_enhancement",
"base_job_settings",
"structure_model",
"looped_collection"
)
def __init__(self, _parent: Optional[Gear] = None) -> None:
self._parent = _parent
super().__init__()
self.enable_unimolecular_trials = True
"""
bool
If `True`, enables the exploration of unimolecular reactions.
"""
self.enable_bimolecular_trials = True
"""
bool
If `True`, enables the exploration of bimolecular reactions.
"""
self.run_one_cycle_with_settings_enhancement = False
"""
bool
If `True`, enables the enhancement of the settings for the next cycle.
"""
self.base_job_settings: ValueCollection = ValueCollection({})
"""
ValueCollection
The base settings for the jobs. Duplicate keys are overwritten by the settings of the TrialGenerator.
"""
self.structure_model: db.Model = construct_place_holder_model()
"""
Optional[db.Model]
If not None, calculations are only started for structures with the given model.
"""
self.looped_collection: str = "compounds"
"""
str
The collection to loop over. Can be "compounds" or "flasks".
"""
def __setattr__(self, item, value) -> None:
"""
Overwritten standard method to synchronize model option
"""
model_case = bool(
item == "model" and
hasattr(self, "model") and
self.model != value and
hasattr(self, "_parent") and
self._parent is not None and
hasattr(self._parent, "trial_generator")
)
if item == "looped_collection" and value not in ["compounds", "flasks"]:
raise ValueError(f"Invalid value for {item}: '{value}'. Only 'compounds' and 'flasks' are allowed.")
if item == "base_job_settings":
if not isinstance(value, ValueCollection):
raise TypeError(f"The {item} must be a ValueCollection.")
if hasattr(self, "_parent") and self._parent is not None and hasattr(self._parent, "trial_generator"):
if self._parent.trial_generator.options.base_job_settings:
warn("The base job settings of the trial generator are overwritten by the gear.")
self._parent.trial_generator.options.base_job_settings = value
super().__setattr__(item, value)
if model_case:
if not isinstance(self._parent.trial_generator.options.model, PlaceHolderModelType): # type: ignore
warn("The model of the trial generator is overwritten by the gear.")
self._parent.trial_generator.options.model = value # type: ignore
self._parent.clear_cache() # type: ignore
options: Options
def __init__(self) -> None:
super().__init__()
self._required_collections = ["calculations", "compounds", "flasks", "properties", "reactions", "structures"]
self.options = self.Options(_parent=self)
self.trial_generator: TrialGenerator = BondBased()
self.trial_generator.options.base_job_settings = self.options.base_job_settings
self.aggregate_filter: AggregateFilter = AggregateFilter()
self._cache: Set[str] = set()
self._rebuild_cache = True
def __setattr__(self, item, value) -> None:
"""
Overwritten standard method to synchronize model option
"""
super().__setattr__(item, value)
if isinstance(value, TrialGenerator):
if isinstance(self.options.model, PlaceHolderModelType) \
and not isinstance(value.options.model, PlaceHolderModelType):
warn("The model of the gear is overwritten by the given trial generator.")
self.options.model = value.options.model
else:
if not isinstance(value.options.model, PlaceHolderModelType):
warn("The model of the trial generator is overwritten by the gear.")
self.trial_generator.options.model = self.options.model
self.trial_generator._parent = self
if hasattr(self, "aggregate_filter"):
self._check_filters_for_flask_compatibility()
self.clear_cache()
if item == "aggregate_filter":
if not isinstance(value, AggregateFilter):
raise TypeError(f"The {item} must be an AggregateFilter.")
if hasattr(self, "options") and self.options.looped_collection == "flasks" and not value.supports_flasks():
raise ValueError(f"The aggregate filter {value.name} does not support flasks.")
[docs] def clear_cache(self) -> None:
self._cache = set()
self._rebuild_cache = True
[docs] def disable_caching(self) -> None:
self.clear_cache()
self._rebuild_cache = False
[docs] def enable_caching(self) -> None:
self._rebuild_cache = True
[docs] def unimolecular_coordinates(self, credentials: db.Credentials,
observer: Optional[Callable[[], None]] = None) \
-> Dict[str,
Dict[str,
List[Tuple[List[List[Tuple[int, int]]], int]]]]:
"""
Returns the reaction coordinates allowed for unimolecular reactions for the whole database based on
the set options and filters. This method does not set up new calculations.
The returned object is a dictionary of dictionaries containing list of tuple.
The dictionary holds the aggregate IDs, the next dictionary holds then the structures of each aggregate
with the reaction coordinate information.
The first argument in the tuple is a list of reaction coordinates.
The second argument in the tuple is the number of dissociations.
Parameters
----------
credentials : db.Credentials
The credentials of the database.
observer : Optional[Callable[[], None]]
A function that is called after each aggregate to count the number of aggregates processed.
"""
_initialize_a_gear_to_a_db(self, credentials)
return self._internal_loop_impl(setup_calculations=False,
loop_unimolecular=self.options.enable_unimolecular_trials,
loop_bimolecular=False,
observer=observer)[0]
[docs] def bimolecular_coordinates(self, credentials: db.Credentials,
observer: Optional[Callable[[], None]] = None) -> \
Dict[str,
Dict[str,
Dict[Tuple[List[Tuple[int, int]], int],
List[Tuple[ndarray, ndarray, float, float]]
]
]
]:
"""
Returns the reaction coordinates allowed for bimolecular reactions for the whole database based on
the set options and filters. This method does not set up new calculations.
The returned object is a dictionary of dictionaries containing a dictionary specifying the coordinates.
The dictionary holds the aggregate IDs, the next dictionary holds then the structures of each aggregate
with the reaction coordinate information.
The keys are a tuple containing a reaction coordinates and the number of dissociations.
The values hold a list of instructions. Each entry in this list allows to construct a reactive complex.
Therefore, the number of reactive complexes per reaction coordinate can also be inferred.
Notes
-----
The index basis (total system or separate systems) of the returned indices in the reaction coordinates
varies between different TrialGenerator implementations!
Parameters
----------
credentials : db.Credentials
The credentials of the database.
observer : Optional[Callable[[], None]]
A function that is called after each aggregate to count the number of aggregates processed.
"""
_initialize_a_gear_to_a_db(self, credentials)
return self._internal_loop_impl(setup_calculations=False,
loop_unimolecular=False,
loop_bimolecular=self.options.enable_bimolecular_trials,
observer=observer)[1]
def _sanity_check_configuration(self):
if not isinstance(self.aggregate_filter, AggregateFilter):
raise TypeError(f"Expected a AggregateFilter (or a class derived "
f"from it) in {self.name}.aggregate_filter.")
if hasattr(self.trial_generator, 'reactive_site_filter'):
if not isinstance(getattr(self.trial_generator, 'reactive_site_filter'), ReactiveSiteFilter):
raise TypeError(f"Expected a ReactiveSiteFilter (or a class derived "
f"from it) in {self.name}.trial_generator.reactive_site_filter.")
def _propagate_db_manager(self, manager: db.Manager):
self._sanity_check_configuration()
self.trial_generator.initialize_collections(manager)
if hasattr(self, 'aggregate_filter'):
self.aggregate_filter.initialize_collections(manager)
if hasattr(self.trial_generator, 'reactive_site_filter'):
self.trial_generator.reactive_site_filter.initialize_collections(manager)
def _loop_impl(self):
if self.options.run_one_cycle_with_settings_enhancement:
self.clear_cache()
self._internal_loop_impl(setup_calculations=True,
loop_unimolecular=self.options.enable_unimolecular_trials,
loop_bimolecular=self.options.enable_bimolecular_trials)
self.options.run_one_cycle_with_settings_enhancement = False
def _internal_loop_impl(self, setup_calculations: bool, loop_unimolecular: bool, loop_bimolecular: bool,
observer: Optional[Callable[[], None]] = None) \
-> Tuple[Dict[str,
Dict[str,
List[Tuple[List[List[Tuple[int, int]]], int]]
]
],
Dict[str,
Dict[str,
Dict[Tuple[List[Tuple[int, int]], int],
List[Tuple[ndarray, ndarray, float, float]]
]
]
]
]:
if self.options.model != self.trial_generator.options.model:
raise TypeError(f"Elementary step gear {self.name} and trial generator "
f"{self.trial_generator.__class__.__name__} have diverging models")
uni_result: Dict[str,
Dict[str,
List[Tuple[List[List[Tuple[int, int]]], int]]]] \
= defaultdict(lambda: defaultdict(list))
bi_result: Dict[str,
Dict[str,
Dict[Tuple[List[Tuple[int, int]], int],
List[Tuple[ndarray, ndarray, float, float]]
]
]
] \
= defaultdict(lambda: defaultdict(dict))
# Loop over all aggregates
collection, iterator = self._get_collection_iterator()
for aggregate_one in stop_on_timeout(iterator):
aggregate_one.link(collection)
if self.stop_at_next_break_point:
return {}, {}
if observer is not None:
observer()
eligible_sid_one = None
if loop_unimolecular and self.aggregate_filter.filter(aggregate_one):
eligible_sid_one = sorted(self._get_eligible_structures(aggregate_one))
for sid_one in eligible_sid_one:
if self.stop_at_next_break_point:
return {}, {}
if not self.options.run_one_cycle_with_settings_enhancement and sid_one.string() in self._cache:
continue
structure_one = db.Structure(sid_one, self._structures)
if not self._check_structure_model(structure_one):
continue
if self._rebuild_cache:
if sid_one.string() not in self._cache:
self._update_cache(
structure_one,
self.trial_generator.get_unimolecular_job_order(),
self.trial_generator.options.model
)
if not self.options.run_one_cycle_with_settings_enhancement and sid_one.string() in self._cache:
continue
if setup_calculations:
self.trial_generator.unimolecular_reactions(
structure_one, self.options.run_one_cycle_with_settings_enhancement)
else:
uni_result[str(aggregate_one.id())][str(sid_one)] = \
self.trial_generator.unimolecular_coordinates(
structure_one, self.options.run_one_cycle_with_settings_enhancement)
self._cache.add(sid_one.string())
# Get intermolecular reaction partners
if not loop_bimolecular:
continue
if eligible_sid_one is None:
eligible_sid_one = sorted(self._get_eligible_structures(aggregate_one))
if not eligible_sid_one:
continue
c_id_one = aggregate_one.id().string()
_, second_iterator = self._get_collection_iterator()
for aggregate_two in stop_on_timeout(second_iterator):
aggregate_two.link(collection)
if self.stop_at_next_break_point:
return {}, {}
# Make this loop run lower triangular + diagonal only
c_id_two = aggregate_two.id().string()
sorted_ids = sorted([c_id_one, c_id_two])
# Second criterion needed to not exclude diagonal
if sorted_ids[0] == c_id_two and c_id_one != c_id_two:
continue
# Filter
if not self.aggregate_filter.filter(aggregate_one, aggregate_two):
continue
eligible_sid_two = sorted(self._get_eligible_structures(aggregate_two))
if not eligible_sid_two:
continue
same_compounds = c_id_one == c_id_two
for i, sid_one in enumerate(eligible_sid_one):
for j, sid_two in enumerate(eligible_sid_two):
if self.stop_at_next_break_point:
return {}, {}
if same_compounds and j > i:
break
joined_ids = ';'.join(sorted([sid_one.string(), sid_two.string()]))
if not self.options.run_one_cycle_with_settings_enhancement and joined_ids in self._cache:
continue
structure_one = db.Structure(sid_one, self._structures)
structure_two = db.Structure(sid_two, self._structures)
if not self._check_structure_model(structure_one) or \
not self._check_structure_model(structure_two):
continue
if self._rebuild_cache:
if i == 0 and (sid_one.string() not in self._cache):
self._update_cache(
structure_one,
self.trial_generator.get_bimolecular_job_order(),
self.trial_generator.options.model
)
if sid_two.string() not in self._cache:
self._update_cache(
structure_two,
self.trial_generator.get_bimolecular_job_order(),
self.trial_generator.options.model
)
if not self.options.run_one_cycle_with_settings_enhancement and joined_ids in self._cache:
continue
if setup_calculations:
self.trial_generator.bimolecular_reactions(
[structure_one, structure_two], self.options.run_one_cycle_with_settings_enhancement)
else:
# split to make more readable
compound_key = f"{str(aggregate_one.id())}-{str(aggregate_two.id())}"
structure_key = f"{str(sid_one)}-{str(sid_two)}"
bi_result[compound_key][structure_key] = \
self.trial_generator.bimolecular_coordinates(
[structure_one, structure_two],
self.options.run_one_cycle_with_settings_enhancement)
self._cache.add(joined_ids)
if self._rebuild_cache:
self._rebuild_cache = False
return uni_result, bi_result
def _update_cache(self, structure: db.Structure, job_order: str, model: db.Model) -> None:
calc_ids = structure.get_calculations(job_order)
if not calc_ids:
return
for calc_id in calc_ids:
calculation = db.Calculation(calc_id, self._calculations)
if calculation.get_model() != model:
continue
structures_in_calc_ids = calculation.get_structures()
joined_ids = ';'.join(sorted([s.string() for s in structures_in_calc_ids]))
self._cache.add(joined_ids)
@abstractmethod
def _get_eligible_structures(self, aggregate: Union[db.Compound, db.Flask]) -> List[db.ID]:
pass
def _check_structure_model(self, structure: db.Structure) -> bool:
# Only check the model if the option for the structure model is checked
if not isinstance(self.options.structure_model, PlaceHolderModelType):
return self.options.structure_model == structure.get_model()
return True
def _get_collection_iterator(self) -> Tuple[db.Collection, Iterator[Union[db.Compound, db.Flask]]]:
selection = {"exploration_disabled": {"$ne": True}}
if self.options.looped_collection == "compounds":
return self._compounds, self._compounds.iterate_compounds(dumps(selection))
if self.options.looped_collection == "flasks":
self._check_filters_for_flask_compatibility()
return self._flasks, self._flasks.iterate_flasks(dumps(selection))
raise ValueError(f"Invalid value for looped_collection: '{self.options.looped_collection}'. "
f"Only 'compounds' and 'flasks' are allowed.")
def _check_filters_for_flask_compatibility(self) -> None:
"""
Checks if the aggregate filter and the trial generator are compatible with flasks.
Raises
------
ValueError
If the aggregate filter or the trial generator are not compatible with flasks.
"""
if not self.aggregate_filter.supports_flasks():
raise ValueError(f"The aggregate filter {self.aggregate_filter.name} does not support flasks.")
if not self.trial_generator.reactive_site_filter.supports_flasks():
raise ValueError(f"The reactive site filter {self.trial_generator.reactive_site_filter.name} does not "
f"support flasks.")
further_filter = getattr(self.trial_generator, "further_reactive_site_filter", None)
if isinstance(further_filter, ReactiveSiteFilter) and not further_filter.supports_flasks():
raise ValueError(f"The further reactive site filter "
f"{getattr(self.trial_generator, 'further_reactive_site_filter').name} "
f"does not support flasks.")