Commit 0e53603b authored by Martino Bertoni's avatar Martino Bertoni 🌋
Browse files

saving dataset names in result, fadded method to clear tfhub cache, fixed verbosity

parent 8aeb4c34
# using the module
import os
import h5py
import itertools
import shutil
import numpy as np
from tqdm import tqdm
import tensorflow.compat.v1 as tf
......@@ -14,11 +14,12 @@ except ImportError:
"https://www.rdkit.org/docs/Install.html")
class Signaturizer():
class Signaturizer(object):
"""Class loading TF-hub module and performing predictions."""
def __init__(self, model_name, verbose=True,
base_url="file:///aloy/web_checker/exported_smilespreds/"):
def __init__(self, model_name,
base_url="file:///aloy/web_checker/exported_smilespreds/",
tf_version=1, verbose=True):
"""Initialize the Signaturizer.
Args:
......@@ -39,16 +40,22 @@ class Signaturizer():
else:
models = model_name
# load modules
self.model_names = list()
self.modules = list()
self.graph = tf.Graph()
with self.graph.as_default():
for model in models:
if os.path.isdir(model):
if self.verbose:
print('LOADING local:', model)
spec = hub.create_module_spec_from_saved_model(model)
module = hub.Module(spec, tags=['serve'])
else:
if self.verbose:
print('LOADING remote:', base_url + model)
module = hub.Module(base_url + model, tags=['serve'])
self.modules.append(module)
self.model_names.append(model)
def predict(self, smiles, destination=None, chunk_size=1000):
"""Predict signatures for given SMILES.
......@@ -68,13 +75,14 @@ class Signaturizer():
features = len(self.modules) * 128
results = SignaturizerResult(len(smiles), destination,
features)
results.dataset[:] = self.model_names
if results.readonly:
raise Exception(
'Destination file already exists, ' +
'delete or rename to proceed.')
# predict by chunk
all_chunks = range(0, len(smiles), chunk_size)
for i in tqdm(all_chunks, disable=self.verbose):
for i in tqdm(all_chunks, disable=not self.verbose):
chunk = slice(i, i + chunk_size)
sign0s = list()
failed = list()
......@@ -110,8 +118,20 @@ class Signaturizer():
mdl_cols = slice(idx * 128, (idx + 1) * 128)
results.signature[chunk, mdl_cols] = preds
results.close()
if self.verbose:
print('PREDICTION complete!')
return results
@staticmethod
def _clear_tfhub_cache():
cache_dir = os.getenv('TFHUB_CACHE_DIR')
if cache_dir is None:
cache_dir = '/tmp/tfhub_modules/'
if not os.path.isdir(cache_dir):
raise Exception('Cannot find tfhub cache directory, ' +
'please set TFHUB_CACHE_DIR variable')
shutil.rmtree(cache_dir)
class SignaturizerResult():
"""Class storing result of the prediction.
......@@ -137,6 +157,8 @@ class SignaturizerResult():
# simple numpy arrays
self.h5 = None
self.signature = np.zeros((size, features), dtype=np.float32)
self.dataset = np.zeros((int(features / 128),),
dtype=h5py.special_dtype(vlen=str))
else:
# check if the file exists already
if os.path.isfile(self.dst):
......@@ -149,8 +171,12 @@ class SignaturizerResult():
self.h5 = h5py.File(self.dst, 'w')
self.h5.create_dataset(
'signature', (size, features), dtype=np.float32)
self.h5.create_dataset(
'dataset', (int(features / 128),),
dtype=h5py.special_dtype(vlen=str))
# expose the datasets
self.signature = self.h5['signature']
self.dataset = self.h5['dataset']
def close(self):
if self.h5 is None:
......@@ -160,3 +186,4 @@ class SignaturizerResult():
self.h5 = h5py.File(self.dst, 'r')
# expose the datasets
self.signature = self.h5['signature']
self.dataset = self.h5['dataset']
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment