Commit ff8dea23 authored by Martino Bertoni's avatar Martino Bertoni 🌋
Browse files

also save shuffled keys, jsut in case

parent e48b72fe
......@@ -377,6 +377,8 @@ class NeighborTripletTraintest(object):
combo_dists = dict()
with h5py.File(out_file, "w") as fh:
fh.create_dataset('x', data=X)
fh.create_dataset('x_ink', data=np.array(
X_inks, dtype=DataSignature.string_dtype()))
if mean_center_x:
fh.create_dataset(
'scaler',
......@@ -441,7 +443,8 @@ class NeighborTripletTraintest(object):
# save list of split indeces
#anchors_split = np.repeat(np.arange(len(neig_idxs)), triplet_per_mol)
# anchors_split = np.repeat(np.arange(len(neig_idxs)),
# triplet_per_mol)
easy_a_split = list()
easy_p_split = list()
easy_n_split = list()
......@@ -505,8 +508,8 @@ class NeighborTripletTraintest(object):
hard_n_split.extend(h_negatives)
# easy negatives (sampled from everywhere; in general should be fine altough it may sample positives...)
#e_negatives = np.random.choice(len(neig_idxs), triplet_per_mol*2, replace=True)
#e_negatives = np.random.choice(list(e_negatives.difference(neig_idxs[idx])), triplet_per_mol, replace=True)
# e_negatives = np.random.choice(len(neig_idxs), triplet_per_mol*2, replace=True)
# e_negatives = np.random.choice(list(e_negatives.difference(neig_idxs[idx])), triplet_per_mol, replace=True)
# if too slow just uncomment the following and comment the
# above
e_negatives = np.random.choice(
......@@ -515,7 +518,8 @@ class NeighborTripletTraintest(object):
# get reference ids
NeighborTripletTraintest.__log.info("Mapping triplets")
#anchors_ref = [split_ref_map[split2][x] for x in anchors_split]
# anchors_ref = [split_ref_map[split2][x] for x in
# anchors_split]
easy_a_ref = [split_ref_map[split2][x] for x in easy_a_split]
easy_p_ref = [split_ref_map[split1][x] for x in easy_p_split]
easy_n_ref = [split_ref_map[split1][x] for x in easy_n_split]
......@@ -704,7 +708,8 @@ class NeighborTripletTraintest(object):
if sharedx_trim is not None:
X = sharedx_trim
else:
X = sharedx[:, np.argwhere(np.repeat(trim_mask, 128)).ravel()]
X = sharedx[:, np.argwhere(
np.repeat(trim_mask, 128)).ravel()]
else:
NeighborTripletTraintest.__log.debug('Reading X in memory')
if trim_mask is None:
......@@ -713,7 +718,8 @@ class NeighborTripletTraintest(object):
if sharedx_trim is not None:
X = sharedx_trim
else:
X = reader.get_x_columns(np.argwhere(np.repeat(trim_mask, 128)).ravel())
X = reader.get_x_columns(np.argwhere(
np.repeat(trim_mask, 128)).ravel())
NeighborTripletTraintest.__log.debug('X shape: %s' % str(X.shape))
# default mask is not masking
if mask_fn is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment