Commit bbea196d authored by Martino Bertoni's avatar Martino Bertoni 🌋
Browse files

speedup on iterator, added option for chunks and compression

parent 9e075ea8
Pipeline #2683 failed with stages
in 100 minutes and 20 seconds
......@@ -22,6 +22,7 @@ try:
except:
pass
@logged
class DataSignature(object):
"""DataSignature class."""
......@@ -34,7 +35,8 @@ class DataSignature(object):
self.PVALRANGES = np.array(
[0, 0.001, 0.01, 0.1] + list(np.arange(1, 100)) + [100]) / 100.
def add_datasets(self, data_dict, overwrite=True):
def add_datasets(self, data_dict, overwrite=True, chunks=None,
compression=None):
"""Add dataset to a H5"""
for k, v in data_dict.items():
with h5py.File(self.data_path, 'a') as hf:
......@@ -50,7 +52,8 @@ class DataSignature(object):
else:
if hasattr(v.flat[0], 'decode') or isinstance(v.flat[0], str) or isinstance(v.flat[0], np.str_):
v = self.h5_str(v)
hf.create_dataset(k, data=v)
hf.create_dataset(k, data=v, chunks=chunks,
compression=compression)
def _check_data(self):
"""Test if data file is available"""
......@@ -107,15 +110,17 @@ class DataSignature(object):
self._check_data()
self._check_dataset(key)
tot_size = self._get_shape(key, axis)
with h5py.File(self.data_path, 'r') as hf:
myrange = range(0, tot_size, chunk_size)
desc = 'Iterating on `%s` axis %s' % (key, axis)
for i in tqdm(myrange, disable=not bar, desc=desc):
mychunk = slice(i, i + chunk_size)
if chunk:
yield mychunk, self._get_data_chunk(hf, key, mychunk, axis)
else:
yield self._get_data_chunk(hf, key, mychunk, axis)
if not hasattr(self, 'hdf5'):
self.open_hdf5()
hf = self.hdf5
myrange = range(0, tot_size, chunk_size)
desc = 'Iterating on `%s` axis %s' % (key, axis)
for i in tqdm(myrange, disable=not bar, desc=desc):
mychunk = slice(i, i + chunk_size)
if chunk:
yield mychunk, self._get_data_chunk(hf, key, mychunk, axis)
else:
yield self._get_data_chunk(hf, key, mychunk, axis)
def __iter__(self):
"""By default iterate on signatures V."""
......@@ -290,7 +295,8 @@ class DataSignature(object):
hf[key] = src
def make_filtered_copy(self, destination, mask, include_all=False,
data_file=None):
data_file=None, datasets=None, dst_datasets=None,
chunk_size=1000, compression=None):
"""Make a copy of applying a filtering mask on rows.
destination (str): The destination file path.
......@@ -305,38 +311,52 @@ class DataSignature(object):
data_file = self.data_path
with h5py.File(data_file, 'r') as hf_in:
with h5py.File(destination, 'w') as hf_out:
for dset in hf_in.keys():
with h5py.File(destination, 'a') as hf_out:
if datasets is None:
datasets = hf_in.keys()
if dst_datasets is None:
dst_datasets = datasets
for dset, dst_dset in zip(datasets, dst_datasets):
# skip dataset incompatible with mask (or copy unmasked)
if hf_in[dset].shape[0] != mask.shape[0]:
if not include_all:
continue
else:
masked = hf_in[dset][:][:]
hf_out.create_dataset(dset, data=masked)
hf_out.create_dataset(dst_dset, data=masked,
compression=compression)
self.__log.debug("Copy dataset %s of shape %s" %
(dset, str(masked.shape)))
continue
# never mask features
if dset == 'features':
masked = hf_in[dset][:][:]
self.__log.debug("Copy dataset %s of shape %s" %
(dset, str(masked.shape)))
hf_out.create_dataset(dset, data=masked)
hf_out.create_dataset(dst_dset, data=masked,
compression=compression)
continue
# memory safe masked copy for other datasets
# mask single value dataset all at once
if len(hf_in[dset].shape) == 1:
final_shape = (sum(mask),)
else:
final_shape = (sum(mask), hf_in[dset].shape[1])
masked = hf_in[dset][:][mask]
self.__log.debug("Copy dataset %s of shape %s" %
(dset, str(masked.shape)))
hf_out.create_dataset(dst_dset, data=masked,
compression=compression)
continue
# memory safe masked copy for other datasets
final_shape = (sum(mask), hf_in[dset].shape[1])
hf_out.create_dataset(
dset, final_shape, dtype=hf_in[dset].dtype)
dst_dset, final_shape, dtype=hf_in[dset].dtype,
compression=compression)
self.__log.debug("Copy dataset %s of shape %s" %
(dset, str(final_shape)))
idx_dst = 0
for idx_src in np.argwhere(mask).ravel():
hf_out[dset][idx_dst] = hf_in[dset][idx_src]
idx_dst += 1
for chunk, data in self.chunk_iter(dset, 100, 1, True):
hf_out[dst_dset][:, chunk] = data[mask]
def filter_h5_dataset(self, key, mask, axis, chunk_size=1000):
"""Apply a maks to a dataset, dropping columns or rows.
......@@ -460,7 +480,8 @@ class DataSignature(object):
else:
return hf[h5_dataset_name][mask, :]
def get_vectors(self, keys, include_nan=False, dataset_name='V', output_missing=False):
def get_vectors(self, keys, include_nan=False, dataset_name='V',
output_missing=False):
"""Get vectors for a list of keys, sorted by default.
Args:
......@@ -564,10 +585,13 @@ class DataSignature(object):
def open_hdf5(self):
self.hdf5 = h5py.File(self.data_path, 'r')
def __del__(self):
def close_hdf5(self):
if hasattr(self, 'hdf5'):
self.hdf5.close()
def __del__(self):
self.close_hdf5()
def __len__(self):
if not hasattr(self, 'hdf5'):
self.open_hdf5()
......@@ -589,15 +613,15 @@ class DataSignature(object):
return self.hdf5[self.ds_data][key]
if isinstance(key, list):
key = slice(min(key), max(key))
if isinstance(key, bytes):
key = key.decode("utf-8")
if isinstance(key, slice):
return self.hdf5[self.ds_data][key]
if isinstance(key, bytes):
key = key.decode("utf-8")
if isinstance(key, str):
if key not in self.unique_keys:
raise Exception("Key '%s' not found." % key)
idx = bisect_left(self.keys, key)
self.hdf5[self.ds_data][idx]
return self.hdf5[self.ds_data][idx]
else:
raise Exception("Key type %s not recognized." % type(key))
......@@ -785,7 +809,8 @@ class DataSignature(object):
src_vectors = hf['V'][:]
with h5py.File(out_file, "w") as hf:
hf.create_dataset('keys', data=np.array(
src_keys, DataSignature.string_dtype()), dtype=DataSignature.string_dtype())
src_keys, DataSignature.string_dtype()),
dtype=DataSignature.string_dtype())
hf.create_dataset('V', data=src_vectors, dtype=np.float32)
hf.create_dataset("shape", data=src_vectors.shape)
return
......@@ -807,7 +832,8 @@ class DataSignature(object):
sorted_idx = np.argsort(dst_keys)
with h5py.File(out_file, "w") as hf:
hf.create_dataset('keys', data=np.array(
dst_keys[sorted_idx], DataSignature.string_dtype()), dtype=DataSignature.string_dtype())
dst_keys[sorted_idx], DataSignature.string_dtype()),
dtype=DataSignature.string_dtype())
hf.create_dataset('V', data=matrix[sorted_idx], dtype=np.float32)
hf.create_dataset("shape", data=matrix.shape)
......@@ -839,16 +865,18 @@ class DataSignature(object):
hf_out.create_dataset("features", data=np.array(
features, DataSignature.string_dtype()))
def dataloader(self, batch_size=32, num_workers=1, shuffle=False, weak_shuffle=False, drop_last=False):
def dataloader(self, batch_size=32, num_workers=1, shuffle=False,
weak_shuffle=False, drop_last=False):
"""Return a pytorch DataLoader object for quick signature iterations."""
if weak_shuffle:
return torch.utils.data.DataLoader(
self,
batch_size=None, # must be disabled when using samplers
batch_size=None, # must be disabled when using samplers
num_workers=num_workers,
shuffle=False,
sampler=torch.utils.data.BatchSampler(
RandomBatchSampler(self, batch_size), batch_size=batch_size, drop_last=drop_last)
RandomBatchSampler(self, batch_size), batch_size=batch_size,
drop_last=drop_last)
)
else:
return torch.utils.data.DataLoader(
......@@ -859,34 +887,42 @@ class DataSignature(object):
)
class RandomBatchSampler(torch.utils.data.Sampler):
"""Sampling class to create random sequential batches from a given dataset
E.g. if data is [1,2,3,4] with bs=2. Then first batch, [[1,2], [3,4]] then shuffle batches -> [[3,4],[1,2]]
This is useful for cases when you are interested in 'weak shuffling'
https://towardsdatascience.com/reading-h5-files-faster-with-pytorch-datasets-3ff86938cc
:param dataset: dataset you want to batch
:type dataset: torch.utils.data.Dataset
:param batch_size: batch size
:type batch_size: int
:returns: generator object of shuffled batch indices
"""
def __init__(self, dataset, batch_size):
self.batch_size = batch_size
self.dataset_length = len(dataset)
self.n_batches = self.dataset_length / self.batch_size
self.batch_ids = torch.randperm(int(self.n_batches))
def __len__(self):
return self.batch_size
def __iter__(self):
for id in self.batch_ids:
idx = torch.arange(id * self.batch_size, (id + 1) * self.batch_size)
for index in idx:
yield int(index)
if int(self.n_batches) < self.n_batches:
idx = torch.arange(int(self.n_batches) * self.batch_size, self.dataset_length)
for index in idx:
yield int(index)
try:
class RandomBatchSampler(torch.utils.data.Sampler):
"""Sampling class to create random sequential batches of a dataset.
E.g. if data is [1,2,3,4] with bs=2. Then first batch, [[1,2], [3,4]]
then shuffle batches -> [[3,4],[1,2]]
This is useful for cases when you are interested in 'weak shuffling'
https://towardsdatascience.com/
reading-h5-files-faster-with-pytorch-datasets-3ff86938cc
:param dataset: dataset you want to batch
:type dataset: torch.utils.data.Dataset
:param batch_size: batch size
:type batch_size: int
:returns: generator object of shuffled batch indices
"""
def __init__(self, dataset, batch_size):
self.batch_size = batch_size
self.dataset_length = len(dataset)
self.n_batches = self.dataset_length / self.batch_size
self.batch_ids = torch.randperm(int(self.n_batches))
def __len__(self):
return self.batch_size
def __iter__(self):
for id in self.batch_ids:
idx = torch.arange(id * self.batch_size,
(id + 1) * self.batch_size)
for index in idx:
yield int(index)
if int(self.n_batches) < self.n_batches:
idx = torch.arange(int(self.n_batches) *
self.batch_size, self.dataset_length)
for index in idx:
yield int(index)
except:
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment