Commit 09ef7e48 authored by Martino Bertoni's avatar Martino Bertoni 🌋
Browse files

fixed standardization of molecules now converting SMILES to InChI, fixed batch_size bug

parent 5a5b1e1b
......@@ -116,7 +116,8 @@ class Signaturizer(object):
def predict(self, smiles, destination=None, chunk_size=1000):
def predict(self, molecules, destination=None, chunk_size=1000,
"""Predict signatures for given SMILES.
Perform signature prediction for input SMILES. We recommend that the
......@@ -125,16 +126,37 @@ class Signaturizer(object):
is possible and the corresponding signature will be set to NaN.
smiles(list): List of SMILES strings.
molecules(list): List of strings representing molecules. Can be
SMILES (by default) or InChI.
destination(str): File path where to save predictions.
chunk_size(int): Perform prediction on chunks of this size.
keytype(str): Wether to interpret molecules as InChI or SMILES.
results: `SignaturizerResult` class. The ordering of input SMILES
is preserved.
# convert input molecules to InChI
inchies = list()
if keytype.upper() == 'SMILES':
for smi in molecules:
if smi == '':
mol = Chem.MolFromSmiles(smi)
if mol is None:
if self.verbose:
print("Cannot get molecule from SMILES: %s." % smi)
inchies.append('INVALID SMILES')
inchi = Chem.rdinchi.MolToInchi(mol)[0]
if self.verbose:
print('CONVERTED:', smi, inchi)
inchies = molecules
self.inchies = inchies
# Prepare result object
features = len(self.model_names) * 128
results = SignaturizerResult(len(smiles), destination, features)
results = SignaturizerResult(len(inchies), destination, features)
results.dataset[:] = self.model_names
if results.readonly:
raise Exception(
......@@ -142,18 +164,21 @@ class Signaturizer(object):
'delete or rename to proceed.')
# predict by chunk
all_chunks = range(0, len(smiles), chunk_size)
all_chunks = range(0, len(inchies), chunk_size)
for i in tqdm(all_chunks, disable=not self.verbose):
chunk = slice(i, i + chunk_size)
sign0s = list()
failed = list()
for idx, mol_smiles in enumerate(smiles[chunk]):
for idx, inchi in enumerate(inchies[chunk]):
# read SMILES as molecules
mol = Chem.MolFromSmiles(mol_smiles)
# read molecule
inchi = inchi.encode('ascii', 'ignore')
if self.verbose:
print('READING', inchi, type(inchi))
mol = Chem.inchi.MolFromInchi(inchi)
if mol is None:
raise Exception(
"Cannot get molecule from smiles.")
"Cannot get molecule from InChI.")
info = {}
fp = AllChem.GetMorganFingerprintAsBitVect(
mol, 2, nBits=2048, bitInfo=info)
......@@ -163,21 +188,21 @@ class Signaturizer(object):
except Exception as err:
# in case of failure save idx to fill NaNs
if self.verbose:
print("SKIPPING %s: %s" % (mol_smiles, str(err)))
print("SKIPPING %s: %s" % (inchi, str(err)))
calc_s0 = np.full((2048, ), np.nan)
# stack input fingerprints and run signature predictor
sign0s = np.vstack(sign0s)
preds = self.model.predict(sign0s)
preds = self.model.predict(sign0s, batch_size=1)
# add NaN where SMILES conversion failed
if failed:
preds[np.array(failed)] = np.full(features, np.nan)
results.signature[chunk] = preds
# run applicability predictor
if self.applicability:
apreds = self.app_model.predict(sign0s)
apreds = self.app_model.predict(sign0s, batch_size=1)
if failed:
apreds[np.array(failed)] = np.nan
results.applicability[chunk] = apreds
......@@ -187,11 +212,11 @@ class Signaturizer(object):
failed = np.isnan(results.signature[:, 0])
results.failed = failed
if any(failed) > 0:
print('Some SMILES could not be recognized,'
print('Some molecules could not be recognized,'
' the corresponding signatures are NaN')
if self.verbose:
for idx in np.argwhere(failed).flatten():
return results
......@@ -25,10 +25,14 @@ class TestSignaturizer(unittest.TestCase):
self.invalid_smiles = ['C', 'C&', 'C']
self.tautomer_smiles = ['CC(O)=Nc1ccc(O)cc1', 'CC(=O)NC1=CC=C(C=C1)O']
self.inchi = [
def tearDown(self):
if os.path.exists(self.tmp_dir):
shutil.rmtree(self.tmp_dir, ignore_errors=True)
def test_predict(self):
......@@ -99,3 +103,22 @@ class TestSignaturizer(unittest.TestCase):
# repeating writing will result in an exception
with self.assertRaises(Exception):
module.predict(self.test_smiles, destination)
def test_tautomers(self):
module = Signaturizer('A1')
res = module.predict(self.tautomer_smiles)
self.assertTrue(all(res.signature[0] == res.signature[1]))
def test_inchi(self):
module = Signaturizer('A1')
res_inchi = module.predict(self.inchi, keytype='InChI')
res_smiles = module.predict([self.tautomer_smiles[0]])
self.assertTrue(all(res_inchi.signature[0] == res_smiles.signature[0]))
def test_all_single(self):
module = Signaturizer('A1')
res_all = module.predict(self.test_smiles)
for idx, smile in enumerate(self.test_smiles):
res_single = module.predict([smile])
all(res_all.signature[idx] == res_single.signature[0]))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment