"""
calculate_diagram.py
Training Hasse diagrams.
"""
import time
import numpy as np
import pickle
from .hasse import extended_hasse
from ehreact.helpers.rdkit import (
read_in_smiles_unique,
read_in_reactions_unique,
preprocess_seeds,
make_fragment_indices,
morgan_bit_fp_from_mol,
tanimoto_from_fp,
find_matching_atoms,
canonicalize,
make_mol,
make_mol_no_sanit_h,
)
from ehreact.diagram.diagram import sanity_check
from ehreact.diagram.plot_hasse import plot_hasse
from ehreact.helpers.utils import findsubsets
from ehreact.helpers.transition_state import compute_aam_with_h
[docs]def calculate_diagram(
smiles,
verbose=False,
quiet=True,
compute_aam=False,
save_path=None,
save_plot=None,
train_mode="transition_state",
seed=[],
no_props=False,
plot_only_branches=False,
temp_dir_img=None,
):
"""
Computes a Hasse diagram of a list of reaction or molecule smiles.
Parameters
----------
smiles: List[str]
List of SMILES or reaction SMILES.
verbose: bool, default False
Whether to print additional information.
quiet: bool, default True
Whether to silence all output.
compute_aam: bool, default False
Whether to compute atom-mappings for reactions.
save_path: str, default None
File to which diagram is saved.
save_plot: str, default None
File to which save image of diagram.
train_mode: Literal["single_reactant","transition_state"], default "transition_state"
Train mode, either transition states extracted from reaction smiles or single reactants extracted from smiles.
seed: List[str], default []
List of SMILES seeds for the reactant algorithm, usually a single seed is given.
no_props: bool, default False
Do not compute any properties, just output the diagram.
plot_only_branches: bool, default False
Plot only substructures that branch off.
temp_dir_img: str, default None
Directory to save temporary image files
Returns
-------
d: ehreact.diagram.diagram.Diagram
The Hasse diagram of the input list of molecules/reactions.
"""
# Create Hasse diagram
if not quiet:
print("=" * 40)
print("Calculating diagram")
start1 = time.time()
if train_mode == "single_reactant":
d, smiles_dict = calculate_diagram_single_reactant(smiles, seed, verbose, quiet)
elif train_mode == "transition_state":
d, smiles_dict = calculate_diagram_transition_state(
smiles, verbose, quiet, compute_aam
)
else:
raise ValueError("Unknown option used for train_mode.")
d.mode = train_mode
sanity_check(d)
end1 = time.time()
# Fill additional information into diagram
if not quiet:
print("=" * 40)
print("Create additional diagram information")
start2 = time.time()
# Need to calculate the lowest template for each node for correct plotting of diagram
for n in d.nodes:
d.nodes[n].lowest_template = find_lowest_template(d.nodes[n], d)
# Calculate all other properties only if not no_props:
if not no_props:
fill_information(d, train_mode, verbose, smiles_dict)
end2 = time.time()
# Save diagram
start3 = time.time()
if save_path is not None:
if not quiet:
print("=" * 40)
print("Saving model")
with open(save_path, "wb") as handle:
pickle.dump(d, handle)
if not quiet:
print("Saved model to", save_path)
end3 = time.time()
# Plot diagram
start4 = time.time()
if save_plot is not None:
if not quiet:
print("=" * 40)
print("Plotting model")
if temp_dir_img is None:
raise ValueError(
"Must supply a directory for saving temporary images files."
)
plot_hasse(d, temp_dir_img, save_plot, plot_only_branches)
if not quiet:
print("Plotted diagram in png format to " + save_plot)
end4 = time.time()
# Print timing
if not quiet:
print("=" * 40)
print("Finished")
print("Calculating diagram -- time:{}s".format(end1 - start1))
print("Create additional diagram information -- time:{}s".format(end2 - start2))
print("Saving model -- time:{}s".format(end3 - start3))
print("Plotting model -- time:{}s".format(end4 - start4))
return d
[docs]def calculate_diagram_single_reactant(smiles, seed_list, verbose, quiet):
"""
Computes a Hasse diagram of a list of molecule smiles.
Parameters
----------
smiles: List[str]
List of SMILES.
seed_list: List[str]
List of SMILES seeds.
verbose: bool
Whether to print additional information.
quiet: bool
Whether to silence all output.
Returns
-------
d: ehreact.diagram.diagram.Diagram
The Hasse diagram of the input list of molecules.
smiles_dict: dict
A dictionary of the canonicalized input smiles.
"""
# Create canonical smiles and mol objects:
smiles, smiles_dict, skipped = read_in_smiles_unique(smiles)
if not quiet and len(skipped) != 0:
print("Skipped duplicates:", skipped)
if verbose:
print(
"Input:",
len(smiles) + len(skipped),
"(of which",
len(skipped),
"were skipped due to duplicates).",
)
print("Canonicalized smiles used for training:")
for smi in smiles:
print(
"Without stereochemistry:",
smi,
" with stereochemistry:",
smiles_dict[smi]["canonical_smi"],
)
# Create seeds:
if verbose and len(seed_list) == 0:
print(
"No seed given, find maximum common substructure of all molecules instead."
)
seeds, rule_dict, num_smiles_seed = preprocess_seeds(seed_list, smiles, smiles_dict)
if verbose:
print("Seeds", seeds)
if np.sum(num_smiles_seed) < len(smiles):
raise ValueError("Not all molecules fit a seed. Check given seeds.")
elif np.sum(num_smiles_seed) > len(smiles):
raise ValueError("Seeds are not mutually exclusive. Check given seeds.")
tags_core = {}
d = extended_hasse(smiles_dict, seeds, rule_dict, tags_core, verbose, quiet)
return d, smiles_dict
[docs]def calculate_diagram_transition_state(smiles, verbose, quiet, compute_aam):
"""
Computes a Hasse diagram of a list of reaction smiles.
Parameters
----------
smiles: List[str]
List of reaction SMILES.
verbose: bool, default False
Whether to print additional information.
quiet: bool
Whether to silence all output.
compute_aam: bool
Whether to compute atom-mappings for reactions.
Returns
-------
d: ehreact.diagram.diagram.Diagram
The Hasse diagram of the input list of reactions.
smiles_dict: dict
A dictionary of the canonicalized input smiles.
"""
# Create transition states
if compute_aam:
if not quiet:
print("Calculating atom-mappings")
mapped_smiles = []
for smi in smiles:
mapped_smiles.append(compute_aam_with_h(smi))
if verbose:
print(smi, mapped_smiles[-1])
smiles = mapped_smiles
if not quiet:
print("Finished calculating atom-mappings")
(
seeds,
rule_dict,
smiles,
smiles_dict,
skipped,
tags_core,
) = read_in_reactions_unique(smiles)
if not quiet and len(skipped) != 0:
print("Skipped duplicates:", skipped)
if verbose:
print(
"Input:",
len(smiles) + len(skipped),
"(of which",
len(skipped),
"were skipped due to duplicates).",
)
print("Canonicalized smiles used for training:")
for smi in smiles:
print(
"Without stereochemistry:",
smi,
" with stereochemistry:",
smiles_dict[smi]["canonical_smi"],
)
d = extended_hasse(smiles_dict, seeds, rule_dict, tags_core, verbose, quiet)
return d, smiles_dict
[docs]def find_lowest_template(curr_node, d):
"""Function to find the lowest (most general) substructure/reaction rule in the tree.
Parameters
----------
curr_node: ehreact.diagram.diagram.Node
Node for which to find the lowest template.
d: ehreact.diagram.diagram.Diagram
Hasse diagram.
Returns
-------
lowest_template: str
Name of the lowest template.
"""
if curr_node.key == "":
return ""
while curr_node.edges_to_parent[0].parent_node.key != "":
curr_node = curr_node.edges_to_parent[0].parent_node
lowest_template = curr_node.key
return lowest_template
[docs]def write_fragment_list_to_root(d, train_mode, verbose, smiles_dict):
"""Function to calculate a list of reactant rule fragments (only atoms in reaction center). This in needed
to transform inputted molecules to their corresponding transition state. Save to the root node (in-place).
Parameters
----------
d: ehreact.diagram.diagram.Diagram
Hasse diagram.
train_mode: Literal["single_reactant","transition_state"]
Train mode, either transition states extracted from reaction smiles or single reactants extracted from smiles.
verbose: bool, default False
Whether to print additional information.
smiles_dict: dict
A dictionary of the canonicalized input smiles.
"""
if verbose:
print("Calculating reactant fragments for each minimal template")
fragment_dict_reac = {}
minimal_templates_edges = d.nodes[""].edges_to_child
for edge in minimal_templates_edges:
minimal_template = edge.child_node
if verbose:
print("template", minimal_template)
fragment_dict_reac[str(minimal_template)] = {}
rule_reac = d.nodes[str(minimal_template)].rule_reac
rule_reac_fragment_indices, rule_reac_fragment_mols = make_fragment_indices(
rule_reac
)
for i in range(len(rule_reac_fragment_indices)):
fragment_dict_reac[str(minimal_template)][rule_reac_fragment_indices[i]] = {
"rule_fragment": rule_reac_fragment_mols[i],
"molecules": [],
}
for leaf in d.nodes[str(minimal_template)].all_leafs:
if verbose:
print("... leaf", leaf)
tags_core = d.nodes[leaf].tags_core
smi_reac = leaf.split(">")[0]
mol_reac = make_mol_no_sanit_h(smi_reac)
matches_reac = find_matching_atoms(
train_mode, mol_reac, rule_reac, None, tags_core
)
if len(matches_reac) != 1:
print(
"WARNING: Expected exactly one match, but got "
+ str(len(matches_reac))
+ " instead."
)
print("tags:", tags_core)
print("leaf:", leaf)
print("minimal_template:", minimal_template)
if len(matches_reac) == 0:
raise ValueError("Zero matches, something went wrong.")
matches_reac = matches_reac[0]
if len(smi_reac.split(".")) != len(rule_reac_fragment_indices):
print(
"Intramolecular reaction detected. This is currently not tested to work correctly"
)
print("this might lead to erraneous results.")
print(leaf)
print(smi_reac)
print(rule_reac_fragment_indices)
# TODO implement intramolecular version, write exception if rule_atoms_in_fragment not found
# (since it is a superset of the combination of available fragments then,
# need to iterate through those!
for smi_fragment in smi_reac.split("."):
fragment = make_mol(smi_fragment)
fragment_matches = mol_reac.GetSubstructMatches(fragment)
fragment_match = []
for match in fragment_matches:
correct_match = True
for i, x in enumerate(match):
if fragment.GetAtomWithIdx(i).HasProp("molAtomMapNumber"):
if fragment.GetAtomWithIdx(i).GetProp(
"molAtomMapNumber"
) != mol_reac.GetAtoms()[x].GetProp("molAtomMapNumber"):
correct_match = False
if correct_match:
fragment_match.append(match)
if len(fragment_match) != 1:
raise ValueError(
"Expected exactly one match, but got",
len(fragment_match),
"instead. Exit program",
)
fragment_match = fragment_match[0]
rule_atoms_in_fragment = tuple(
[i for i, x in enumerate(matches_reac) if x in fragment_match]
)
for atom in fragment.GetAtoms():
atom.SetAtomMapNum(0)
canonical_smi_fragment = canonicalize(fragment)
if (
rule_atoms_in_fragment
not in fragment_dict_reac[str(minimal_template)].keys()
):
# Could be in intramolecular reaction, correct for it here:
# TODO, add to molecules for each of the subsets that apply.
# else this is an error:
raise ValueError(
"An error occured: Found rule fragment doesn't match available fragments"
)
if (
canonical_smi_fragment
not in fragment_dict_reac[str(minimal_template)][
rule_atoms_in_fragment
]["molecules"]
):
fragment_dict_reac[str(minimal_template)][rule_atoms_in_fragment][
"molecules"
].append(canonical_smi_fragment)
d.nodes[""].fragment_reac = fragment_dict_reac
[docs]def calculate_diversity(d, node, stereo, train_mode):
"""Calculates diversity within a branch or tree.
Parameters
----------
d: ehreact.diagram.diagram.Diagram
Hasse diagram.
node: ehreact.diagram.diagram.Node
Node for which to calculate diversity, all leaf nodes attached to this node by an arbitrary number of
edges toward children are taken into account.
stereo: bool
Whether to include stereochemistry in fingerprints
train_mode: Literal["single_reactant", "transition_state"]
Train mode, either transition states extracted from reaction smiles or single reactants extracted from smiles.
Returns
-------
mean_div_reac: float
Mean pair similarity of reactants.
mean_div_prod: float
Mean pair similarity of products.
"""
diversity_reac = []
diversity_prod = []
if len(d.nodes[node].all_leafs) > 1:
for combi in findsubsets(d.nodes[node].all_leafs, 2):
if not stereo:
fp1 = d.nodes[combi[0]].fp_reac
fp2 = d.nodes[combi[1]].fp_reac
diversity_reac.append(tanimoto_from_fp(fp1, fp2))
if train_mode == "transition_state":
fp1 = d.nodes[combi[0]].fp_prod
fp2 = d.nodes[combi[1]].fp_prod
diversity_prod.append(tanimoto_from_fp(fp1, fp2))
else:
diversity_prod.append(1)
else:
best_sim = 0
for fp1 in d.nodes[combi[0]].fp_reac_stereo:
for fp2 in d.nodes[combi[1]].fp_reac_stereo:
sim = tanimoto_from_fp(fp1, fp2)
if sim > best_sim:
best_sim = sim
diversity_reac.append(best_sim)
if train_mode == "transition_state":
best_sim = 0
for fp1 in d.nodes[combi[0]].fp_prod_stereo:
for fp2 in d.nodes[combi[1]].fp_prod_stereo:
sim = tanimoto_from_fp(fp1, fp2)
if sim > best_sim:
best_sim = sim
diversity_prod.append(best_sim)
else:
diversity_prod.append(1)
else:
diversity_reac.append(1)
diversity_prod.append(1)
mean_div_reac = np.mean(diversity_reac)
mean_div_prod = np.mean(diversity_prod)
return mean_div_reac, mean_div_prod