Source code for scine_chemoton.utilities.masm

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
__copyright__ = """ This code is licensed under the 3-clause BSD license.
Copyright ETH Zurich, Department of Chemistry and Applied Biosciences, Reiher Group.
See LICENSE.txt for details.
"""

# Standard library imports
import ast
import numpy as np

import scine_molassembler as masm
import scine_database as db
import scine_utilities as utils

from typing import List, Set, Any, Callable, Tuple, Optional, DefaultDict
from collections import defaultdict, namedtuple


[docs]def mol_to_cbor(mol: masm.Molecule) -> str: """ Convert a molecule into a base-64 encoded CBOR serialization Parameters ---------- mol : masm.Molecule Molecule to serialize Returns ------- serialization : str The string-serialized molecule representation """ serializer = masm.JsonSerialization cbor_format = serializer.BinaryFormat.CBOR serialization = serializer(mol) cbor_binary = serialization.to_binary(cbor_format) return serializer.base_64_encode(cbor_binary)
[docs]def mol_from_cbor(cbor_str: str) -> masm.Molecule: """ Convert base-64 encoded CBOR to a molassembler Molecule Converts a single base-64 encoded CBOR string (no ';' separator as stored in the database) into a molecule Parameters ---------- cbor_str : str String to deserialize into a Molecule Returns ------- molecule : masm.Molecule The deserialized molecule """ serializer = masm.JsonSerialization cbor_binary = serializer.base_64_decode(cbor_str) cbor_format = serializer.BinaryFormat.CBOR serialization = serializer(cbor_binary, cbor_format) return serialization.to_molecule()
[docs]def mols_from_properties(structure: db.Structure, properties: db.Collection) -> Optional[List[masm.Molecule]]: """ Generate all molecules based on atomic positions in a structure and the bond orders stored in attache properties. Parameters ---------- structure : db.Structure The structure whose contained molecule(s) to analyze. properties : db.Collection The collection holding all properties. Returns ------- molecules : List[masm.Molecule] A list of all the molecules contained in the database structure. """ atoms = structure.get_atoms() distance_bos = utils.BondDetector.detect_bonds(atoms) # Check/get bond orders if not structure.has_property('bond_orders'): return None bo_property_id = structure.get_property('bond_orders') bo_property = db.SparseMatrixProperty(bo_property_id) bo_property.link(properties) # Update bond orders bond_orders = utils.BondOrderCollection(len(atoms)) bond_orders.matrix = bo_property.data() final_bo_matrix = (bond_orders.matrix).maximum(distance_bos.matrix) final_bo_matrix = (final_bo_matrix).multiply(distance_bos.matrix) bos = utils.BondOrderCollection(len(atoms)) bos.matrix = final_bo_matrix # Build and return molecules return masm.interpret.molecules(atoms, bos, set(), {}, masm.interpret.BondDiscretization.Binary).molecules
[docs]def deserialize_molecules(structure: db.Structure) -> List[masm.Molecule]: """ Retrieves all molecules stored for a structure Parameters ---------- structure : db.Structure The structure whose contained molecules to deserialize Returns ------- molecules : List[masm.Molecule] A list of all the molecules contained in the database structure """ multiple_cbors = structure.get_graph("masm_cbor_graph") return [mol_from_cbor(m) for m in multiple_cbors.split(";")]
[docs]def distinguish_components(components: List[int], map_unary: Callable[[int], Any]) -> List[int]: """ Splits components by the result of a unary mapping function Parameters ---------- components : List[int] A per-index mapping to a component index. Must contain only sequential numbers starting from zero. map_unary : Callable[[int], Any] A unary callable that is called with an index, not a component index, yielding some comparable type. Components of indices are then split by matching results of invocations of this callable. Returns ------- components : List[int] A per-index mapping to a component index. Contains only sequential numbers starting from zero. """ assert len(set(components)) == max(components) + 1 component_sets: List[Set[int]] component_sets = [set() for _ in range(max(components) + 1)] for i, c in enumerate(components): component_sets[c].add(i) def split_by_unary(indices: Set[int]) -> List[Set[int]]: results = defaultdict(set) for i in indices: results[map_unary(i)].add(i) return list(results.values()) split_component_sets = [] for subset in component_sets: split_component_sets.extend(split_by_unary(subset)) new_components = [0 for _ in range(len(components))] for c, subset in enumerate(split_component_sets): for i in subset: new_components[i] = c return new_components
[docs]def distinct_components(mol: masm.Molecule, h_only: bool) -> List[int]: """ Generates a flat map of atom index to component identifier Parameters ---------- mol : masm.Molecule A molecule whose atoms to generate distinct components for h_only : bool Whether to only apply ranking deduplication to hydrogen atoms Returns ------- components : List[int] A flat per-atom index mapping to a component index. Contains only sequential numbers starting from zero. """ components = masm.ranking_equivalent_groups(mol) if h_only: return distinguish_components(components, lambda i: mol.graph.element_type(i) == utils.ElementType.H) return components
[docs]def distinct_atoms(mol: masm.Molecule, h_only: bool) -> List[int]: """ Generates a list of distinct atom indices Parameters ---------- mol : masm.Molecule A molecule whose atoms to list distinct atoms for h_only : bool Whether to only apply ranking deduplication to hydrogen atoms Returns ------- components : List[int] A list of ranking-distinct atoms """ def is_h(i: int) -> bool: return mol.graph.element_type(i) == utils.ElementType.H distinct = masm.ranking_distinct_atoms(mol) if h_only: distinct_hs = [i for i in distinct if is_h(i)] heavy_atoms = [i for i in range(mol.graph.V) if not is_h(i)] return distinct_hs + heavy_atoms return distinct
[docs]def make_sorted_pair(a: int, b: int) -> Tuple[int, int]: if b < a: return b, a return a, b
ComponentDistanceTuple = namedtuple("ComponentDistanceTuple", ["mol_idx", "components", "distance"]) StructureIndexPair = Tuple[int, int] EquivalentPairingsMap = DefaultDict[ComponentDistanceTuple, Set[StructureIndexPair]]
[docs]def pruned_atom_pairs( molecules: List[masm.Molecule], idx_map: List[Tuple[int, int]], distance_bounds: Tuple[int, int], prune: str ) -> Set[Tuple[int, int]]: assert prune in ["Hydrogen", "All"] def structure_idx(c: int, i: int) -> int: return idx_map.index((c, i)) pairings: EquivalentPairingsMap = defaultdict(set) # Idea: For each distinct atom in the molecule, distinguish the # distinct components of the molecule by the distance to the selected # atom. Then, store one atom pairing for each distinct set of component # and distance combination. for mol_idx, molecule in enumerate(molecules): distinct = distinct_atoms(molecule, prune == "Hydrogen") components = distinct_components(molecule, prune == "Hydrogen") for i in distinct: distances = masm.distance(i, molecule.graph) local_components = distinguish_components(components, lambda x: distances[x]) # pylint: disable=cell-var-from-loop considered_components = set() for j, c in enumerate(local_components): if c in considered_components or i == j: continue considered_components.add(c) if min(distance_bounds) <= distances[j] <= max(distance_bounds): key = ComponentDistanceTuple( mol_idx=mol_idx, components=make_sorted_pair(components[i], components[j]), distance=distances[j], ) if key not in pairings: s_ij = make_sorted_pair(*[structure_idx(mol_idx, x) for x in [i, j]]) pairings[key].add(s_ij) # Pick one element from each set of same-key pairings return set([next(iter(subset)) for subset in pairings.values()])
[docs]def unpruned_atom_pairs( molecules: List[masm.Molecule], idx_map: List[Tuple[int, int]], distance_bounds: Tuple[int, int] ) -> Set[Tuple[int, int]]: """Helper function to generate the set of unpruned atom pairs""" def structure_idx(c: int, i: int) -> int: return idx_map.index((c, i)) pairs: Set[Tuple[int, int]] = set() for component, molecule in enumerate(molecules): for i in molecule.graph.atoms(): distances = np.array(masm.distance(i, molecule.graph)) partners = np.nonzero((distances <= max(distance_bounds)) & (distances >= min(distance_bounds)))[0] # Back-transform to structure indices and add to set s_i = structure_idx(component, i) s_partners = [structure_idx(component, j) for j in partners] pairs |= set(make_sorted_pair(s_i, s_j) for s_j in s_partners) return pairs
[docs]def get_atom_pairs( structure: db.Structure, distance_bounds: Tuple[int, int], prune: str = "None", superset: Optional[Set[Tuple[int, int]]] = None, ) -> Set[Tuple[int, int]]: """ Gets a list of all atom pairs whose graph distance is smaller or equal to `max_graph_distance` and larger or equal to `min_graph_distance` on the basis of the interpreted graph representation. Parameters ---------- structure : db.Structure The structure that is investigated distance_bounds : Tuple[int, int] The minimum and maximum distance between two points that is allowed so that they are considered a valid atom pair. prune : str Whether to prune atom pairings by Molassembler's ranking distinct atoms descriptor. Allowed values: `'None'`, `'Hydrogen'`, `'All'` superset : Optional[Set[Tuple[int, int]]] Optional superset of pairs to filter. If set, will filter the passed set. Otherwise, generates atom pairings from all possible pairs in the molecule. Returns ------- pairs : Set[Tuple[int, int]] The indices of valid atom pairs. """ valid_option_values = ["None", "Hydrogen", "All"] if prune not in valid_option_values: msg = "Option for masm atom pruning invalid: {}" raise RuntimeError(msg.format(prune)) molecules = deserialize_molecules(structure) idx_map = ast.literal_eval(structure.get_graph("masm_idx_map")) if prune in ["Hydrogen", "All"]: if superset is not None: # Expand superset into the pruned keyspace superset_dict: EquivalentPairingsMap = defaultdict(set) molecule_components = [distinct_components(mol, prune == "Hydrogen") for mol in molecules] for s_i, s_j in superset: mol_idx, i = idx_map[s_i] cmp_idx, j = idx_map[s_j] assert mol_idx == cmp_idx components = molecule_components[mol_idx] key = ComponentDistanceTuple( mol_idx=mol_idx, components=make_sorted_pair(components[i], components[j]), distance=masm.distance(i, molecules[mol_idx].graph)[j], ) if min(distance_bounds) <= key.distance <= max(distance_bounds): superset_dict[key].add(make_sorted_pair(s_i, s_j)) # Pick one element from each set of same-key pairings return set([next(iter(subset)) for subset in superset_dict.values()]) return pruned_atom_pairs(molecules, idx_map, distance_bounds, prune) if superset is not None: pairs: Set[Tuple[int, int]] = set() for s_i, s_j in superset: mol_idx, i = idx_map[s_i] cmp_idx, j = idx_map[s_j] assert mol_idx == cmp_idx distance = masm.distance(i, molecules[mol_idx])[j] if min(distance_bounds) <= distance <= max(distance_bounds): pairs.add(make_sorted_pair(s_i, s_j)) return pairs return unpruned_atom_pairs(molecules, idx_map, distance_bounds)