Commit 136755e9 authored by Martino Bertoni's avatar Martino Bertoni 🌋
Browse files

now we can export signature and applicability prediction as signle module

parent 4f307b45
......@@ -9,10 +9,12 @@ with warnings.catch_warnings():
def export_smilespred(smilespred_path, destination,
tmp_path=None, clear_tmp=True, compress=True):
tmp_path=None, clear_tmp=True, compress=True,
applicability_path=None):
"""Export our Keras Smiles predictor to the TF-hub module format."""
from keras import backend as K
from chemicalchecker.tool.smilespred import Smilespred
from chemicalchecker.tool.smilespred import ApplicabilityPredictor
if tmp_path is None:
tmp_path = tempfile.mkdtemp()
......@@ -34,13 +36,22 @@ def export_smilespred(smilespred_path, destination,
'''
signature = tf.saved_model.signature_def_utils.predict_signature_def(
inputs={'default': model.input}, outputs={'default': model.output})
signature_def_map = {'serving_default': signature}
if applicability_path is not None:
app_pred = ApplicabilityPredictor(applicability_path)
app_pred.build_model(load=True)
app_model = app_pred.model
signature_app = tf.saved_model.signature_def_utils.predict_signature_def(
inputs={'default': app_model.input},
outputs={'default': app_model.output})
signature_def_map.update({'applicability': signature_app})
if tmp_path is None:
tmp_path = tempfile.mkdtemp()
builder = tf.saved_model.builder.SavedModelBuilder(tmp_path)
builder.add_meta_graph_and_variables(
sess=K.get_session(),
tags=['serve'],
signature_def_map={'serving_default': signature})
signature_def_map=signature_def_map)
builder.save()
# now export savedmodel to module
export_savedmodel(tmp_path, destination, compress=compress)
......
......@@ -35,6 +35,53 @@ class TestSignaturizer(unittest.TestCase):
shutil.rmtree(self.tmp_dir)
pass
def test_export(self):
# export smilespred
version = 'vXXX'
module_file = 'dest_smilespred.tar.gz'
module_destination = os.path.join(
self.tmp_dir, version, module_file)
tmp_path_smilespred = os.path.join(self.tmp_dir, 'export_smilespred')
smilespred_path = os.path.join(self.data_dir, 'models', 'smiles')
export_smilespred(smilespred_path, module_destination,
tmp_path=tmp_path_smilespred, clear_tmp=False)
base_url = "http://localhost:%d/" % (self.server_port)
module = Signaturizer(module_file, base_url=base_url, version=version)
res = module.predict(self.test_smiles)
pred = res.signature[:]
ref_pred_file = os.path.join(
self.data_dir, 'models', 'smiles_pred.npy')
#np.save(ref_pred_file, pred)
pred_ref = np.load(ref_pred_file)
np.testing.assert_almost_equal(pred_ref, pred)
def test_export_applicability(self):
# export smilespred and applicability
version = 'vXXX'
module_file = 'dest_smilespred.tar.gz'
module_destination = os.path.join(
self.tmp_dir, version, module_file)
tmp_path_smilespred = os.path.join(self.tmp_dir, 'export_smilespred')
smilespred_path = os.path.join(self.data_dir, 'models', 'smiles')
apppred_path = os.path.join(self.data_dir, 'models', 'applicability')
export_smilespred(smilespred_path, module_destination,
tmp_path=tmp_path_smilespred, clear_tmp=False,
applicability_path=apppred_path)
base_url = "http://localhost:%d/" % (self.server_port)
module = Signaturizer(module_file, base_url=base_url, version=version)
res = module.predict(self.test_smiles, applicability=True)
pred = res.signature[:]
ref_pred_file = os.path.join(
self.data_dir, 'models', 'smiles_pred.npy')
pred_ref = np.load(ref_pred_file)
np.testing.assert_almost_equal(pred_ref, pred)
apred = res.applicability[:]
ref_apred_file = os.path.join(
self.data_dir, 'models', 'applicability_pred.npy')
#np.save(ref_apred_file, apred)
apred_ref = np.load(ref_apred_file)
np.testing.assert_almost_equal(apred_ref, apred)
@skip_if_import_exception
def test_export_consistency(self):
"""Compare the exported module to the original SMILES predictor.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment