Commit f4b15a94 authored by Martino Bertoni's avatar Martino Bertoni 🌋
Browse files

added option to specify steps per epoch, fixed some warnings

parent ed8949b3
......@@ -115,6 +115,8 @@ class SiameseTriplets(object):
self.traintest_file = kwargs.get("traintest_file", None)
self.standard = kwargs.get("standard", True)
self.trim_mask = kwargs.get("trim_mask", None)
self.steps_per_epoch = kwargs.get("steps_per_epoch", None)
self.validation_steps = kwargs.get("validation_steps", None)
# internal variables
self.name = self.__class__.__name__.lower()
......@@ -159,8 +161,9 @@ class SiameseTriplets(object):
self.generator = tr_shape_type_gen
self.tr_shapes = tr_shape_type_gen[0]
self.tr_gen = tr_shape_type_gen[2]()
self.steps_per_epoch = np.ceil(
self.tr_shapes[0][0] / self.batch_size)
if self.steps_per_epoch is None:
self.steps_per_epoch = np.ceil(
self.tr_shapes[0][0] / self.batch_size)
# load the scaler
if not self.standard:
......@@ -175,10 +178,6 @@ class SiameseTriplets(object):
pickle.dump(self.scaler, open(scaler_path, 'wb'))
else:
self.__log.warning("No scaler has been loaded")
self.scaler.scale_ = self.scaler.scale_[
np.repeat(self.trim_mask, 128)]
self.scaler.center_ = self.scaler.center_[
np.repeat(self.trim_mask, 128)]
# initialize validation/test generator
if evaluate:
......@@ -198,8 +197,9 @@ class SiameseTriplets(object):
standard=self.standard, trim_mask=self.trim_mask)
self.val_shapes = val_shape_type_gen[0]
self.val_gen = val_shape_type_gen[2]()
self.validation_steps = np.ceil(
self.val_shapes[0][0] / self.batch_size)
if self.validation_steps is None:
self.validation_steps = np.ceil(
self.val_shapes[0][0] / self.batch_size)
else:
self.val_shapes = None
self.val_gen = None
......@@ -563,9 +563,9 @@ class SiameseTriplets(object):
lfuncs_dict[self.loss_func].__name__)
if self.learning_rate == 'auto':
optimizer = keras.optimizers.Adam(lr=MIN_LR)
optimizer = keras.optimizers.Adam(learning_rate=MIN_LR)
else:
optimizer = keras.optimizers.Adam(lr=self.learning_rate)
optimizer = keras.optimizers.Adam(learning_rate=self.learning_rate)
model.compile(
optimizer=optimizer,
......@@ -597,7 +597,7 @@ class SiameseTriplets(object):
self.__log.info('Trying lr %s' % lr)
lr_params['learning_rate'] = lr
siamese = SiameseTriplets(
self.model_dir, evaluate=True, plot=False, save_params=False,
self.model_dir, evaluate=True, plot=True, save_params=False,
generator=generator, **lr_params)
siamese.fit(save=False)
h_file = os.path.join(
......@@ -748,7 +748,8 @@ class SiameseTriplets(object):
standard=self.standard, trim_mask=self.trim_mask)
validation_sets.append((gen, shapes, name))
additional_vals = AdditionalValidationSets(
validation_sets, self.model, batch_size=self.batch_size)
validation_sets, self.model, batch_size=self.batch_size,
validation_steps=self.validation_steps)
callbacks.append(additional_vals)
class CustomEarlyStopping(EarlyStopping):
......@@ -831,8 +832,8 @@ class SiameseTriplets(object):
# call fit and save model
t0 = time()
self.history = self.model.fit_generator(
generator=self.tr_gen,
self.history = self.model.fit(
self.tr_gen,
steps_per_epoch=self.steps_per_epoch,
epochs=self.epochs,
callbacks=callbacks,
......@@ -857,9 +858,9 @@ class SiameseTriplets(object):
pickle.dump(self.history.history, open(history_file, 'wb'))
history_file = os.path.join(self.model_dir, "history.png")
anchor_file = os.path.join(self.model_dir, "anchor_distr.png")
if self.evaluate:
if self.evaluate and self.plot:
self._plot_history(self.history.history, vsets, history_file)
if not self.standard:
if not self.standard and self.plot:
self._plot_anchor_dist(anchor_file)
def predict(self, x_matrix, dropout_fn=None, dropout_samples=10, cp=False):
......@@ -869,8 +870,10 @@ class SiameseTriplets(object):
split(str): which split to predict.
batch_size(int): batch size for prediction.
"""
# apply input scaling
if hasattr(self, 'scaler'):
# scaler has already been trimmed
scaled = self.scaler.transform(x_matrix)
else:
scaled = x_matrix
......@@ -880,7 +883,7 @@ class SiameseTriplets(object):
trimmed = scaled[:, np.repeat(self.trim_mask, 128)]
else:
trimmed = scaled
# load model if not alredy there
if self.model is None:
self.build_model((trimmed.shape[1],), load=True, cp=cp)
......@@ -1091,7 +1094,8 @@ class SiameseTriplets(object):
class AdditionalValidationSets(Callback):
def __init__(self, validation_sets, model, verbose=1, batch_size=None):
def __init__(self, validation_sets, model, verbose=1, batch_size=None,
validation_steps=None):
"""
validation_sets(list): list of 3-tuples (val_data, val_targets,
val_set_name) or 4-tuples (val_data, val_targets, sample_weights,
......@@ -1102,6 +1106,9 @@ class AdditionalValidationSets(Callback):
"""
super(AdditionalValidationSets, self).__init__()
self.validation_sets = validation_sets
self.validation_steps = validation_steps
if self.validation_steps is None:
self.validation_steps = np.ceil(val_shapes[0][0] / self.batch_size)
self.epoch = []
self.history = {}
self.verbose = verbose
......@@ -1122,9 +1129,9 @@ class AdditionalValidationSets(Callback):
# evaluate on the additional validation sets
for val_gen, val_shapes, val_set_name in self.validation_sets:
results = self.model.evaluate_generator(
results = self.model.evaluate(
val_gen(),
steps=np.ceil(val_shapes[0][0] / self.batch_size),
steps=self.validation_steps,
verbose=self.verbose)
for i, result in enumerate(results):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment