Source code for scine_chemoton.steering_wheel.selections

#!/usr/bin/env python3
from __future__ import annotations
# -*- coding: utf-8 -*-
__copyright__ = """ This code is licensed under the 3-clause BSD license.
Copyright ETH Zurich, Department of Chemistry and Applied Biosciences, Reiher Group.
See LICENSE.txt for details.
"""
from abc import abstractmethod, ABC, ABCMeta
from functools import wraps
from typing import Callable, List, Optional, Set, Union

import scine_database as db
from scine_database.energy_query_functions import get_barriers_for_elementary_step_by_type

from scine_chemoton.filters.aggregate_filters import (
    AggregateFilter,
    AggregateFilterAndArray,
    AggregateFilterOrArray
)
from scine_chemoton.filters.reactive_site_filters import (
    ReactiveSiteFilter,
    ReactiveSiteFilterAndArray,
    ReactiveSiteFilterOrArray,
)
from scine_chemoton.filters.further_exploration_filters import (
    FurtherExplorationFilterAndArray,
    FurtherExplorationFilterOrArray,
)
from scine_chemoton.utilities import connect_to_db
from ..datastructures import (
    NetworkExpansionResult,
    SelectionResult,
    ExplorationSchemeStep,
    Status,
    LogicCoupling,
    NoRestartInfoPresent,
    RestartPartialExpansionInfo
)


def status_wrap(fun: Callable):
    """
    Decorator to wrap a function and set the status of the class to calculating while the function is running
    and to finished after the function has finished.
    """
    @wraps(fun)
    def _impl(self, *args, **kwargs):
        self.status = Status.CALCULATING
        result = fun(self, *args, **kwargs)
        if self.status != Status.FAILED:
            self.status = Status.FINISHED
        return result

    return _impl


[docs]class Selection(ExplorationSchemeStep, metaclass=ABCMeta): """ The base class for selecting aggregates, individual structures, and/or reactive sites. It specifies the common __call__ execution and holds 1 abstract methods that must be implemented by each implementation. Additionally, it holds some common functionalities for querying to simplify future implementation of new selections. """ options: Selection.Options def __init__(self, model: db.Model, # pylint: disable=keyword-arg-before-vararg additional_aggregate_filters: Optional[List[AggregateFilter]] = None, additional_reactive_site_filters: Optional[List[ReactiveSiteFilter]] = None, logic_coupling: Union[str, LogicCoupling] = LogicCoupling.AND, *args, **kwargs ) -> None: """ Initialize the selection with the given parameters. Parameters ---------- model : db.Model The model to use for the selection. additional_aggregate_filters : Optional[List[AggregateFilter]], optional An optional list of aggregate filters to further limit selection. They are combined by an 'and' logic step. By default, None. additional_reactive_site_filters : Optional[List[ReactiveSiteFilter]], optional An optional list of reactive site filters to further limit selection. They are combined by an 'and' logic step. By default, None logic_coupling : Union[str, LogicCoupling], optional Define how this selection may be coupled together with other selections, by default LogicCoupling.AND """ super().__init__(model, *args, **kwargs) self._add_aggregate_filter = additional_aggregate_filters if additional_aggregate_filters is not None else [] self._add_site_filter = additional_reactive_site_filters if additional_reactive_site_filters is not None else [] self._step_result: Optional[NetworkExpansionResult] = None if isinstance(logic_coupling, str): self.logic_coupling = LogicCoupling(logic_coupling.replace("LogicCoupling.", "").lower()) else: self.logic_coupling = logic_coupling self._result: Optional[SelectionResult] = None
[docs] def get_step_result(self) -> NetworkExpansionResult: """ Get the result of the previous expansion step. Returns ------- NetworkExpansionResult The result of the previous expansion step. Raises ------ RuntimeError If the selection did not receive a previous step result. """ if self._step_result is None: raise RuntimeError(f"The selection {self.name} did not receive a previous step result, " f"but wanted to access it") return self._step_result
[docs] def set_step_result(self, step_result: Optional[NetworkExpansionResult]) -> None: self._step_result = step_result
def _not_implemented_arguments_sanity_check( self, notify_partial_steps_callback: Optional[Callable[[Union[NoRestartInfoPresent, RestartPartialExpansionInfo]], None]] = None, restart_information: Optional[RestartPartialExpansionInfo] = None) -> None: if notify_partial_steps_callback is not None: raise NotImplementedError("The notify_partial_steps_callback is not implemented for Selections.") if restart_information is not None: raise NotImplementedError("The restart_information is not implemented for Selections.") @status_wrap def __call__(self, credentials: db.Credentials, step_result: Optional[NetworkExpansionResult] = None, notify_partial_steps_callback: Optional[ Callable[[Union[ NoRestartInfoPresent, RestartPartialExpansionInfo]], None]] = None, restart_information: Optional[RestartPartialExpansionInfo] = None) \ -> SelectionResult: self._not_implemented_arguments_sanity_check(notify_partial_steps_callback, restart_information) manager = connect_to_db(credentials) self.initialize_collections(manager) if step_result is not None: self.set_step_result(step_result) try: self._result = self._select() except BaseException as e: print(f"Selection {self.name} failed with error {e}") self.status = Status.FAILED raise e if self._add_aggregate_filter: self._result.aggregate_filter = AggregateFilterAndArray(self._add_aggregate_filter + [self._result.aggregate_filter]) if self._add_site_filter: self._result.reactive_site_filter = ReactiveSiteFilterAndArray(self._add_site_filter + [self._result.reactive_site_filter]) return self._result
[docs] def lowest_barrier_per_reaction(self, step_result: NetworkExpansionResult, energy_type: str) -> List[Optional[float]]: """ Convenience method to get the lowest barrier for each reaction in the given step result. Parameters ---------- step_result : NetworkExpansionResult The step result to get the lowest barrier for. energy_type : str The energy type to use for the barrier lookup. Returns ------- List[Optional[float]] A list of the lowest barrier for each reaction in the given step result. """ barriers: List[Optional[float]] = [] for rid in step_result.reactions: reaction = db.Reaction(rid, self._reactions) step_barriers = [] for sid in reaction.get_elementary_steps(): step = db.ElementaryStep(sid, self._elementary_steps) _, rhs = get_barriers_for_elementary_step_by_type(step, energy_type, self.options.model, self._structures, self._properties) if rhs is not None: step_barriers.append(rhs) barriers.append(min(step_barriers) if step_barriers else None) return barriers
[docs] def get_result(self) -> Optional[SelectionResult]: return self._result
[docs] def set_result(self, result: Optional[SelectionResult]): # type: ignore[override] self._result = result
@abstractmethod def _select(self) -> SelectionResult: """ Abstract method to be implemented by each selection. The handling of the database initialization and the addition of the given additional filters is already covered by the base class. Returns ------- SelectionResult The result of the selection. """
class SafeFirstSelection(Selection, metaclass=ABCMeta): """ A selection that is safe to use as the first selection in a network expansion. This is still an abstract class and should not be used directly. Instead, classes that can be used as the first selection should inherit from this class. """ def get_step_result(self) -> NetworkExpansionResult: """ Ensures that the step_result member is not accessed. """ raise PermissionError(f"The class {self.__class__.__name__} may not access the step_result member,\n" f"because it inherits from 'SafeFirstSelection' and must therefore give a\n" f"selection without a previous step result.") def set_step_result(self, step_result: Union[NetworkExpansionResult, None]) -> None: pass class AllCompoundsSelection(SafeFirstSelection): """ Most basic selection that selects all compounds in the database, which is defined by an empty result because the default filters do not filter anything. However, the structures are empty to avoid transferring huge lists. """ def _select(self) -> SelectionResult: return SelectionResult() class PredeterminedSelection(SafeFirstSelection): """ A selection that is predetermined by the user. """ options: PredeterminedSelection.Options def __init__(self, model: db.Model, result: SelectionResult, # pylint: disable=keyword-arg-before-vararg additional_aggregate_filters: Optional[List[AggregateFilter]] = None, additional_reactive_site_filters: Optional[List[ReactiveSiteFilter]] = None, logic_coupling: Union[str, LogicCoupling] = LogicCoupling.AND, *args, **kwargs) -> None: super().__init__(model, additional_aggregate_filters, additional_reactive_site_filters, logic_coupling, *args, **kwargs) self._result = result def _select(self) -> SelectionResult: assert self._result is not None return self._result class _MultipleSelections(Selection, ABC): """ A base class for the classes that combine multiple selections. Notes ----- * It requires to be initialized with a list of selections. * It receives its model from the first selection in the list. * It is not possible to use this class directly, but it should be inherited from. """ options: _MultipleSelections.Options def __init__(self, selections: List[Selection], *args, **kwargs) -> None: if not selections: raise TypeError(f"Cannot give empty list of selections to {self.__class__.__name__}") super().__init__(selections[0].options.model, *args, **kwargs) self.selections = selections def initialize_collections(self, manager: db.Manager) -> None: super().initialize_collections(manager) for selection in self.selections: selection.initialize_collections(manager) def _call_selections(self, credentials: db.Credentials, step_result: Optional[NetworkExpansionResult] = None) \ -> List[SelectionResult]: """ Handles the call for each selection and their combination based on the fact whether they are safe to use with a previous expansion result or not. Parameters ---------- credentials : db.Credentials The credentials to the database we are selecting from step_result : Optional[NetworkExpansionResult] The optional previous network expansion result Returns ------- List[SelectionResult] The results of each individual selection. Their logical combination is handled by the inheriting class. """ manager = connect_to_db(credentials) self.initialize_collections(manager) results = [sele(credentials, step_result) for sele in self.selections if not isinstance(sele, SafeFirstSelection)] results += [sele(credentials) for sele in self.selections if isinstance(sele, SafeFirstSelection)] return results def _gather_structures_from_filter(self, aggregate_filter: AggregateFilter, excluded_aggregates: Set[str]) -> List[db.ID]: res = [] # TODO extend to flasks if flasks relevant and all aggregate filters safe for flasks for compound in self._compounds.iterate_all_compounds(): compound.link(self._compounds) if str(compound.id()) in excluded_aggregates or not aggregate_filter.filter(compound): continue res += compound.get_structures() return res def __len__(self) -> int: return len(self.selections) def __iter__(self): return (s for s in self.selections) def __setitem__(self, key, value): self.selections[key] = value
[docs]class SelectionAndArray(_MultipleSelections): """ Combines multiple selections with a logical AND. """ options: SelectionAndArray.Options def __init__(self, selections: List[Selection], *args, **kwargs) -> None: super().__init__(selections, *args, **kwargs) self.name = "AndSelection[" + ("-".join(s.name for s in selections)) + "]" @status_wrap def __call__(self, credentials: db.Credentials, step_result: Optional[NetworkExpansionResult] = None, notify_partial_steps_callback: Optional[ Callable[[Union[NoRestartInfoPresent, RestartPartialExpansionInfo]], None]] = None, restart_information: Optional[RestartPartialExpansionInfo] = None) \ -> SelectionResult: self._not_implemented_arguments_sanity_check(notify_partial_steps_callback, restart_information) results = self._call_selections(credentials, step_result) structures = self._combine_structures(results) return SelectionResult( AggregateFilterAndArray([r.aggregate_filter for r in results]), ReactiveSiteFilterAndArray([r.reactive_site_filter for r in results]), FurtherExplorationFilterAndArray([r.further_exploration_filter for r in results]), [db.ID(s) for s in structures] ) def _select(self) -> SelectionResult: raise NotImplementedError @staticmethod def _combine_structures(results: List[SelectionResult]) -> Set[str]: structures = [set(str(ss) for ss in r.structures) for r in results] intersect_structures = set.intersection(*structures) return intersect_structures
[docs]class SelectionOrArray(_MultipleSelections): """ Combines multiple selections with a logical OR. """ options: SelectionOrArray.Options def __init__(self, selections: List[Selection], *args, **kwargs) -> None: super().__init__(selections, *args, **kwargs) self.name = "OrSelection[" + ("-".join(s.name for s in selections)) + "]" @status_wrap def __call__(self, credentials: db.Credentials, step_result: Optional[NetworkExpansionResult] = None, notify_partial_steps_callback: Optional[ Callable[[Union[ NoRestartInfoPresent, RestartPartialExpansionInfo]], None]] = None, restart_information: Optional[RestartPartialExpansionInfo] = None) \ -> SelectionResult: self._not_implemented_arguments_sanity_check(notify_partial_steps_callback, restart_information) results = self._call_selections(credentials, step_result) structures = self._combine_structures(results) return SelectionResult( AggregateFilterOrArray([r.aggregate_filter for r in results]), ReactiveSiteFilterOrArray([r.reactive_site_filter for r in results]), FurtherExplorationFilterOrArray([r.further_exploration_filter for r in results]), [db.ID(s) for s in structures] ) def _select(self) -> SelectionResult: raise NotImplementedError def _combine_structures(self, results: List[SelectionResult]) -> Set[str]: # if one result does not specify structures # gather structures based on the aggregate filter no_structure_results = [r for r in results if not r.structures] if no_structure_results: covered_aggregates = set() for r in results: if not r.structures: continue for s in r.structures: structure = db.Structure(s, self._structures) if structure.has_aggregate(): covered_aggregates.add(str(structure.get_aggregate())) agg_filter = AggregateFilterOrArray([r.aggregate_filter for r in no_structure_results]) structures = self._gather_structures_from_filter(agg_filter, covered_aggregates) else: structures = [] for r in results: structures += r.structures return set(str(ss) for ss in structures)