Source code for scine_art.database

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
__copyright__ = """ This code is licensed under the 3-clause BSD license.
Copyright ETH Zurich, Department of Chemistry and Applied Biosciences, Reiher Group.
See LICENSE.txt for details.
"""


import scine_molassembler as masm

from scine_art.reaction_template import ReactionTemplate, Match
from scine_art.io import write_molecule_to_svg


from typing import List, Optional, Any, Generator
from copy import deepcopy
import pickle
import svgutils
import os


[docs]class ReactionTemplateDatabase: """A database of reaction templates.""" def __init__(self) -> None: self.__unique_templates: List[ReactionTemplate] = []
[docs] def get_template(self, uuid: str) -> Optional[ReactionTemplate]: """Getter for a single reaction template. Parameters ---------- uuid : str The unique ID of a reaction template. Returns ------- Optional[ReactionTemplate] The reaction template with the given unique ID if present. """ for template in self.__unique_templates: if template.get_uuid() == uuid: return deepcopy(template) return None
[docs] def iterate_templates(self) -> Generator[ReactionTemplate, None, None]: """Iterator through all stored reaction templates. Yields ------ Generator[ReactionTemplate, None, None] Iterator of all reaction templates in the database. """ for template in self.__unique_templates: yield deepcopy(template)
[docs] def add_template( self, new_template: ReactionTemplate ) -> str: """Add a new template to the database. Will deduplicate the reaction template against stored ones. Parameters ---------- new_template : ReactionTemplate The new reaction template Returns ------- str Returns the UUID of the now stored template, if the template was a duplicate and thus subsumed, the UUID of the existing match is returned. """ for template in self.__unique_templates: if template == new_template: template.update(new_template) return template.get_uuid() self.__unique_templates.append(new_template) return new_template.get_uuid()
[docs] def find_matching_templates( self, molecules: List[masm.Molecule], energy_cutoff: Optional[float] = 300, enforce_atom_shapes: bool = True ) -> Optional[List[Match]]: """Finds all matching reaction templates for a given set of molecules. All given molecules must be used in the matched reaction template exactly once. None may be unused or used twice. A single template can match multiple atom combinations in the given atoms, hence multiple matches can be returned per reaction template. Note ---- For a more fine grained matching algorithm loop over all individual templates and use their matching methods which allow for more optional arguments. Parameters ---------- molecules : List[masm.Molecule] The molecules to apply the template to. energy_cutoff : float, optional Only apply the template if a barrier in the trialed direction of less than this value has been reported, by default 300 (kJ/mol) enforce_atom_shapes : bool, optional If true only allow atoms with the same coordination sphere shapes to be considered matching, by default ``True``. Returns ------- Optional[List[Match]] If any matches exist, returns a list of all possible matches. """ results = [] for template in self.__unique_templates: if energy_cutoff is None: matches = template.determine_all_matches( molecules, energy_cutoff=0.0, enforce_atom_shapes=enforce_atom_shapes, enforce_lhs=True, enforce_rhs=True, ) else: matches = template.determine_all_matches( molecules, energy_cutoff=energy_cutoff, enforce_atom_shapes=enforce_atom_shapes ) if matches is not None: results += matches if not results: return None else: return results
[docs] def save(self, path: str = '.rtdb.pickle.obj') -> None: """Saves the currently stored unique reaction templates into a file. Parameters ---------- path : str, optional The path to store a list of reaction templates, by default ``'.rtdb.pickle.obj'`` """ with open(path, 'wb') as f: pickle.dump(self.__unique_templates, f)
[docs] def load(self, path: str = '.rtdb.pickle.obj') -> None: """Loads a stored (pickled) file of templates and replaces existing data. Parameters ---------- path : str, optional The path to a stored list of reaction templates, by default ``'.rtdb.pickle.obj'`` """ with open(path, 'rb') as f: self.__unique_templates = pickle.load(f)
[docs] def append_file(self, path: str = '.rtdb.pickle.obj') -> None: """Loads a stored (pickled) file of templates and adds it to the existing unique reaction templates, keeps current template and adds data from the file. Parameters ---------- path : str, optional The path to a stored list of reaction templates, by default ``'.rtdb.pickle.obj'`` """ with open(path, 'rb') as f: new_templates = pickle.load(f) for template in new_templates: self.add_template(template)
[docs] def template_count(self) -> int: """Getter for the number of reaction templates int he database. Returns ------- int Number of unique reaction templates in the database. """ return len(self.__unique_templates)
# TODO requires masm.Molecule JSON representation generation on the fly # def dump_json(self, path: str = 'rtdb.dump.json') -> None: # with open(path, 'w') as f: # for template in self.__unique_templates: # f.write(template.to_json())
[docs] def add_scine_database( self, host: str = 'localhost', port: int = 27017, name: str = 'default', energy_cutoff: float = 300, model: Optional[Any] = None, gibbs_free_energy: bool = True ) -> None: """Parse an entire SCINE Database adding and deduplicating all templates into this database. Parameters ---------- host : str, optional The IP or hostname of the database server, by default 'localhost' port : int, optional The port of the database server, by default 27017 name : str, optional The name of the database on the server, by default 'default' energy_cutoff : float, optional Only store templates where at least one recorded barrier of the reaction is below the given threshold (in kJ/mol), by default 300 kJ/mol. model : Optional[scine_database.Model], optional The scine_database.Model to be allowed for energy evaluations, by default ``None`` is given, data of all models in the SCINE Database are allowed. gibbs_free_energy : bool, optional If true will only consider Gibbs free energies as valid energies for evaluation, by default True """ import scine_database as db manager = db.Manager() credentials = db.Credentials(host, port, name) manager.set_credentials(credentials) manager.connect() reaction_collection = manager.get_collection('reactions') elementary_step_collection = manager.get_collection('elementary_steps') property_collection = manager.get_collection('properties') structure_collection = manager.get_collection('structures') if gibbs_free_energy: energy_label = 'gibbs_free_energy' else: energy_label = 'electronic_energy' def get_energy_for_structure( structure: db.Structure, model: db.Model, ) -> Optional[float]: structure.link(structure_collection) structure_properties = structure.query_properties(energy_label, model, property_collection) if len(structure_properties) < 1: return None # pick last property if multiple prop = db.NumberProperty(structure_properties[-1]) prop.link(property_collection) return prop.get_data() for reaction in reaction_collection.iterate_all_reactions(): reaction.link(reaction_collection) elementary_steps = reaction.get_elementary_steps() best_spline = None best_step = None best_barriers = (9999999.9, 9999999.9) best_ts = 9999999.9 for step_id in elementary_steps: step = db.ElementaryStep(step_id) step.link(elementary_step_collection) if step.has_spline(): spline = step.get_spline() ts = db.Structure(step.get_transition_state()) ts.link(structure_collection) current_model = model if current_model is None: current_model = ts.get_model() e_ts = get_energy_for_structure( ts, current_model ) if e_ts is None: continue missing_energy = False e_lhs = 0.0 for reactant in step.get_reactants(db.Side.LHS)[0]: energy = get_energy_for_structure( db.Structure(reactant), current_model ) if energy is None: missing_energy = True break e_lhs += energy e_rhs = 0.0 for reactant in step.get_reactants(db.Side.RHS)[1]: energy = get_energy_for_structure( db.Structure(reactant), current_model ) if energy is None: missing_energy = True break e_rhs += energy if missing_energy: continue forward_barrier = (e_ts-e_lhs)*2625.5 backward_barrier = (e_ts-e_rhs)*2625.5 if not (forward_barrier < energy_cutoff or backward_barrier < energy_cutoff): continue if forward_barrier < best_barriers[0]: best_barriers = (forward_barrier, best_barriers[1]) if backward_barrier < best_barriers[1]: best_barriers = (best_barriers[0], backward_barrier) if e_ts < best_ts: best_ts = e_ts best_spline = spline best_step = step_id.string() if best_spline: template = ReactionTemplate.from_trajectory_spline( best_spline, barriers=best_barriers, elementary_step_id=best_step ) if template is not None: self.add_template(template)
[docs] def dump_svg_representation(self, path: str, index: Optional[int] = None) -> None: """Generate SVG representations of all templates. Parameters ---------- path : str The path (folder) to generate the SVG files in. index : Optional[int], optional If given only generate an SVG for the template with the given index, by default None """ def __dump_by_index(index: int): assert index < len(self.__unique_templates) svgs: List[svgutils.compose.SVG] = [] fnames: List[str] = [] for side in ['lhs', 'rhs']: count = 0 for mol in self.__unique_templates[index].shelled_nuclei[side][0]: for fragment in mol: count += 1 fname = os.path.join(path, f'{side}-{count:03d}.svg') fnames.append(fname) write_molecule_to_svg(fname, fragment) svgs.append(svgutils.compose.SVG(fname)) if side == 'lhs': middle = len(svgs) max_height = max(svgs, key=lambda x: x.height).height shifts = [0] for s in (svgs[0:-1]): shifts.append(shifts[-1]+s.width) if len(shifts) == middle+1: # pylint: disable=possibly-used-before-assignment shifts[-1] += 100 for i, s in enumerate(svgs): s.move(shifts[i], (max_height-s.height)/2) svgutils.compose.Figure( sum([x.width for x in svgs])+100, max_height, *svgs ).save(os.path.join(path, f'template-{index:04d}.svg')) for fname in fnames: os.remove(fname) if index is not None: __dump_by_index(index) else: for i in range(len(self.__unique_templates)): __dump_by_index(i)