Commit 06618a6d authored by Martino Bertoni's avatar Martino Bertoni 🌋
Browse files

bugfix in confidence, data must be loaded from unshuffled train.h5

parent ff8dea23
Pipeline #2575 failed with stages
in 10 minutes and 6 seconds
......@@ -121,7 +121,7 @@ class sign3(BaseSignature, DataSignature):
return None
full_trim = np.argwhere(np.repeat(self.trim_mask, 128))
self._sharedx_trim = self.sharedx[:, full_trim.ravel()]
self.__log.debug("sharedx_trim shape: %s" %
self.__log.debug("sharedx_trim shape: %s" %
str(self._sharedx_trim.shape))
return self._sharedx_trim
......@@ -249,7 +249,7 @@ class sign3(BaseSignature, DataSignature):
if ref_cc is not None:
cc_ref = ChemicalChecker(ref_cc)
else:
cc_ref = ChemicalChecker()
cc_ref = self.get_cc()
sign3.__log.info(
"Reference CC (for predict methods): %s" % cc_ref.cc_root)
# create CC instance to host signatures
......@@ -470,7 +470,8 @@ class sign3(BaseSignature, DataSignature):
return siamese, prior_model, prior_sign_model, confidence_model
def train_confidence(self, siamese, suffix='eval', traintest_file=None,
max_x=10000, max_neig=50000, p_self=0.0):
train_file=None, max_x=10000, max_neig=50000,
p_self=0.0):
"""Train confidence and prior models."""
# get sorted keys from siamese traintest file
self.update_status('Training applicability')
......@@ -479,25 +480,30 @@ class sign3(BaseSignature, DataSignature):
'traintest_%s.h5' % suffix)
if not os.path.isfile(traintest_file):
raise Exception('Traintest_file not found: %s' % traintest_file)
if train_file is None:
train_file = os.path.join(self.model_path, 'train.h5')
if not os.path.isfile(train_file):
raise Exception('Train_file not found: %s' % train_file)
self.traintest_file = traintest_file
tt = DataSignature(self.traintest_file)
test_inks = tt.get_h5_dataset('keys_test')[:max_x]
test_inks = np.sort(test_inks)
train_inks = tt.get_h5_dataset('keys_train')[:max_neig]
train_inks = np.sort(train_inks)
traintest = DataSignature(self.traintest_file)
test_inks = traintest.get_h5_dataset('keys_test')[:max_x]
train_inks = traintest.get_h5_dataset('keys_train')[:max_neig]
# confidence is going to be trained only on siamese test data
test_mask = np.isin(list(self.sign2_self.keys), list(test_inks),
assume_unique=True)
assume_unique=True)
train_mask = np.isin(list(self.sign2_self.keys), list(train_inks),
assume_unique=True)
# confidence is going to be trained only on siamese test data
confidence_train_x = self.sharedx[test_mask]
s2_test = self.sign2_self.get_h5_dataset('V', mask=test_mask)
assume_unique=True)
train = DataSignature(train_file)
confidence_train_x = train.get_h5_dataset('x', mask=test_mask)
#s2_test = self.sign2_self.get_h5_dataset('V', mask=test_mask)
_, s2_test = self.sign2_self.get_vectors(test_inks)
s2_test_x = confidence_train_x[:, self.dataset_idx[0]
* 128: (self.dataset_idx[0] + 1) * 128]
self.__log.debug('self.dataset_idx: %s' % str(self.dataset_idx))
assert(np.all(s2_test == s2_test_x))
# siamese train is going to be used for appticability domain
known_x = self.sharedx[train_mask]
known_x = train.get_h5_dataset('x', mask=train_mask)
# generate train-test split for confidence estimation
split_names = ['train', 'test']
split_fractions = [0.8, 0.2]
......@@ -2238,7 +2244,7 @@ def plot_subsample(sign, plot_file, sign2_coverage, traintest_file,
import matplotlib.pyplot as plt
from chemicalchecker import ChemicalChecker
cc = ChemicalChecker()
cc = sign.get_cc()
# NICO sign2_list
if sign2_list is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment