Commit f8f8ff80 authored by Martino Bertoni's avatar Martino Bertoni 🌋
Browse files

added batch_size for prediction as argument, added possibility to save MFP

parent bde5216d
......@@ -121,8 +121,8 @@ class Signaturizer(object):
else:
RDLogger.DisableLog('rdApp.*')
def predict(self, molecules, destination=None, chunk_size=1000,
keytype='SMILES'):
def predict(self, molecules, destination=None, keytype='SMILES',
save_mfp=False, chunk_size=1000, batch_size=128):
"""Predict signatures for given SMILES.
Perform signature prediction for input SMILES. We recommend that the
......@@ -161,7 +161,8 @@ class Signaturizer(object):
self.inchies = inchies
# Prepare result object
features = len(self.model_names) * 128
results = SignaturizerResult(len(inchies), destination, features)
results = SignaturizerResult(
len(inchies), destination, features, save_mfp=save_mfp)
results.dataset[:] = self.model_names
if results.readonly:
raise Exception(
......@@ -179,7 +180,8 @@ class Signaturizer(object):
# read molecule
inchi = inchi.encode('ascii', 'ignore')
if self.verbose:
print('READING', inchi, type(inchi))
# print('READING', inchi, type(inchi))
pass
mol = Chem.inchi.MolFromInchi(inchi)
if mol is None:
raise Exception(
......@@ -200,14 +202,16 @@ class Signaturizer(object):
sign0s.append(calc_s0)
# stack input fingerprints and run signature predictor
sign0s = np.vstack(sign0s)
preds = self.model.predict(sign0s, batch_size=1)
preds = self.model.predict(sign0s, batch_size=batch_size)
# add NaN where SMILES conversion failed
if failed:
preds[np.array(failed)] = np.full(features, np.nan)
results.signature[chunk] = preds
if save_mfp:
results.mfp[chunk] = sign0s
# run applicability predictor
if self.applicability:
apreds = self.app_model.predict(sign0s, batch_size=1)
apreds = self.app_model.predict(sign0s, batch_size=batch_size)
if failed:
apreds[np.array(failed)] = np.nan
results.applicability[chunk] = apreds
......@@ -251,7 +255,7 @@ class SignaturizerResult():
the same vector available as HDF5 datasets.
"""
def __init__(self, size, destination, features=128):
def __init__(self, size, destination, features=128, save_mfp=False):
"""Initialize a SignaturizerResult instance.
Args:
......@@ -263,6 +267,7 @@ class SignaturizerResult():
"""
self.dst = destination
self.readonly = False
self.save_mfp = save_mfp
if self.dst is None:
# simple numpy arrays
self.h5 = None
......@@ -273,6 +278,8 @@ class SignaturizerResult():
self.dataset = np.full((int(np.ceil(features / 128)),), np.nan,
dtype=h5py.special_dtype(vlen=str))
self.failed = np.full((size, ), False, dtype=np.bool)
if self.save_mfp:
self.mfp = np.full((size, 2048), np.nan, order='F', dtype=int)
else:
# check if the file exists already
if os.path.isfile(self.dst):
......@@ -294,11 +301,15 @@ class SignaturizerResult():
self.h5.create_dataset(
'failed', (size,),
dtype=np.bool)
if self.save_mfp:
self.h5.create_dataset('mfp', (size, 2048), dtype=int)
# expose the datasets
self.signature = self.h5['signature']
self.applicability = self.h5['applicability']
self.dataset = self.h5['dataset']
self.failed = self.h5['failed']
if self.save_mfp:
self.mfp = self.h5['mfp']
def close(self):
if self.h5 is None:
......@@ -311,3 +322,5 @@ class SignaturizerResult():
self.applicability = self.h5['applicability']
self.dataset = self.h5['dataset']
self.failed = self.h5['failed']
if self.save_mfp:
self.mfp = self.h5['mfp']
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment