Commit 3fc3d3fa authored by Martino Bertoni's avatar Martino Bertoni 🌋
Browse files

fixed unittests

parent 3f6e24d6
import os
import time
import shutil
import unittest
import numpy as np
from .helper import skip_if_import_exception, start_http_server
from signaturizer.exporter import export_smilespred, export_savedmodel
......@@ -64,13 +64,13 @@ class TestSignaturizer(unittest.TestCase):
module = Signaturizer(tmp_path_smilespred, local=True)
res = module.predict(self.test_smiles)
pred = res.signature[:]
self.assertEqual(pred_ref.tolist(), pred.tolist())
np.testing.assert_almost_equal(pred_ref, pred)
# test final step
base_url = "http://localhost:%d/" % (self.server_port)
module = Signaturizer(module_file, base_url=base_url, version=version)
res = module.predict(self.test_smiles)
pred = res.signature[:]
self.assertEqual(pred_ref.tolist(), pred.tolist())
np.testing.assert_almost_equal(pred_ref, pred)
# export savedmodel
module_destination = os.path.join(
......@@ -83,9 +83,9 @@ class TestSignaturizer(unittest.TestCase):
module = Signaturizer(tmp_path_savedmodel, local=True)
res = module.predict(self.test_smiles)
pred = res.signature[:]
self.assertEqual(pred_ref.tolist(), pred.tolist())
np.testing.assert_almost_equal(pred_ref, pred)
# test final step
module = Signaturizer(module_file, base_url=base_url, version=version)
res = module.predict(self.test_smiles)
pred = res.signature[:]
self.assertEqual(pred_ref.tolist(), pred.tolist())
np.testing.assert_almost_equal(pred_ref, pred)
......@@ -3,6 +3,7 @@ import math
import pickle
import shutil
import unittest
import numpy as np
from signaturizer import Signaturizer
......@@ -38,12 +39,12 @@ class TestSignaturizer(unittest.TestCase):
module_dir = os.path.join(self.data_dir, 'B1')
module = Signaturizer(module_dir, local=True)
res = module.predict(self.test_smiles)
self.assertEqual(pred_ref.tolist(), res.signature.tolist())
np.testing.assert_almost_equal(pred_ref, res.signature[:])
# test saving to file
destination = os.path.join(self.tmp_dir, 'pred.h5')
res = module.predict(self.test_smiles, destination)
self.assertTrue(os.path.isfile(destination))
self.assertEqual(pred_ref.tolist(), res.signature[:].tolist())
np.testing.assert_almost_equal(pred_ref, res.signature[:])
# test prediction of invalid SMILES
res = module.predict(self.invalid_smiles)
for comp in res.signature[0]:
......@@ -67,12 +68,13 @@ class TestSignaturizer(unittest.TestCase):
module_A1 = Signaturizer(A1_path, local=True)
res_A1 = module_A1.predict(self.test_smiles)
self.assertEqual(res_A1B1.signature[:, :128].tolist(),
res_A1.signature.tolist())
np.testing.assert_almost_equal(res_A1B1.signature[:, :128],
res_A1.signature)
module_B1 = Signaturizer(B1_path, local=True)
res_B1 = module_B1.predict(self.test_smiles)
self.assertEqual(res_A1B1.signature[:, 128:].tolist(),
res_B1.signature.tolist())
np.testing.assert_almost_equal(res_A1B1.signature[:, 128:],
res_B1.signature)
res = module_A1B1.predict(self.invalid_smiles)
for comp in res.signature[0]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment