Commit 9588b0ef authored by Martino Bertoni's avatar Martino Bertoni 🌋

updated model export and load to TF 2.3

parent 73e72fa8
import os
import shutil
import tempfile
import warnings
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=FutureWarning)
import tensorflow.compat.v1 as tf
import tensorflow_hub as hub
import tensorflow as tf
class SignaturizerModule(tf.Module):
def __init__(self, signature_mdl):
self.signature_mdl = signature_mdl
@tf.function(input_signature=[tf.TensorSpec(shape=(None, 2048), dtype=tf.float32)])
def signature(self, mfp):
results = self.signature_mdl(mfp)
return {"signature": results}
class SignaturizerApplicabilityModule(tf.Module):
def __init__(self, signature_mdl, applicability_mdl):
self.signature_mdl = signature_mdl
self.applicability_mdl = applicability_mdl
@tf.function(input_signature=[tf.TensorSpec(shape=(None, 2048), dtype=tf.float32)])
def signature(self, mfp):
results = self.signature_mdl(mfp)
return {"signature": results}
@tf.function(input_signature=[tf.TensorSpec(shape=(None, 2048), dtype=tf.float32)])
def applicability(self, mfp):
results = self.applicability_mdl(mfp)
return {"applicability": results}
def export_smilespred(smilespred_path, destination,
tmp_path=None, clear_tmp=True, compress=True,
applicability_path=None):
"""Export our Keras Smiles predictor to the TF-hub module format."""
from keras import backend as K
from chemicalchecker.tool.smilespred import Smilespred
from chemicalchecker.tool.smilespred import ApplicabilityPredictor
if tmp_path is None:
tmp_path = tempfile.mkdtemp()
# save to savedmodel format
with tf.Graph().as_default():
smilespred = Smilespred(smilespred_path)
smilespred.build_model(load=True)
model = smilespred.model
'''
with tf.Session() as sess:
sess.run(tf.tables_initializer())
sess.run(tf.global_variables_initializer())
tf.saved_model.simple_save(
sess,
tmp_path,
inputs={'default': model.input},
outputs={'default': model.output}
)
'''
signature = tf.saved_model.signature_def_utils.predict_signature_def(
inputs={'default': model.input}, outputs={'default': model.output})
signature_def_map = {'serving_default': signature}
if applicability_path is not None:
app_pred = ApplicabilityPredictor(applicability_path)
app_pred.build_model(load=True)
app_model = app_pred.model
signature_app = tf.saved_model.signature_def_utils.predict_signature_def(
inputs={'default': app_model.input},
outputs={'default': app_model.output})
signature_def_map.update({'applicability': signature_app})
if tmp_path is None:
tmp_path = tempfile.mkdtemp()
builder = tf.saved_model.builder.SavedModelBuilder(tmp_path)
builder.add_meta_graph_and_variables(
sess=K.get_session(),
tags=['serve'],
signature_def_map=signature_def_map)
builder.save()
# now export savedmodel to module
export_savedmodel(tmp_path, destination, compress=compress)
# clean temporary folder
if clear_tmp:
shutil.rmtree(tmp_path)
def export_savedmodel(savedmodel_path, destination,
tmp_path=None, clear_tmp=True, compress=True):
"""Export Tensorflow SavedModel to the TF-hub module format."""
if tmp_path is None:
tmp_path = tempfile.mkdtemp()
# save to hub module format
with tf.Graph().as_default():
spec = hub.create_module_spec_from_saved_model(savedmodel_path)
module = hub.Module(spec, tags=['serve'])
with tf.Session() as sess:
sess.run(tf.tables_initializer())
sess.run(tf.global_variables_initializer())
module.export(tmp_path, sess)
# load models
smilespred = Smilespred(smilespred_path)
smilespred.build_model(load=True)
model = smilespred.model
# save simple modelor combined with applicability
if applicability_path is None:
full_mdl = SignaturizerModule(model)
tf.saved_model.save(full_mdl, tmp_path,
signatures={
"signature": full_mdl.signature})
else:
app_pred = ApplicabilityPredictor(applicability_path)
app_pred.build_model(load=True)
app_model = app_pred.model
# combine in a module and save to savedmodel format
full_mdl = SignaturizerApplicabilityModule(model, app_model)
tf.saved_model.save(full_mdl, tmp_path,
signatures={
"signature": full_mdl.signature,
"applicability": full_mdl.applicability})
# now export savedmodel to tfhub module
if compress:
# compress the exported files to destination
os.system("tar -cz -f %s --owner=0 --group=0 -C %s ." %
......
......@@ -4,14 +4,11 @@ import h5py
import shutil
import numpy as np
from tqdm import tqdm
import warnings
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)
import tensorflow.compat.v1 as tf
import tensorflow_hub as hub
from tensorflow.compat.v1.keras.models import Model
from tensorflow.compat.v1.keras import Input
import tensorflow as tf
import tensorflow_hub as hub
from tensorflow.keras import Model
from tensorflow.keras import Input
try:
from rdkit import Chem
from rdkit import RDLogger
......@@ -20,8 +17,6 @@ except ImportError:
raise ImportError("requires RDKit " +
"https://www.rdkit.org/docs/Install.html")
tf.logging.set_verbosity(tf.logging.ERROR)
class Signaturizer(object):
"""Signaturizer Class.
......@@ -31,7 +26,7 @@ class Signaturizer(object):
def __init__(self, model_name,
base_url="http://chemicalchecker.com/api/db/getSignaturizer/",
version='2020_02', local=False, tf_version='1', verbose=False,
version='2020_02', local=False, verbose=False,
applicability=True):
"""Initialize a Signaturizer instance.
......@@ -46,7 +41,6 @@ class Signaturizer(object):
version(int): Signaturizer version.
local(bool): Wethere the specified model_name shoudl be
interpreted as a path to a local model.
tf_version(int): The Tesorflow version.
verbose(bool): If True some more information will be printed.
applicability(bool): Wether to also compute the applicability of
each prediction.
......@@ -69,11 +63,6 @@ class Signaturizer(object):
main_input = Input(shape=(2048,), dtype=tf.float32, name='main_input')
sign_output = list()
app_output = list()
as_dict = False
output_key = 'default'
if len(self.model_names) == 1:
as_dict = True
output_key = None
for name in self.model_names:
# build module spec
if local:
......@@ -88,10 +77,10 @@ class Signaturizer(object):
if self.verbose:
print('LOADING remote:', url)
sign_layer = hub.KerasLayer(url, signature='serving_default',
sign_layer = hub.KerasLayer(url, signature='signature',
trainable=False, tags=['serve'],
output_key=output_key,
signature_outputs_as_dict=as_dict)
output_key='signature',
signature_outputs_as_dict=False)
sign_output.append(sign_layer(main_input))
if self.applicability:
......@@ -99,8 +88,8 @@ class Signaturizer(object):
app_layer = hub.KerasLayer(
url, signature='applicability',
trainable=False, tags=['serve'],
output_key=output_key,
signature_outputs_as_dict=as_dict)
output_key='applicability',
signature_outputs_as_dict=False)
app_output.append(app_layer(main_input))
except Exception as ex:
print('WARNING: applicability predictions not available. '
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment