Commit de74f99b authored by Martino Bertoni's avatar Martino Bertoni 🌋
Browse files

the signature to use for defining triplets is now parametric

parent f02f095d
Pipeline #2698 failed with stages
in 78 minutes and 56 seconds
......@@ -411,7 +411,7 @@ class sign3(BaseSignature, DataSignature):
cpu = params.get('cpu', 1)
if not reuse or not os.path.isfile(self.traintest_file):
NeighborTripletTraintest.create(
X, self.traintest_file, self.neig_sign,
X, self.traintest_file, self.triplet_sign,
split_names=['train', 'test'],
split_fractions=[.8, .2],
suffix=suffix,
......@@ -424,7 +424,7 @@ class sign3(BaseSignature, DataSignature):
cpu = params.get('cpu', 1)
if not reuse or not os.path.isfile(self.traintest_file):
NeighborTripletTraintest.create(
X, self.traintest_file, self.neig_sign,
X, self.traintest_file, self.triplet_sign,
split_names=['train'],
split_fractions=[1.0],
suffix=suffix,
......@@ -546,7 +546,7 @@ class sign3(BaseSignature, DataSignature):
from chemicalchecker.tool.siamese import SiameseTriplets
except ImportError:
raise ImportError("requires tensorflow https://tensorflow.org")
sign1_self = cc.get_signature("sign1", "full", self.dataset)
triplet_sign = cc.get_signature("sign1", "full", self.dataset)
sign2_self = cc.get_signature("sign2", "full", self.dataset)
sign2_list = [cc.get_signature("sign2", "full", d)
for d in cc.datasets_exemplary()]
......@@ -557,7 +557,7 @@ class sign3(BaseSignature, DataSignature):
cc.cc_root, 'full', 'all_sign2_coverage.h5')
self.src_datasets = [sign.dataset for sign in sign2_list]
self.neig_sign = sign1_self
self.triplet_sign = triplet_sign
self.sign2_self = sign2_self
self.sign2_list = sign2_list
self.sign2_coverage = sign2_coverage
......@@ -1442,7 +1442,7 @@ class sign3(BaseSignature, DataSignature):
p_self=0.0, subsampling=False):
# applicability is whether not-self preds is close to only-self preds
# neighbors between 5 and 25 depending on the size of the dataset
app_thr = int(np.clip(np.log10(self.neig_sign.shape[0])**2, 5, 25))
app_thr = int(np.clip(np.log10(self.triplet_sign.shape[0])**2, 5, 25))
if subsampling:
if dropout_fn is None:
dropout_fn, _ = self.realistic_subsampling_fn()
......@@ -1715,7 +1715,7 @@ class sign3(BaseSignature, DataSignature):
inchikeys = sorted(list(inchikeys))
return inchikeys
def fit(self, sign2_list=None, sign2_self=None, sign1_self=None,
def fit(self, sign2_list=None, sign2_self=None, triplet_sign=None,
sign2_universe=None, complete_universe='full',
sign2_coverage=None,
model_confidence=True, save_correlations=False,
......@@ -1726,7 +1726,8 @@ class sign3(BaseSignature, DataSignature):
Args:
sign2_list(list): List of signature 2 objects to learn from.
sign2_self(sign2): Signature 2 of the current space.
sign2_self(sign1): Signature 1 of the current space.
triplet_sign(sign1): Signature used to define acnhor positive and
negative in triplets.
sign2_universe(str): Path to the union of all signatures 2 for all
molecules in the CC universe. (~1M x 3200)
complete_universe(str): add chemistry information for molecules not
......@@ -1778,9 +1779,9 @@ class sign3(BaseSignature, DataSignature):
if sign2_self is None:
sign2_self = self.get_sign('sign2')
self.sign2_self = sign2_self
if sign1_self is None:
sign1_self = self.get_sign('sign1')
self.neig_sign = sign1_self
if triplet_sign is None:
triplet_sign = self.get_sign('sign1')
self.triplet_sign = triplet_sign
self.sign2_coverage = sign2_coverage
self.dataset_idx = np.argwhere(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment