Commit ed8949b3 authored by Martino Bertoni's avatar Martino Bertoni 🌋
Browse files

SNN redict: first scale then trim

parent 9bd4b674
Pipeline #2482 passed with stages
in 11 minutes and 31 seconds
......@@ -869,21 +869,24 @@ class SiameseTriplets(object):
split(str): which split to predict.
batch_size(int): batch size for prediction.
"""
# apply input scaling
if hasattr(self, 'scaler'):
scaled = self.scaler.transform(x_matrix)
else:
scaled = x_matrix
# apply trimming of input matrix
if self.trim_mask is not None:
trimmed = x_matrix[:, np.repeat(self.trim_mask, 128)]
trimmed = scaled[:, np.repeat(self.trim_mask, 128)]
else:
trimmed = x_matrix
trimmed = scaled
# load model if not alredy there
if self.model is None:
self.build_model((trimmed.shape[1],), load=True, cp=cp)
# apply input scaling
if hasattr(self, 'scaler'):
scaled = self.scaler.transform(trimmed)
else:
scaled = trimmed
# get rid of NaNs
no_nans = np.nan_to_num(scaled)
no_nans = np.nan_to_num(trimmed)
# get default dropout function
if dropout_fn is None:
return self.transformer.predict(no_nans)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment