Commit ed8949b3 authored by Martino Bertoni's avatar Martino Bertoni 🌋
Browse files

SNN redict: first scale then trim

parent 9bd4b674
Pipeline #2482 passed with stages
in 11 minutes and 31 seconds
...@@ -869,21 +869,24 @@ class SiameseTriplets(object): ...@@ -869,21 +869,24 @@ class SiameseTriplets(object):
split(str): which split to predict. split(str): which split to predict.
batch_size(int): batch size for prediction. batch_size(int): batch size for prediction.
""" """
# apply input scaling
if hasattr(self, 'scaler'):
scaled = self.scaler.transform(x_matrix)
else:
scaled = x_matrix
# apply trimming of input matrix # apply trimming of input matrix
if self.trim_mask is not None: if self.trim_mask is not None:
trimmed = x_matrix[:, np.repeat(self.trim_mask, 128)] trimmed = scaled[:, np.repeat(self.trim_mask, 128)]
else: else:
trimmed = x_matrix trimmed = scaled
# load model if not alredy there # load model if not alredy there
if self.model is None: if self.model is None:
self.build_model((trimmed.shape[1],), load=True, cp=cp) self.build_model((trimmed.shape[1],), load=True, cp=cp)
# apply input scaling
if hasattr(self, 'scaler'):
scaled = self.scaler.transform(trimmed)
else:
scaled = trimmed
# get rid of NaNs # get rid of NaNs
no_nans = np.nan_to_num(scaled) no_nans = np.nan_to_num(trimmed)
# get default dropout function # get default dropout function
if dropout_fn is None: if dropout_fn is None:
return self.transformer.predict(no_nans) return self.transformer.predict(no_nans)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment