Commit 4fb066cc authored by Martino Bertoni's avatar Martino Bertoni 🌋
Browse files

added function to apply mask to dataset, improved chunk iterator, added...

added function to apply mask to dataset, improved chunk iterator, added function to add raw data to datasets
parent 3af24250
...@@ -29,6 +29,24 @@ class DataSignature(object): ...@@ -29,6 +29,24 @@ class DataSignature(object):
self.PVALRANGES = np.array( self.PVALRANGES = np.array(
[0, 0.001, 0.01, 0.1] + list(np.arange(1, 100)) + [100]) / 100. [0, 0.001, 0.01, 0.1] + list(np.arange(1, 100)) + [100]) / 100.
def add_datasets(self, data_dict, overwrite=True):
"""Add dataset to a H5"""
for k, v in data_dict.items():
with h5py.File(self.data_path, 'a') as hf:
if k in hf.keys():
if overwrite:
del hf[k]
else:
self.__log.info('Skipping `%s~`: already there')
continue
if isinstance(v, list):
if hasattr(v[0], 'decode') or isinstance(v[0], str) or isinstance(v[0], np.str_):
v = self.h5_str(v)
else:
if hasattr(v.flat[0], 'decode') or isinstance(v.flat[0], str) or isinstance(v.flat[0], np.str_):
v = self.h5_str(v)
hf.create_dataset(k, data=v)
def _check_data(self): def _check_data(self):
"""Test if data file is available""" """Test if data file is available"""
if not os.path.isfile(self.data_path): if not os.path.isfile(self.data_path):
...@@ -40,17 +58,20 @@ class DataSignature(object): ...@@ -40,17 +58,20 @@ class DataSignature(object):
if key not in hf.keys(): if key not in hf.keys():
raise Exception("No '%s' dataset in this signature!" % key) raise Exception("No '%s' dataset in this signature!" % key)
@staticmethod def clear(self):
def _decode(arg): with h5py.File(self.data_path, 'w') as hf:
"""Apply decode function to input""" pass
decoder = np.vectorize(lambda x: x.decode())
return decoder(arg)
def _get_shape(self, key): def _get_shape(self, key, axis=None):
"""Get shape of dataset""" """Get shape of dataset"""
with h5py.File(self.data_path, 'r') as hf: with h5py.File(self.data_path, 'r') as hf:
data = hf[key].shape data = hf[key].shape
return data if axis != None:
if len(data) == axis:
return data[0]
return data[axis]
else:
return data
def _get_dtype(self, key): def _get_dtype(self, key):
"""Get shape of dataset""" """Get shape of dataset"""
...@@ -63,25 +84,31 @@ class DataSignature(object): ...@@ -63,25 +84,31 @@ class DataSignature(object):
with h5py.File(self.data_path, 'r') as hf: with h5py.File(self.data_path, 'r') as hf:
data = hf[key][:] data = hf[key][:]
if hasattr(data.flat[0], 'decode'): if hasattr(data.flat[0], 'decode'):
return self._decode(data) return data.astype(str)
return data return data
def _get_chunk(self, key, chunk): def _get_data_chunk(self, key, chunk, axis=0):
"""Get chunk of dataset""" """Get chunk of dataset"""
with h5py.File(self.data_path, 'r') as hf: with h5py.File(self.data_path, 'r') as hf:
data = hf[key][chunk] if axis == 0:
data = hf[key][chunk]
else:
data = hf[key][:, chunk]
if hasattr(data.flat[0], 'decode'): if hasattr(data.flat[0], 'decode'):
return self._decode(data) return data.astype(str)
return data return data
def chunk_iter(self, key, chunk_size): def chunk_iter(self, key, chunk_size, axis=0, chunk=False):
"""Iterator on chunks of data""" """Iterator on chunks of data"""
self._check_data() self._check_data()
self._check_dataset(key) self._check_dataset(key)
tot_size = self._get_shape(key)[0] tot_size = self._get_shape(key, axis)
for i in range(0, tot_size, chunk_size): for i in range(0, tot_size, chunk_size):
chunk = slice(i, i + chunk_size) mychunk = slice(i, i + chunk_size)
yield self._get_chunk(key, chunk) if chunk:
yield mychunk, self._get_data_chunk(key, mychunk, axis)
else:
yield self._get_data_chunk(key, mychunk, axis)
def __iter__(self): def __iter__(self):
"""By default iterate on signatures V.""" """By default iterate on signatures V."""
...@@ -249,8 +276,7 @@ class DataSignature(object): ...@@ -249,8 +276,7 @@ class DataSignature(object):
def make_filtered_copy(self, destination, mask, include_all=False, def make_filtered_copy(self, destination, mask, include_all=False,
data_file=None): data_file=None):
""" """Make a copy of applying a filtering mask on rows.
Make a copy of applying a filtering mask on rows.
destination (str): The destination file path. destination (str): The destination file path.
mask (bool array): A numpy mask array (e.g. result of `np.isin`) mask (bool array): A numpy mask array (e.g. result of `np.isin`)
...@@ -297,6 +323,47 @@ class DataSignature(object): ...@@ -297,6 +323,47 @@ class DataSignature(object):
hf_out[dset][idx_dst] = hf_in[dset][idx_src] hf_out[dset][idx_dst] = hf_in[dset][idx_src]
idx_dst += 1 idx_dst += 1
def filter_h5_dataset(self, key, mask, axis):
"""Apply a maks to a dataset, dropping columns or rows.
key (str): The H5 dataset to filter.
mask (np.array): A bool one dimensional mask array. True values will
be kept.
axis (int): Wether the mask refers to rows (0) or columns (1).
"""
self._check_dataset(key)
if self._get_shape(key, axis) != mask.shape[0]:
raise Exception("Shape mismatch:", self._get_shape(
key, axis), mask.shape[0])
key_tmp = "%s_tmp" % key
with h5py.File(self.data_path, "a") as hf:
if key_tmp in hf.keys():
self.__log.debug('Deleting pre-existing `%s`' % key_tmp)
del hf[key_tmp]
# if we have a list directly apply the mask
if hf[key].ndim == 1:
hf.create_dataset(key_tmp, (sum(mask),), dtype=hf[key].dtype)
hf[key_tmp][:] = hf[key][mask]
# otherwise apply mask on chunks to avoid overloading the memory
else:
new_shape = list(hf[key].shape)
new_shape[axis] = sum(mask)
hf.create_dataset(key_tmp, new_shape, dtype=hf[key].dtype)
# if we filter out rows we iterate on smaller vertical slices
cs = 10
it_axis = 1
if axis == 1:
it_axis = 0
cs = 1000
for chunk, data in self.chunk_iter(key, cs, it_axis, True):
if axis == 1:
hf[key_tmp][chunk] = data[:, mask]
else:
hf[key_tmp][:, chunk] = data[mask]
del hf[key]
hf[key] = hf[key_tmp]
del hf[key_tmp]
@staticmethod @staticmethod
def hstack_signatures(sign_list, destination, chunk_size=1000, def hstack_signatures(sign_list, destination, chunk_size=1000,
aggregate_keys=None): aggregate_keys=None):
...@@ -366,8 +433,7 @@ class DataSignature(object): ...@@ -366,8 +433,7 @@ class DataSignature(object):
if mask is None: if mask is None:
ndim = hf[h5_dataset_name].ndim ndim = hf[h5_dataset_name].ndim
if hasattr(hf[h5_dataset_name][(0,) * ndim], 'decode'): if hasattr(hf[h5_dataset_name][(0,) * ndim], 'decode'):
decoder = np.vectorize(lambda x: x.decode()) return hf[h5_dataset_name][:].astype(str)
return decoder(hf[h5_dataset_name][:])
else: else:
return hf[h5_dataset_name][:] return hf[h5_dataset_name][:]
else: else:
...@@ -376,8 +442,7 @@ class DataSignature(object): ...@@ -376,8 +442,7 @@ class DataSignature(object):
mask = np.argwhere(mask.ravel()).ravel() mask = np.argwhere(mask.ravel()).ravel()
ndim = hf[h5_dataset_name].ndim ndim = hf[h5_dataset_name].ndim
if hasattr(hf[h5_dataset_name][(0,) * ndim], 'decode'): if hasattr(hf[h5_dataset_name][(0,) * ndim], 'decode'):
decoder = np.vectorize(lambda x: x.decode()) return hf[h5_dataset_name][mask].astype(str)
return decoder(hf[h5_dataset_name][mask])
else: else:
return hf[h5_dataset_name][mask, :] return hf[h5_dataset_name][mask, :]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment