Commit 3fc3d3fa authored by Martino Bertoni's avatar Martino Bertoni 🌋
Browse files

fixed unittests

parent 3f6e24d6
import os import os
import time
import shutil import shutil
import unittest import unittest
import numpy as np
from .helper import skip_if_import_exception, start_http_server from .helper import skip_if_import_exception, start_http_server
from signaturizer.exporter import export_smilespred, export_savedmodel from signaturizer.exporter import export_smilespred, export_savedmodel
...@@ -64,13 +64,13 @@ class TestSignaturizer(unittest.TestCase): ...@@ -64,13 +64,13 @@ class TestSignaturizer(unittest.TestCase):
module = Signaturizer(tmp_path_smilespred, local=True) module = Signaturizer(tmp_path_smilespred, local=True)
res = module.predict(self.test_smiles) res = module.predict(self.test_smiles)
pred = res.signature[:] pred = res.signature[:]
self.assertEqual(pred_ref.tolist(), pred.tolist()) np.testing.assert_almost_equal(pred_ref, pred)
# test final step # test final step
base_url = "http://localhost:%d/" % (self.server_port) base_url = "http://localhost:%d/" % (self.server_port)
module = Signaturizer(module_file, base_url=base_url, version=version) module = Signaturizer(module_file, base_url=base_url, version=version)
res = module.predict(self.test_smiles) res = module.predict(self.test_smiles)
pred = res.signature[:] pred = res.signature[:]
self.assertEqual(pred_ref.tolist(), pred.tolist()) np.testing.assert_almost_equal(pred_ref, pred)
# export savedmodel # export savedmodel
module_destination = os.path.join( module_destination = os.path.join(
...@@ -83,9 +83,9 @@ class TestSignaturizer(unittest.TestCase): ...@@ -83,9 +83,9 @@ class TestSignaturizer(unittest.TestCase):
module = Signaturizer(tmp_path_savedmodel, local=True) module = Signaturizer(tmp_path_savedmodel, local=True)
res = module.predict(self.test_smiles) res = module.predict(self.test_smiles)
pred = res.signature[:] pred = res.signature[:]
self.assertEqual(pred_ref.tolist(), pred.tolist()) np.testing.assert_almost_equal(pred_ref, pred)
# test final step # test final step
module = Signaturizer(module_file, base_url=base_url, version=version) module = Signaturizer(module_file, base_url=base_url, version=version)
res = module.predict(self.test_smiles) res = module.predict(self.test_smiles)
pred = res.signature[:] pred = res.signature[:]
self.assertEqual(pred_ref.tolist(), pred.tolist()) np.testing.assert_almost_equal(pred_ref, pred)
...@@ -3,6 +3,7 @@ import math ...@@ -3,6 +3,7 @@ import math
import pickle import pickle
import shutil import shutil
import unittest import unittest
import numpy as np
from signaturizer import Signaturizer from signaturizer import Signaturizer
...@@ -38,12 +39,12 @@ class TestSignaturizer(unittest.TestCase): ...@@ -38,12 +39,12 @@ class TestSignaturizer(unittest.TestCase):
module_dir = os.path.join(self.data_dir, 'B1') module_dir = os.path.join(self.data_dir, 'B1')
module = Signaturizer(module_dir, local=True) module = Signaturizer(module_dir, local=True)
res = module.predict(self.test_smiles) res = module.predict(self.test_smiles)
self.assertEqual(pred_ref.tolist(), res.signature.tolist()) np.testing.assert_almost_equal(pred_ref, res.signature[:])
# test saving to file # test saving to file
destination = os.path.join(self.tmp_dir, 'pred.h5') destination = os.path.join(self.tmp_dir, 'pred.h5')
res = module.predict(self.test_smiles, destination) res = module.predict(self.test_smiles, destination)
self.assertTrue(os.path.isfile(destination)) self.assertTrue(os.path.isfile(destination))
self.assertEqual(pred_ref.tolist(), res.signature[:].tolist()) np.testing.assert_almost_equal(pred_ref, res.signature[:])
# test prediction of invalid SMILES # test prediction of invalid SMILES
res = module.predict(self.invalid_smiles) res = module.predict(self.invalid_smiles)
for comp in res.signature[0]: for comp in res.signature[0]:
...@@ -67,12 +68,13 @@ class TestSignaturizer(unittest.TestCase): ...@@ -67,12 +68,13 @@ class TestSignaturizer(unittest.TestCase):
module_A1 = Signaturizer(A1_path, local=True) module_A1 = Signaturizer(A1_path, local=True)
res_A1 = module_A1.predict(self.test_smiles) res_A1 = module_A1.predict(self.test_smiles)
self.assertEqual(res_A1B1.signature[:, :128].tolist(), np.testing.assert_almost_equal(res_A1B1.signature[:, :128],
res_A1.signature.tolist()) res_A1.signature)
module_B1 = Signaturizer(B1_path, local=True) module_B1 = Signaturizer(B1_path, local=True)
res_B1 = module_B1.predict(self.test_smiles) res_B1 = module_B1.predict(self.test_smiles)
self.assertEqual(res_A1B1.signature[:, 128:].tolist(), np.testing.assert_almost_equal(res_A1B1.signature[:, 128:],
res_B1.signature.tolist()) res_B1.signature)
res = module_A1B1.predict(self.invalid_smiles) res = module_A1B1.predict(self.invalid_smiles)
for comp in res.signature[0]: for comp in res.signature[0]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment