Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Packages
chemical_checker
Commits
ff8dea23
Commit
ff8dea23
authored
Nov 29, 2021
by
Martino Bertoni
🌋
Browse files
also save shuffled keys, jsut in case
parent
e48b72fe
Changes
1
Hide whitespace changes
Inline
Side-by-side
package/chemicalchecker/util/splitter/neighbortriplet.py
View file @
ff8dea23
...
...
@@ -377,6 +377,8 @@ class NeighborTripletTraintest(object):
combo_dists
=
dict
()
with
h5py
.
File
(
out_file
,
"w"
)
as
fh
:
fh
.
create_dataset
(
'x'
,
data
=
X
)
fh
.
create_dataset
(
'x_ink'
,
data
=
np
.
array
(
X_inks
,
dtype
=
DataSignature
.
string_dtype
()))
if
mean_center_x
:
fh
.
create_dataset
(
'scaler'
,
...
...
@@ -441,7 +443,8 @@ class NeighborTripletTraintest(object):
# save list of split indeces
#anchors_split = np.repeat(np.arange(len(neig_idxs)), triplet_per_mol)
# anchors_split = np.repeat(np.arange(len(neig_idxs)),
# triplet_per_mol)
easy_a_split
=
list
()
easy_p_split
=
list
()
easy_n_split
=
list
()
...
...
@@ -505,8 +508,8 @@ class NeighborTripletTraintest(object):
hard_n_split
.
extend
(
h_negatives
)
# easy negatives (sampled from everywhere; in general should be fine altough it may sample positives...)
#e_negatives = np.random.choice(len(neig_idxs), triplet_per_mol*2, replace=True)
#e_negatives = np.random.choice(list(e_negatives.difference(neig_idxs[idx])), triplet_per_mol, replace=True)
#
e_negatives = np.random.choice(len(neig_idxs), triplet_per_mol*2, replace=True)
#
e_negatives = np.random.choice(list(e_negatives.difference(neig_idxs[idx])), triplet_per_mol, replace=True)
# if too slow just uncomment the following and comment the
# above
e_negatives
=
np
.
random
.
choice
(
...
...
@@ -515,7 +518,8 @@ class NeighborTripletTraintest(object):
# get reference ids
NeighborTripletTraintest
.
__log
.
info
(
"Mapping triplets"
)
#anchors_ref = [split_ref_map[split2][x] for x in anchors_split]
# anchors_ref = [split_ref_map[split2][x] for x in
# anchors_split]
easy_a_ref
=
[
split_ref_map
[
split2
][
x
]
for
x
in
easy_a_split
]
easy_p_ref
=
[
split_ref_map
[
split1
][
x
]
for
x
in
easy_p_split
]
easy_n_ref
=
[
split_ref_map
[
split1
][
x
]
for
x
in
easy_n_split
]
...
...
@@ -704,7 +708,8 @@ class NeighborTripletTraintest(object):
if
sharedx_trim
is
not
None
:
X
=
sharedx_trim
else
:
X
=
sharedx
[:,
np
.
argwhere
(
np
.
repeat
(
trim_mask
,
128
)).
ravel
()]
X
=
sharedx
[:,
np
.
argwhere
(
np
.
repeat
(
trim_mask
,
128
)).
ravel
()]
else
:
NeighborTripletTraintest
.
__log
.
debug
(
'Reading X in memory'
)
if
trim_mask
is
None
:
...
...
@@ -713,7 +718,8 @@ class NeighborTripletTraintest(object):
if
sharedx_trim
is
not
None
:
X
=
sharedx_trim
else
:
X
=
reader
.
get_x_columns
(
np
.
argwhere
(
np
.
repeat
(
trim_mask
,
128
)).
ravel
())
X
=
reader
.
get_x_columns
(
np
.
argwhere
(
np
.
repeat
(
trim_mask
,
128
)).
ravel
())
NeighborTripletTraintest
.
__log
.
debug
(
'X shape: %s'
%
str
(
X
.
shape
))
# default mask is not masking
if
mask_fn
is
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment