Commit 4fb066cc authored by Martino Bertoni's avatar Martino Bertoni 🌋
Browse files

added function to apply mask to dataset, improved chunk iterator, added...

added function to apply mask to dataset, improved chunk iterator, added function to add raw data to datasets
parent 3af24250
......@@ -29,6 +29,24 @@ class DataSignature(object):
self.PVALRANGES = np.array(
[0, 0.001, 0.01, 0.1] + list(np.arange(1, 100)) + [100]) / 100.
def add_datasets(self, data_dict, overwrite=True):
"""Add dataset to a H5"""
for k, v in data_dict.items():
with h5py.File(self.data_path, 'a') as hf:
if k in hf.keys():
if overwrite:
del hf[k]
else:
self.__log.info('Skipping `%s~`: already there')
continue
if isinstance(v, list):
if hasattr(v[0], 'decode') or isinstance(v[0], str) or isinstance(v[0], np.str_):
v = self.h5_str(v)
else:
if hasattr(v.flat[0], 'decode') or isinstance(v.flat[0], str) or isinstance(v.flat[0], np.str_):
v = self.h5_str(v)
hf.create_dataset(k, data=v)
def _check_data(self):
"""Test if data file is available"""
if not os.path.isfile(self.data_path):
......@@ -40,17 +58,20 @@ class DataSignature(object):
if key not in hf.keys():
raise Exception("No '%s' dataset in this signature!" % key)
@staticmethod
def _decode(arg):
"""Apply decode function to input"""
decoder = np.vectorize(lambda x: x.decode())
return decoder(arg)
def clear(self):
with h5py.File(self.data_path, 'w') as hf:
pass
def _get_shape(self, key):
def _get_shape(self, key, axis=None):
"""Get shape of dataset"""
with h5py.File(self.data_path, 'r') as hf:
data = hf[key].shape
return data
if axis != None:
if len(data) == axis:
return data[0]
return data[axis]
else:
return data
def _get_dtype(self, key):
"""Get shape of dataset"""
......@@ -63,25 +84,31 @@ class DataSignature(object):
with h5py.File(self.data_path, 'r') as hf:
data = hf[key][:]
if hasattr(data.flat[0], 'decode'):
return self._decode(data)
return data.astype(str)
return data
def _get_chunk(self, key, chunk):
def _get_data_chunk(self, key, chunk, axis=0):
"""Get chunk of dataset"""
with h5py.File(self.data_path, 'r') as hf:
data = hf[key][chunk]
if axis == 0:
data = hf[key][chunk]
else:
data = hf[key][:, chunk]
if hasattr(data.flat[0], 'decode'):
return self._decode(data)
return data.astype(str)
return data
def chunk_iter(self, key, chunk_size):
def chunk_iter(self, key, chunk_size, axis=0, chunk=False):
"""Iterator on chunks of data"""
self._check_data()
self._check_dataset(key)
tot_size = self._get_shape(key)[0]
tot_size = self._get_shape(key, axis)
for i in range(0, tot_size, chunk_size):
chunk = slice(i, i + chunk_size)
yield self._get_chunk(key, chunk)
mychunk = slice(i, i + chunk_size)
if chunk:
yield mychunk, self._get_data_chunk(key, mychunk, axis)
else:
yield self._get_data_chunk(key, mychunk, axis)
def __iter__(self):
"""By default iterate on signatures V."""
......@@ -249,8 +276,7 @@ class DataSignature(object):
def make_filtered_copy(self, destination, mask, include_all=False,
data_file=None):
"""
Make a copy of applying a filtering mask on rows.
"""Make a copy of applying a filtering mask on rows.
destination (str): The destination file path.
mask (bool array): A numpy mask array (e.g. result of `np.isin`)
......@@ -297,6 +323,47 @@ class DataSignature(object):
hf_out[dset][idx_dst] = hf_in[dset][idx_src]
idx_dst += 1
def filter_h5_dataset(self, key, mask, axis):
"""Apply a maks to a dataset, dropping columns or rows.
key (str): The H5 dataset to filter.
mask (np.array): A bool one dimensional mask array. True values will
be kept.
axis (int): Wether the mask refers to rows (0) or columns (1).
"""
self._check_dataset(key)
if self._get_shape(key, axis) != mask.shape[0]:
raise Exception("Shape mismatch:", self._get_shape(
key, axis), mask.shape[0])
key_tmp = "%s_tmp" % key
with h5py.File(self.data_path, "a") as hf:
if key_tmp in hf.keys():
self.__log.debug('Deleting pre-existing `%s`' % key_tmp)
del hf[key_tmp]
# if we have a list directly apply the mask
if hf[key].ndim == 1:
hf.create_dataset(key_tmp, (sum(mask),), dtype=hf[key].dtype)
hf[key_tmp][:] = hf[key][mask]
# otherwise apply mask on chunks to avoid overloading the memory
else:
new_shape = list(hf[key].shape)
new_shape[axis] = sum(mask)
hf.create_dataset(key_tmp, new_shape, dtype=hf[key].dtype)
# if we filter out rows we iterate on smaller vertical slices
cs = 10
it_axis = 1
if axis == 1:
it_axis = 0
cs = 1000
for chunk, data in self.chunk_iter(key, cs, it_axis, True):
if axis == 1:
hf[key_tmp][chunk] = data[:, mask]
else:
hf[key_tmp][:, chunk] = data[mask]
del hf[key]
hf[key] = hf[key_tmp]
del hf[key_tmp]
@staticmethod
def hstack_signatures(sign_list, destination, chunk_size=1000,
aggregate_keys=None):
......@@ -366,8 +433,7 @@ class DataSignature(object):
if mask is None:
ndim = hf[h5_dataset_name].ndim
if hasattr(hf[h5_dataset_name][(0,) * ndim], 'decode'):
decoder = np.vectorize(lambda x: x.decode())
return decoder(hf[h5_dataset_name][:])
return hf[h5_dataset_name][:].astype(str)
else:
return hf[h5_dataset_name][:]
else:
......@@ -376,8 +442,7 @@ class DataSignature(object):
mask = np.argwhere(mask.ravel()).ravel()
ndim = hf[h5_dataset_name].ndim
if hasattr(hf[h5_dataset_name][(0,) * ndim], 'decode'):
decoder = np.vectorize(lambda x: x.decode())
return decoder(hf[h5_dataset_name][mask])
return hf[h5_dataset_name][mask].astype(str)
else:
return hf[h5_dataset_name][mask, :]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment