#!/usr/bin/env python3
# -*- coding: utf-8 -*-
__copyright__ = """ This code is licensed under the 3-clause BSD license.
Copyright ETH Zurich, Department of Chemistry and Applied Biosciences, Reiher Group.
See LICENSE.txt for details.
"""
import scine_molassembler as masm
from scine_art.molecules import (
all_matching_fragments,
maximum_matching_fragments,
sort_indices_by_subgraph,
mol_from_subgraph_indices,
)
from typing import List, Tuple, Set, Optional
from copy import copy
[docs]def map_reaction_from_molecules_direct(
lhs: List[masm.Molecule],
rhs: List[masm.Molecule],
known_atom_mappings: Optional[List[Tuple[int, int]]] = None
) -> Tuple[List[List[Tuple[int, int]]], List[List[Tuple[int, int]]]]:
"""Tries to match each atom in the molecules of the left-hand side to
exactly one atom on the right-hand side.
Atom type counts must match.
Note
----
The algorithm is experimental, be cautious with its results.
This version of the algorithm is CPU time greedy, but saves memory.
Parameters
----------
lhs : List[masm.Molecule]
A list of molecules on the left-hand side of the matching.
rhs : List[masm.Molecule]
A list of molecules on the right-hand side of the matching.
known_atom_mappings : Optional[List[Tuple[int, int]]]
A list of atom mappings that have to be respected. The atoms are
indexed on a continuous scale in order of the atoms in the molecules
given.
Returns
-------
Tuple[List[List[Tuple[int, int]]], List[List[Tuple[int, int]]]]
A tuple atom matches. The first entry matching from left- to right-hand
side the second from right- to left-hand side.
Raises
------
RuntimeError
If algorithm fails or miss matches appear.
"""
n_atoms = sum([lmol.graph.V for lmol in lhs])
assert sum([rmol.graph.V for rmol in rhs]) == n_atoms
# In this algorithm we will try to rebuild the LHS from fragments of
# the rhs. To this end the LHS will be viewed as having a continuous
# index with breaks where a molecule ends and a new one starts.
# (e.g. the input methane, HCl and benzene, in that order, would
# correspond to [0,1,2,3,4 | 5,6 | 7,8,9,10,11,...,18]
# When generating the intersection based fragment mapping we will generate
# the mapping to this new 'continuous-LHS' index.
l_sizes = [m.graph.V for m in lhs]
r_sizes = [m.graph.V for m in rhs]
l_offsets: List[int] = [sum(l_sizes[:i]) for i in range(len(lhs))]
r_offsets: List[int] = [sum(r_sizes[:i]) for i in range(len(rhs))]
remaining_lhs_fragments = [mol for mol in lhs]
remaining_lhs_indices = [[x + l_offsets[i] for x in range(mol.graph.V)] for i, mol in enumerate(lhs)]
remaining_rhs_fragments = [mol for mol in rhs]
remaining_rhs_indices = [[x + r_offsets[i] for x in range(mol.graph.V)] for i, mol in enumerate(rhs)]
def get_mappings():
mappings = []
for i, (l_mol, l_idxs) in enumerate(zip(remaining_lhs_fragments, remaining_lhs_indices)):
for j, (r_mol, r_idxs) in enumerate(zip(remaining_rhs_fragments, remaining_rhs_indices)):
fragments, fragments_to_l_mol, fragments_to_r_mol = maximum_matching_fragments(
l_mol, r_mol, min_fragment_size=1
)
# Generate continuous-LHS mapping, and sets of the mapped indices
for fragment, l_maps, r_maps in zip(fragments, fragments_to_l_mol, fragments_to_r_mol):
for l_map in l_maps:
assert len(l_map) == fragment.graph.V
for r_map in r_maps:
assert len(r_map) == fragment.graph.V
mapping = [(l_idxs[l_map[x]], r_idxs[r_map[x]]) for x in range(fragment.graph.V)]
l_set = set([x[0] for x in mapping])
r_set = set([x[1] for x in mapping])
mappings.append((
mapping,
l_set,
r_set,
fragment,
i,
j,
[l_map[x] for x in range(fragment.graph.V)],
[r_map[x] for x in range(fragment.graph.V)]
))
if known_atom_mappings is not None:
reduced_mappings = []
known_lhs_indices = [x[0] for x in known_atom_mappings]
for m in mappings:
bad_pairing = False
for pairing in m[0]:
if pairing[0] in known_lhs_indices and pairing not in known_atom_mappings:
bad_pairing = True
break
if not bad_pairing:
reduced_mappings.append(m)
return reduced_mappings
return mappings
# Start building a mapping from fragments to the 'continuous-LHS'
# 1. Pick (first) largest fragment from all lhs/rhs combinations
# 2. Remove all atoms that are now matched on the LHS and RHS
# 3. Reevaluate the best matched for all LHS/RHS fragment combinations
# 4. Repeat 1./2. until no fragments and/or 'continuous-LHS' indices are free
remaining_mappings = get_mappings()
starters = []
# TODO try multiple times, here? (loop)
starters.append(max(remaining_mappings, key=lambda x: len(x[0])))
used = []
l_blocked: Set[int] = set()
r_blocked: Set[int] = set()
while len(remaining_mappings) > 0:
current = max(remaining_mappings, key=lambda x: len(x[0]))
used.append(current)
assert not bool(l_blocked & current[1])
assert not bool(r_blocked & current[2])
l_blocked.update(current[1])
r_blocked.update(current[2])
# Reduce origin fragments
lhs_origin = current[4]
rhs_origin = current[5]
lhs_origin_fragment = remaining_lhs_fragments.pop(lhs_origin)
rhs_origin_fragment = remaining_rhs_fragments.pop(rhs_origin)
lhs_origin_indices = remaining_lhs_indices.pop(lhs_origin)
rhs_origin_indices = remaining_rhs_indices.pop(rhs_origin)
lhs_indices_by_subgraph = sort_indices_by_subgraph(
[i for i in range(lhs_origin_fragment.graph.V) if i not in current[6]],
lhs_origin_fragment.graph
)
for subgraph_indices in lhs_indices_by_subgraph:
mol, original_indices = mol_from_subgraph_indices(subgraph_indices, copy(lhs_origin_fragment))
if mol:
remaining_lhs_fragments.append(mol)
remaining_lhs_indices.append([lhs_origin_indices[x] for x in original_indices])
rhs_indices_by_subgraph = sort_indices_by_subgraph(
[i for i in range(rhs_origin_fragment.graph.V) if i not in current[7]],
rhs_origin_fragment.graph
)
for subgraph_indices in rhs_indices_by_subgraph:
mol, original_indices = mol_from_subgraph_indices(subgraph_indices, copy(rhs_origin_fragment))
if mol:
remaining_rhs_fragments.append(mol)
remaining_rhs_indices.append([rhs_origin_indices[x] for x in original_indices])
remaining_mappings = get_mappings()
def is_overlapping(fragment_info):
return bool(set(l_blocked) & set(fragment_info[1])) or bool(set(r_blocked) & set(fragment_info[2]))
remaining_mappings[:] = [x for x in remaining_mappings if not is_overlapping(x)]
assert len(l_blocked) == n_atoms
# Prepare the final output, for each atom in each molecule, generate
# a tuple pointing at a molecule and an atom within it on the other
# side.
all_maps: List[Tuple[int, int]] = []
for u in used:
all_maps += u[0]
def continuous_to_fragments(index: int, offsets: List[int]) -> Tuple[int, int]:
for o in reversed(offsets):
if o <= index:
return (offsets.index(o), index-o)
raise RuntimeError('Bug: Offsets in `map_reaction_from_molecules` did not match.')
all_maps.sort(key=lambda x: x[0])
lhs_final: List[List[Tuple[int, int]]] = []
for i, l_mol in enumerate(lhs):
tmp: List[Tuple[int, int]] = []
for j in range(l_mol.graph.V):
tmp.append(continuous_to_fragments(all_maps[j+l_offsets[i]][1], r_offsets))
lhs_final.append(tmp)
all_maps.sort(key=lambda x: x[1])
rhs_final: List[List[Tuple[int, int]]] = []
for i, r_mol in enumerate(rhs):
tmp = []
for j in range(r_mol.graph.V):
tmp.append(continuous_to_fragments(all_maps[j+r_offsets[i]][0], l_offsets))
assert lhs_final[tmp[-1][0]][tmp[-1][1]][0] == i
assert lhs_final[tmp[-1][0]][tmp[-1][1]][1] == j
rhs_final.append(tmp)
return lhs_final, rhs_final
[docs]def map_reaction_from_molecules_cached(
lhs: List[masm.Molecule],
rhs: List[masm.Molecule],
known_atom_mappings: Optional[List[Tuple[int, int]]] = None
) -> Tuple[List[List[Tuple[int, int]]], List[List[Tuple[int, int]]]]:
"""Tries to match each atom in the molecules of the left-hand side to
exactly one atom on the right-hand side.
Atom type counts must match.
Note
----
The algorithm is experimental, be cautious with its results.
This version of the algorithm is memory greedy.
Parameters
----------
lhs : List[masm.Molecule]
A list of molecules on the left-hand side of the matching.
rhs : List[masm.Molecule]
A list of molecules on the right-hand side of the matching.
known_atom_mappings : Optional[List[Tuple[int, int]]]
A list of atom mappings that have to be respected. The atoms are
indexed on a continuous scale in order of the atoms in the molecules
given.
Returns
-------
Tuple[List[List[Tuple[int, int]]], List[List[Tuple[int, int]]]]
A tuple atom matches. The first entry matching from left- to right-hand
side the second from right- to left-hand side.
Raises
------
RuntimeError
If algorithm fails or miss matches appear.
"""
n_atoms = sum([lmol.graph.V for lmol in lhs])
assert sum([rmol.graph.V for rmol in lhs]) == n_atoms
# In this algorithm we will try to rebuild the LHS from fragments of
# the rhs. To this end the LHS will be viewed as having a continuous
# index with breaks where a molecule ends and a new one starts.
# (e.g. the input methane, HCl and benzene, in that order, would
# correspond to [0,1,2,3,4 | 5,6 | 7,8,9,10,11,...,18]
# When generating the intersection based fragment mapping we will generate
# the mapping to this new 'continuous-LHS' index.
l_sizes = [m.graph.V for m in lhs]
r_sizes = [m.graph.V for m in rhs]
l_offsets: List[int] = [sum(l_sizes[:i]) for i in range(len(lhs))]
r_offsets: List[int] = [sum(r_sizes[:i]) for i in range(len(rhs))]
mappings = []
for i, l_mol in enumerate(lhs):
for j, r_mol in enumerate(rhs):
fragments, fragments_to_l_mol, fragments_to_r_mol = all_matching_fragments(
l_mol, r_mol, min_fragment_size=1
)
# Generate continuous-LHS mapping, and sets of the mapped indices
for fragment, l_maps, r_maps in zip(fragments, fragments_to_l_mol, fragments_to_r_mol):
for l_map in l_maps:
assert len(l_map) == fragment.graph.V
for r_map in r_maps:
assert len(r_map) == fragment.graph.V
mapping = [(l_map[x]+l_offsets[i], r_map[x]+r_offsets[j]) for x in range(fragment.graph.V)]
l_set = set([x[0] for x in mapping])
r_set = set([x[1] for x in mapping])
mappings.append((mapping, l_set, r_set, fragment))
# Start building a mapping from fragments to the 'continuous-LHS'
# 1. Pick (first) largest fragment
# 2. Remove all fragments that overlap with it or the now blocked
# indices in the 'continuous-LHS' indices
# 3. Repeat 1. until no fragments and/or 'continuous-LHS' indices are free
if known_atom_mappings is not None:
remaining_mappings = []
known_lhs_indices = [x[0] for x in known_atom_mappings]
for m in mappings:
bad_pairing = False
for pairing in m[0]:
if pairing[0] in known_lhs_indices and pairing not in known_atom_mappings:
bad_pairing = True
break
if not bad_pairing:
remaining_mappings.append(m)
else:
remaining_mappings = copy(mappings)
starters = []
# TODO try multiple times, here? (loop)
starters.append(max(remaining_mappings, key=lambda x: len(x[0])))
used = []
l_blocked: Set[int] = set()
r_blocked: Set[int] = set()
while len(remaining_mappings) > 0:
current = max(remaining_mappings, key=lambda x: len(x[0]))
used.append(current)
assert not bool(l_blocked & current[1])
assert not bool(r_blocked & current[2])
l_blocked.update(current[1])
r_blocked.update(current[2])
def is_overlapping(fragment_info):
return bool(set(l_blocked) & set(fragment_info[1])) or bool(set(r_blocked) & set(fragment_info[2]))
remaining_mappings[:] = [x for x in remaining_mappings if not is_overlapping(x)]
if len(l_blocked) != n_atoms:
raise RuntimeError(
'Could not complete atom mapping across reaction. Check "known_atom_mappings" if any were given.'
)
# Prepare the final output, for each atom in each molecule, generate
# a tuple pointing at a molecule and an atom within it on the other
# side.
all_maps: List[Tuple[int, int]] = []
for u in used:
all_maps += u[0]
def continuous_to_fragments(index: int, offsets: List[int]) -> Tuple[int, int]:
for o in reversed(offsets):
if o <= index:
return (offsets.index(o), index-o)
raise RuntimeError('Bug: Offsets in `map_reaction_from_molecules` did not match.')
all_maps.sort(key=lambda x: x[0])
lhs_final: List[List[Tuple[int, int]]] = []
for i, l_mol in enumerate(lhs):
tmp: List[Tuple[int, int]] = []
for j in range(l_mol.graph.V):
tmp.append(continuous_to_fragments(all_maps[j+l_offsets[i]][1], r_offsets))
lhs_final.append(tmp)
all_maps.sort(key=lambda x: x[1])
rhs_final: List[List[Tuple[int, int]]] = []
for i, r_mol in enumerate(rhs):
tmp = []
for j in range(r_mol.graph.V):
tmp.append(continuous_to_fragments(all_maps[j+r_offsets[i]][0], l_offsets))
assert lhs_final[tmp[-1][0]][tmp[-1][1]][0] == i
assert lhs_final[tmp[-1][0]][tmp[-1][1]][1] == j
rhs_final.append(tmp)
return lhs_final, rhs_final