Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Packages
chemical_checker
Commits
bbea196d
Commit
bbea196d
authored
Apr 25, 2022
by
Martino Bertoni
🌋
Browse files
speedup on iterator, added option for chunks and compression
parent
9e075ea8
Pipeline
#2683
failed with stages
in 100 minutes and 20 seconds
Changes
1
Pipelines
5
Hide whitespace changes
Inline
Side-by-side
package/chemicalchecker/core/signature_data.py
View file @
bbea196d
...
...
@@ -22,6 +22,7 @@ try:
except
:
pass
@
logged
class
DataSignature
(
object
):
"""DataSignature class."""
...
...
@@ -34,7 +35,8 @@ class DataSignature(object):
self
.
PVALRANGES
=
np
.
array
(
[
0
,
0.001
,
0.01
,
0.1
]
+
list
(
np
.
arange
(
1
,
100
))
+
[
100
])
/
100.
def
add_datasets
(
self
,
data_dict
,
overwrite
=
True
):
def
add_datasets
(
self
,
data_dict
,
overwrite
=
True
,
chunks
=
None
,
compression
=
None
):
"""Add dataset to a H5"""
for
k
,
v
in
data_dict
.
items
():
with
h5py
.
File
(
self
.
data_path
,
'a'
)
as
hf
:
...
...
@@ -50,7 +52,8 @@ class DataSignature(object):
else
:
if
hasattr
(
v
.
flat
[
0
],
'decode'
)
or
isinstance
(
v
.
flat
[
0
],
str
)
or
isinstance
(
v
.
flat
[
0
],
np
.
str_
):
v
=
self
.
h5_str
(
v
)
hf
.
create_dataset
(
k
,
data
=
v
)
hf
.
create_dataset
(
k
,
data
=
v
,
chunks
=
chunks
,
compression
=
compression
)
def
_check_data
(
self
):
"""Test if data file is available"""
...
...
@@ -107,15 +110,17 @@ class DataSignature(object):
self
.
_check_data
()
self
.
_check_dataset
(
key
)
tot_size
=
self
.
_get_shape
(
key
,
axis
)
with
h5py
.
File
(
self
.
data_path
,
'r'
)
as
hf
:
myrange
=
range
(
0
,
tot_size
,
chunk_size
)
desc
=
'Iterating on `%s` axis %s'
%
(
key
,
axis
)
for
i
in
tqdm
(
myrange
,
disable
=
not
bar
,
desc
=
desc
):
mychunk
=
slice
(
i
,
i
+
chunk_size
)
if
chunk
:
yield
mychunk
,
self
.
_get_data_chunk
(
hf
,
key
,
mychunk
,
axis
)
else
:
yield
self
.
_get_data_chunk
(
hf
,
key
,
mychunk
,
axis
)
if
not
hasattr
(
self
,
'hdf5'
):
self
.
open_hdf5
()
hf
=
self
.
hdf5
myrange
=
range
(
0
,
tot_size
,
chunk_size
)
desc
=
'Iterating on `%s` axis %s'
%
(
key
,
axis
)
for
i
in
tqdm
(
myrange
,
disable
=
not
bar
,
desc
=
desc
):
mychunk
=
slice
(
i
,
i
+
chunk_size
)
if
chunk
:
yield
mychunk
,
self
.
_get_data_chunk
(
hf
,
key
,
mychunk
,
axis
)
else
:
yield
self
.
_get_data_chunk
(
hf
,
key
,
mychunk
,
axis
)
def
__iter__
(
self
):
"""By default iterate on signatures V."""
...
...
@@ -290,7 +295,8 @@ class DataSignature(object):
hf
[
key
]
=
src
def
make_filtered_copy
(
self
,
destination
,
mask
,
include_all
=
False
,
data_file
=
None
):
data_file
=
None
,
datasets
=
None
,
dst_datasets
=
None
,
chunk_size
=
1000
,
compression
=
None
):
"""Make a copy of applying a filtering mask on rows.
destination (str): The destination file path.
...
...
@@ -305,38 +311,52 @@ class DataSignature(object):
data_file
=
self
.
data_path
with
h5py
.
File
(
data_file
,
'r'
)
as
hf_in
:
with
h5py
.
File
(
destination
,
'w'
)
as
hf_out
:
for
dset
in
hf_in
.
keys
():
with
h5py
.
File
(
destination
,
'a'
)
as
hf_out
:
if
datasets
is
None
:
datasets
=
hf_in
.
keys
()
if
dst_datasets
is
None
:
dst_datasets
=
datasets
for
dset
,
dst_dset
in
zip
(
datasets
,
dst_datasets
):
# skip dataset incompatible with mask (or copy unmasked)
if
hf_in
[
dset
].
shape
[
0
]
!=
mask
.
shape
[
0
]:
if
not
include_all
:
continue
else
:
masked
=
hf_in
[
dset
][:][:]
hf_out
.
create_dataset
(
dset
,
data
=
masked
)
hf_out
.
create_dataset
(
dst_dset
,
data
=
masked
,
compression
=
compression
)
self
.
__log
.
debug
(
"Copy dataset %s of shape %s"
%
(
dset
,
str
(
masked
.
shape
)))
continue
# never mask features
if
dset
==
'features'
:
masked
=
hf_in
[
dset
][:][:]
self
.
__log
.
debug
(
"Copy dataset %s of shape %s"
%
(
dset
,
str
(
masked
.
shape
)))
hf_out
.
create_dataset
(
dset
,
data
=
masked
)
hf_out
.
create_dataset
(
dst_dset
,
data
=
masked
,
compression
=
compression
)
continue
# memory safe masked copy for other datasets
# mask single value dataset all at once
if
len
(
hf_in
[
dset
].
shape
)
==
1
:
final_shape
=
(
sum
(
mask
),)
else
:
final_shape
=
(
sum
(
mask
),
hf_in
[
dset
].
shape
[
1
])
masked
=
hf_in
[
dset
][:][
mask
]
self
.
__log
.
debug
(
"Copy dataset %s of shape %s"
%
(
dset
,
str
(
masked
.
shape
)))
hf_out
.
create_dataset
(
dst_dset
,
data
=
masked
,
compression
=
compression
)
continue
# memory safe masked copy for other datasets
final_shape
=
(
sum
(
mask
),
hf_in
[
dset
].
shape
[
1
])
hf_out
.
create_dataset
(
dset
,
final_shape
,
dtype
=
hf_in
[
dset
].
dtype
)
dst_dset
,
final_shape
,
dtype
=
hf_in
[
dset
].
dtype
,
compression
=
compression
)
self
.
__log
.
debug
(
"Copy dataset %s of shape %s"
%
(
dset
,
str
(
final_shape
)))
idx_dst
=
0
for
idx_src
in
np
.
argwhere
(
mask
).
ravel
():
hf_out
[
dset
][
idx_dst
]
=
hf_in
[
dset
][
idx_src
]
idx_dst
+=
1
for
chunk
,
data
in
self
.
chunk_iter
(
dset
,
100
,
1
,
True
):
hf_out
[
dst_dset
][:,
chunk
]
=
data
[
mask
]
def
filter_h5_dataset
(
self
,
key
,
mask
,
axis
,
chunk_size
=
1000
):
"""Apply a maks to a dataset, dropping columns or rows.
...
...
@@ -460,7 +480,8 @@ class DataSignature(object):
else
:
return
hf
[
h5_dataset_name
][
mask
,
:]
def
get_vectors
(
self
,
keys
,
include_nan
=
False
,
dataset_name
=
'V'
,
output_missing
=
False
):
def
get_vectors
(
self
,
keys
,
include_nan
=
False
,
dataset_name
=
'V'
,
output_missing
=
False
):
"""Get vectors for a list of keys, sorted by default.
Args:
...
...
@@ -564,10 +585,13 @@ class DataSignature(object):
def
open_hdf5
(
self
):
self
.
hdf5
=
h5py
.
File
(
self
.
data_path
,
'r'
)
def
__del__
(
self
):
def
close_hdf5
(
self
):
if
hasattr
(
self
,
'hdf5'
):
self
.
hdf5
.
close
()
def
__del__
(
self
):
self
.
close_hdf5
()
def
__len__
(
self
):
if
not
hasattr
(
self
,
'hdf5'
):
self
.
open_hdf5
()
...
...
@@ -589,15 +613,15 @@ class DataSignature(object):
return
self
.
hdf5
[
self
.
ds_data
][
key
]
if
isinstance
(
key
,
list
):
key
=
slice
(
min
(
key
),
max
(
key
))
if
isinstance
(
key
,
bytes
):
key
=
key
.
decode
(
"utf-8"
)
if
isinstance
(
key
,
slice
):
return
self
.
hdf5
[
self
.
ds_data
][
key
]
if
isinstance
(
key
,
bytes
):
key
=
key
.
decode
(
"utf-8"
)
if
isinstance
(
key
,
str
):
if
key
not
in
self
.
unique_keys
:
raise
Exception
(
"Key '%s' not found."
%
key
)
idx
=
bisect_left
(
self
.
keys
,
key
)
self
.
hdf5
[
self
.
ds_data
][
idx
]
return
self
.
hdf5
[
self
.
ds_data
][
idx
]
else
:
raise
Exception
(
"Key type %s not recognized."
%
type
(
key
))
...
...
@@ -785,7 +809,8 @@ class DataSignature(object):
src_vectors
=
hf
[
'V'
][:]
with
h5py
.
File
(
out_file
,
"w"
)
as
hf
:
hf
.
create_dataset
(
'keys'
,
data
=
np
.
array
(
src_keys
,
DataSignature
.
string_dtype
()),
dtype
=
DataSignature
.
string_dtype
())
src_keys
,
DataSignature
.
string_dtype
()),
dtype
=
DataSignature
.
string_dtype
())
hf
.
create_dataset
(
'V'
,
data
=
src_vectors
,
dtype
=
np
.
float32
)
hf
.
create_dataset
(
"shape"
,
data
=
src_vectors
.
shape
)
return
...
...
@@ -807,7 +832,8 @@ class DataSignature(object):
sorted_idx
=
np
.
argsort
(
dst_keys
)
with
h5py
.
File
(
out_file
,
"w"
)
as
hf
:
hf
.
create_dataset
(
'keys'
,
data
=
np
.
array
(
dst_keys
[
sorted_idx
],
DataSignature
.
string_dtype
()),
dtype
=
DataSignature
.
string_dtype
())
dst_keys
[
sorted_idx
],
DataSignature
.
string_dtype
()),
dtype
=
DataSignature
.
string_dtype
())
hf
.
create_dataset
(
'V'
,
data
=
matrix
[
sorted_idx
],
dtype
=
np
.
float32
)
hf
.
create_dataset
(
"shape"
,
data
=
matrix
.
shape
)
...
...
@@ -839,16 +865,18 @@ class DataSignature(object):
hf_out
.
create_dataset
(
"features"
,
data
=
np
.
array
(
features
,
DataSignature
.
string_dtype
()))
def
dataloader
(
self
,
batch_size
=
32
,
num_workers
=
1
,
shuffle
=
False
,
weak_shuffle
=
False
,
drop_last
=
False
):
def
dataloader
(
self
,
batch_size
=
32
,
num_workers
=
1
,
shuffle
=
False
,
weak_shuffle
=
False
,
drop_last
=
False
):
"""Return a pytorch DataLoader object for quick signature iterations."""
if
weak_shuffle
:
return
torch
.
utils
.
data
.
DataLoader
(
self
,
batch_size
=
None
,
# must be disabled when using samplers
batch_size
=
None
,
# must be disabled when using samplers
num_workers
=
num_workers
,
shuffle
=
False
,
sampler
=
torch
.
utils
.
data
.
BatchSampler
(
RandomBatchSampler
(
self
,
batch_size
),
batch_size
=
batch_size
,
drop_last
=
drop_last
)
RandomBatchSampler
(
self
,
batch_size
),
batch_size
=
batch_size
,
drop_last
=
drop_last
)
)
else
:
return
torch
.
utils
.
data
.
DataLoader
(
...
...
@@ -859,34 +887,42 @@ class DataSignature(object):
)
class
RandomBatchSampler
(
torch
.
utils
.
data
.
Sampler
):
"""Sampling class to create random sequential batches from a given dataset
E.g. if data is [1,2,3,4] with bs=2. Then first batch, [[1,2], [3,4]] then shuffle batches -> [[3,4],[1,2]]
This is useful for cases when you are interested in 'weak shuffling'
https://towardsdatascience.com/reading-h5-files-faster-with-pytorch-datasets-3ff86938cc
:param dataset: dataset you want to batch
:type dataset: torch.utils.data.Dataset
:param batch_size: batch size
:type batch_size: int
:returns: generator object of shuffled batch indices
"""
def
__init__
(
self
,
dataset
,
batch_size
):
self
.
batch_size
=
batch_size
self
.
dataset_length
=
len
(
dataset
)
self
.
n_batches
=
self
.
dataset_length
/
self
.
batch_size
self
.
batch_ids
=
torch
.
randperm
(
int
(
self
.
n_batches
))
def
__len__
(
self
):
return
self
.
batch_size
def
__iter__
(
self
):
for
id
in
self
.
batch_ids
:
idx
=
torch
.
arange
(
id
*
self
.
batch_size
,
(
id
+
1
)
*
self
.
batch_size
)
for
index
in
idx
:
yield
int
(
index
)
if
int
(
self
.
n_batches
)
<
self
.
n_batches
:
idx
=
torch
.
arange
(
int
(
self
.
n_batches
)
*
self
.
batch_size
,
self
.
dataset_length
)
for
index
in
idx
:
yield
int
(
index
)
try
:
class
RandomBatchSampler
(
torch
.
utils
.
data
.
Sampler
):
"""Sampling class to create random sequential batches of a dataset.
E.g. if data is [1,2,3,4] with bs=2. Then first batch, [[1,2], [3,4]]
then shuffle batches -> [[3,4],[1,2]]
This is useful for cases when you are interested in 'weak shuffling'
https://towardsdatascience.com/
reading-h5-files-faster-with-pytorch-datasets-3ff86938cc
:param dataset: dataset you want to batch
:type dataset: torch.utils.data.Dataset
:param batch_size: batch size
:type batch_size: int
:returns: generator object of shuffled batch indices
"""
def
__init__
(
self
,
dataset
,
batch_size
):
self
.
batch_size
=
batch_size
self
.
dataset_length
=
len
(
dataset
)
self
.
n_batches
=
self
.
dataset_length
/
self
.
batch_size
self
.
batch_ids
=
torch
.
randperm
(
int
(
self
.
n_batches
))
def
__len__
(
self
):
return
self
.
batch_size
def
__iter__
(
self
):
for
id
in
self
.
batch_ids
:
idx
=
torch
.
arange
(
id
*
self
.
batch_size
,
(
id
+
1
)
*
self
.
batch_size
)
for
index
in
idx
:
yield
int
(
index
)
if
int
(
self
.
n_batches
)
<
self
.
n_batches
:
idx
=
torch
.
arange
(
int
(
self
.
n_batches
)
*
self
.
batch_size
,
self
.
dataset_length
)
for
index
in
idx
:
yield
int
(
index
)
except
:
pass
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment