siamese_triplets.py 45.9 KB
Newer Older
1
2
3
4
5
6
import os
import pickle
import numpy as np
from time import time
from functools import partial

7
from tensorflow import keras
Martino Bertoni's avatar
Martino Bertoni committed
8
9
10
11
12
13
14
15
from tensorflow.keras import backend as K
from tensorflow.keras.layers import concatenate
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.callbacks import EarlyStopping, Callback
from tensorflow.keras.layers import Input, Dropout, Lambda, Dense
from tensorflow.keras.layers import Activation, Masking, BatchNormalization
from tensorflow.keras.layers import GaussianNoise, AlphaDropout, GaussianDropout
from tensorflow.keras import regularizers
16
17

from chemicalchecker.util import logged
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
18
from chemicalchecker.util.splitter import NeighborTripletTraintest
19
from .callbacks import CyclicLR, LearningRateFinder
20

21
MIN_LR = 1e-8
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
22
MAX_LR = 1e-1
23

24
25
26
27

class AlphaDropoutCP(keras.layers.AlphaDropout):

    def __init__(self, rate, cp=None, noise_shape=None, seed=None, **kwargs):
Martino Bertoni's avatar
Martino Bertoni committed
28
        super(AlphaDropoutCP, self).__init__(rate, **kwargs)
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
        self.cp = cp
        self.rate = rate
        self.noise_shape = noise_shape
        self.seed = seed
        self.supports_masking = True

    def _get_noise_shape(self, inputs):
        return self.noise_shape if self.noise_shape else K.shape(inputs)

    def call(self, inputs, training=None):
        if 0. < self.rate < 1.:
            noise_shape = self._get_noise_shape(inputs)

            def dropped_inputs(inputs=inputs, rate=self.rate, seed=self.seed):
                alpha = 1.6732632423543772848170429916717
                scale = 1.0507009873554804934193349852946
                alpha_p = -alpha * scale

                kept_idx = K.greater_equal(K.random_uniform(noise_shape,
                                                            seed=seed), rate)
                kept_idx = K.cast(kept_idx, K.floatx())

                # Get affine transformation params
                a = ((1 - rate) * (1 + rate * alpha_p ** 2)) ** -0.5
                b = -a * alpha_p * rate

                # Apply mask
                x = inputs * kept_idx + alpha_p * (1 - kept_idx)

                # Do affine transformation
                return a * x + b

            if self.cp:
                return dropped_inputs()
            return K.in_train_phase(dropped_inputs, inputs, training=training)
        return inputs


67
68
69
70
71
72
73
74
@logged
class SiameseTriplets(object):
    """Siamese class.

    This class implements a simple siamese neural network based on Keras that
    allows metric learning.
    """

75
    def __init__(self, model_dir, evaluate=False, predict_only=False,
76
                 plot=True, save_params=True, generator=None, **kwargs):
77
78
79
80
81
82
83
84
        """Initialize the Siamese class.

        Args:
            model_dir(str): Directorty where models will be stored.
            traintest_file(str): Path to the traintest file.
            evaluate(bool): Whether to run evaluation.
        """
        from chemicalchecker.core.signature_data import DataSignature
85
86
87
88
89
        # check if parameter file exists
        param_file = os.path.join(model_dir, 'params.pkl')
        if os.path.isfile(param_file):
            kwargs = pickle.load(open(param_file, 'rb'))
            self.__log.info('Parameters loaded from: %s' % param_file)
90
91
92
        # read parameters
        self.epochs = int(kwargs.get("epochs", 10))
        self.batch_size = int(kwargs.get("batch_size", 100))
93
        self.learning_rate = kwargs.get("learning_rate", 'auto')
94
95
        self.replace_nan = float(kwargs.get("replace_nan", 0.0))
        self.split = str(kwargs.get("split", 'train'))
96
        self.layers_sizes = kwargs.get("layers_sizes", [128])
97
98
99
100
101
102
103
104
        self.layers = list()
        # we can pass layers type as strings
        layers = kwargs.get("layers", [Dense])
        for l in layers:
            if isinstance(l, str):
                self.layers.append(eval(l))
            else:
                self.layers.append(l)
105
        self.activations = kwargs.get("activations",
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
106
                                      ['relu'])
107
        self.dropouts = kwargs.get(
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
108
            "dropouts", [None])
109
        self.augment_fn = kwargs.get("augment_fn", None)
110
        self.augment_kwargs = kwargs.get("augment_kwargs", {})
111
        self.loss_func = str(kwargs.get("loss_func", 'orthogonal_tloss'))
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
112
113
        self.margin = float(kwargs.get("margin", 1.0))
        self.alpha = float(kwargs.get("alpha", 1.0))
114
        self.patience = float(kwargs.get("patience", self.epochs))
115
        self.traintest_file = kwargs.get("traintest_file", None)
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
116
        self.standard = kwargs.get("standard", True)
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
117
        self.trim_mask = kwargs.get("trim_mask", None)
118
119

        # internal variables
Martino Bertoni's avatar
Martino Bertoni committed
120
        self.name = self.__class__.__name__.lower()
121
122
123
124
125
        self.time = 0
        self.model_dir = os.path.abspath(model_dir)
        self.model_file = os.path.join(self.model_dir, "%s.h5" % self.name)
        self.model = None
        self.evaluate = evaluate
126
        self.plot = plot
127
128
129
130
131
132
133

        # check output path
        if not os.path.exists(model_dir):
            self.__log.warning("Creating model directory: %s", self.model_dir)
            os.mkdir(self.model_dir)

        # check input path
134
        self.sharedx = None
135
        if self.traintest_file is not None:
136
            traintest_data = DataSignature(self.traintest_file)
137
            if not predict_only:
138
139
140
141
142
                self.traintest_file = os.path.abspath(self.traintest_file)
                if not os.path.exists(self.traintest_file):
                    raise Exception('Input data file does not exists!')

                # initialize train generator
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
                if generator is None:
                    if self.sharedx is None:
                        self.sharedx = traintest_data.get_h5_dataset('x')
                    tr_shape_type_gen = NeighborTripletTraintest.generator_fn(
                        self.traintest_file,
                        'train_train',
                        epochs=self.epochs,
                        batch_size=self.batch_size,
                        replace_nan=self.replace_nan,
                        sharedx=self.sharedx,
                        augment_fn=self.augment_fn,
                        augment_kwargs=self.augment_kwargs,
                        train=True, standard=self.standard,
                        trim_mask=self.trim_mask)
                else:
                    tr_shape_type_gen = generator
                self.generator = tr_shape_type_gen
160
161
162
163
                self.tr_shapes = tr_shape_type_gen[0]
                self.tr_gen = tr_shape_type_gen[2]()
                self.steps_per_epoch = np.ceil(
                    self.tr_shapes[0][0] / self.batch_size)
164

165
            # load the scaler
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
            if not self.standard:
                scaler_path = os.path.join(self.model_dir, 'scaler.pkl')
                if os.path.isfile(scaler_path):
                    self.scaler = pickle.load(open(scaler_path, 'rb'))
                    self.__log.info("Using scaler: %s", scaler_path)
                elif 'scaler' in traintest_data.info_h5:
                    scaler_path_tt = traintest_data.get_h5_dataset('scaler')[0]
                    self.__log.info("Using scaler: %s", scaler_path_tt)
                    self.scaler = pickle.load(open(scaler_path_tt, 'rb'))
                    pickle.dump(self.scaler, open(scaler_path, 'wb'))
                else:
                    self.__log.warning("No scaler has been loaded")
                self.scaler.scale_ = self.scaler.scale_[
                    np.repeat(self.trim_mask, 128)]
                self.scaler.center_ = self.scaler.center_[
                    np.repeat(self.trim_mask, 128)]
Martino Bertoni's avatar
Martino Bertoni committed
182

183
184
        # initialize validation/test generator
        if evaluate:
185
186
187
            traintest_data = DataSignature(self.traintest_file)
            if self.sharedx is None:
                self.sharedx = traintest_data.get_h5_dataset('x')
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
188
            val_shape_type_gen = NeighborTripletTraintest.generator_fn(
189
190
191
192
                self.traintest_file,
                'test_test',
                batch_size=self.batch_size,
                replace_nan=self.replace_nan,
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
193
194
                augment_kwargs=self.augment_kwargs,
                augment_fn=self.augment_fn,
195
                sharedx=self.sharedx,
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
196
                train=False,
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
197
                shuffle=False,
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
198
                standard=self.standard, trim_mask=self.trim_mask)
199
200
201
202
203
204
205
206
207
208
209
210
            self.val_shapes = val_shape_type_gen[0]
            self.val_gen = val_shape_type_gen[2]()
            self.validation_steps = np.ceil(
                self.val_shapes[0][0] / self.batch_size)
        else:
            self.val_shapes = None
            self.val_gen = None
            self.validation_steps = None

        # log parameters
        self.__log.info("**** %s Parameters: ***" % self.__class__.__name__)
        self.__log.info("{:<22}: {:>12}".format("model_dir", self.model_dir))
211
        if self.traintest_file is not None and not predict_only:
212
213
            self.__log.info("{:<22}: {:>12}".format(
                "traintest_file", self.traintest_file))
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
214
            tmp = NeighborTripletTraintest(self.traintest_file, 'train_train')
215
            self.__log.info("{:<22}: {:>12}".format(
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
216
                'train_train', str(tmp.get_ty_shapes())))
217
            if evaluate:
Martino Bertoni's avatar
Martino Bertoni committed
218
219
                tmp = NeighborTripletTraintest(
                    self.traintest_file, 'train_test')
220
                self.__log.info("{:<22}: {:>12}".format(
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
221
                    'train_test', str(tmp.get_ty_shapes())))
Martino Bertoni's avatar
Martino Bertoni committed
222
223
                tmp = NeighborTripletTraintest(
                    self.traintest_file, 'test_test')
224
                self.__log.info("{:<22}: {:>12}".format(
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
225
                    'test_test', str(tmp.get_ty_shapes())))
226
        self.__log.info("{:<22}: {:>12}".format(
227
            "learning_rate", self.learning_rate))
228
229
230
231
232
233
234
        self.__log.info("{:<22}: {:>12}".format(
            "epochs", self.epochs))
        self.__log.info("{:<22}: {:>12}".format(
            "batch_size", self.batch_size))
        self.__log.info("{:<22}: {:>12}".format(
            "layers", str(self.layers)))
        self.__log.info("{:<22}: {:>12}".format(
235
236
237
238
239
            "layers_sizes", str(self.layers_sizes)))
        self.__log.info("{:<22}: {:>12}".format(
            "activations", str(self.activations)))
        self.__log.info("{:<22}: {:>12}".format(
            "dropouts", str(self.dropouts)))
240
241
242
243
244
245
        self.__log.info("{:<22}: {:>12}".format(
            "augment_fn", str(self.augment_fn)))
        self.__log.info("{:<22}: {:>12}".format(
            "augment_kwargs", str(self.augment_kwargs)))
        self.__log.info("**** %s Parameters: ***" % self.__class__.__name__)

246
247
        if self.learning_rate == 'auto':
            self.__log.debug("Searching for optimal larning rates.")
248
            lr = self.find_lr(kwargs, generator=self.generator)
249
250
251
            self.learning_rate = lr
            kwargs['learning_rate'] = self.learning_rate

Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
252
        if not os.path.isfile(param_file) and save_params:
253
254
255
            self.__log.debug("Saving parameters to %s" % param_file)
            with open(param_file, "wb") as f:
                pickle.dump(kwargs, f)
256

257
    def build_model(self, input_shape, load=False, cp=None):
258
259
260
261
262
        """Compile Keras model

        input_shape(tuple): X dimensions (only nr feat is needed)
        load(bool): Whether to load the pretrained model.
        """
263
264
        def get_model_arch(input_dim, space_dim=128, num_layers=3):
            if input_dim >= space_dim * (2**num_layers):
Martino Bertoni's avatar
Martino Bertoni committed
265
266
                layers = [int(space_dim * 2**i)
                          for i in reversed(range(num_layers))]
267
            else:
Martino Bertoni's avatar
Martino Bertoni committed
268
269
                layers = [max(128, int(input_dim / 2**i))
                          for i in range(1, num_layers + 1)]
270
            return layers
271
272
273
274
275

        def dist_output_shape(shapes):
            shape1, shape2 = shapes
            return (shape1[0], 1)

Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
276
        def euclidean_distance(x, y):
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
277
278
            sum_square = K.sum(K.square(x - y), axis=1, keepdims=True)
            return K.sqrt(K.maximum(sum_square, K.epsilon()))
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
279

280
        def add_layer(net, layer, layer_size, activation, dropout,
281
                      use_bias=True, input_shape=False):
282
283
284
285
            if input_shape is not None:
                if activation == 'selu':
                    net.add(GaussianDropout(rate=0.1, input_shape=input_shape))
                    net.add(layer(layer_size, use_bias=use_bias,
286
                                  kernel_initializer='lecun_normal'))
287
                else:
288
289
                    net.add(layer(layer_size, use_bias=use_bias,
                                  input_shape=input_shape))
290
            else:
291
292
                if activation == 'selu':
                    net.add(layer(layer_size, use_bias=use_bias,
293
                                  kernel_initializer='lecun_normal'))
294
295
                else:
                    net.add(layer(layer_size, use_bias=use_bias))
296
            net.add(Activation(activation))
297
            if dropout is not None:
298
299
300
301
                if activation == 'selu':
                    net.add(AlphaDropoutCP(dropout, cp=cp))
                else:
                    net.add(Dropout(dropout))
302

303
304
305
306
        # we have two inputs
        input_a = Input(shape=input_shape)
        input_p = Input(shape=input_shape)
        input_n = Input(shape=input_shape)
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
307
308
309
        if not self.standard:
            input_o = Input(shape=input_shape)
            input_s = Input(shape=input_shape)
310

Martino Bertoni's avatar
Martino Bertoni committed
311
        # Update layers
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
312
        if self.layers_sizes == None:
Martino Bertoni's avatar
Martino Bertoni committed
313
            self.layers_sizes = get_model_arch(
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
314
                input_shape[0], num_layers=len(self.layers))
315

316
        # each goes to a network with the same architechture
317
        assert(len(self.layers) == len(self.layers_sizes) ==
318
               len(self.activations) == len(self.dropouts))
319
        basenet = Sequential()
320
321
322
323
324
        for i, tple in enumerate(zip(self.layers, self.layers_sizes, self.activations, self.dropouts)):
            layer, layer_size, activation, dropout = tple
            i_shape = None
            if i == 0:
                i_shape = input_shape
325
            if i == (len(self.layers) - 1):
326
                dropout = None
327
328
            add_layer(basenet, layer, layer_size, activation,
                      dropout, input_shape=i_shape)
329

330
        # first layer
331
        """add_layer(basenet, self.layers[0], self.layers_sizes[0],
332
333
                  self.activations[0], self.dropouts[0],
                  input_shape=input_shape)
334
335
336
337
        hidden_layers = zip(self.layers[1:-1],
                            self.layers_sizes[1:-1],
                            self.activations[1:-1],
                            self.dropouts[1:-1])
338
339
        for layer, layer_size, activation, dropout in hidden_layers:
            add_layer(basenet, layer, layer_size, activation, dropout)
340
        add_layer(basenet, self.layers[-1], self.layers_sizes[-1],
341
                  self.activations[-1], None)"""
342
        # last normalization layer for loss
Martino Bertoni's avatar
Martino Bertoni committed
343
        basenet.add(Lambda(lambda x: K.l2_normalize(x, axis=-1)))
344
        basenet.summary()
345

346
347
348
349
        encodeds = list()
        encodeds.append(basenet(input_a))
        encodeds.append(basenet(input_p))
        encodeds.append(basenet(input_n))
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
350
351
352
        if not self.standard:
            encodeds.append(basenet(input_o))
            encodeds.append(basenet(input_s))
353
        merged_vector = concatenate(encodeds, axis=-1, name='merged_layer')
354

Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
355
356
357
        inputs = [input_a, input_p, input_n]
        if not self.standard:
            inputs.extend([input_o, input_s])
358
        model = Model(inputs=inputs, outputs=merged_vector)
359

360
        # TODO NEED TO CHANGE IF WE USE 4 INPUTS INSTEAD OF 3
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
        if self.standard:
            def split_output(y_pred):
                total_lenght = y_pred.shape.as_list()[-1]
                anchor = y_pred[:, 0: int(total_lenght * 1 / 3)]
                positive = y_pred[
                    :, int(total_lenght * 1 / 3): int(total_lenght * 2 / 3)]
                negative = y_pred[
                    :, int(total_lenght * 2 / 3): int(total_lenght * 3 / 3)]
                return anchor, positive, negative, None, None
        else:
            def split_output(y_pred):
                total_lenght = y_pred.shape.as_list()[-1]
                anchor = y_pred[:, 0: int(total_lenght * 1 / 5)]
                positive = y_pred[
                    :, int(total_lenght * 1 / 5): int(total_lenght * 2 / 5)]
                negative = y_pred[
                    :, int(total_lenght * 2 / 5): int(total_lenght * 3 / 5)]
                only = y_pred[
                    :, int(total_lenght * 3 / 5): int(total_lenght * 4 / 5)]
                n_self = y_pred[
                    :, int(total_lenght * 4 / 5): int(total_lenght * 5 / 5)]
                return anchor, positive, negative, only, n_self
383

Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
384
385
        # define monitored metrics
        def accTot(y_true, y_pred):
386
            anchor, positive, negative, _, _ = split_output(y_pred)
Martino Bertoni's avatar
Martino Bertoni committed
387
388
            acc = K.cast(euclidean_distance(anchor, positive) <
                         euclidean_distance(anchor, negative), anchor.dtype)
389
390
            return K.mean(acc)

Martino Bertoni's avatar
Martino Bertoni committed
391
        def accE(y_true, y_pred):
392
            anchor, positive, negative, _, _ = split_output(y_pred)
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
393
            msk = K.cast(K.equal(y_true, 0), 'float32')
394
            prd = self.batch_size / K.sum(msk)
395
396
397
            acc = K.cast(
                euclidean_distance(anchor * msk, positive * msk) <
                euclidean_distance(anchor * msk, negative * msk), anchor.dtype)
398
            return K.mean(acc) * prd
399

Martino Bertoni's avatar
Martino Bertoni committed
400
        def accM(y_true, y_pred):
401
            anchor, positive, negative, _, _ = split_output(y_pred)
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
402
            msk = K.cast(K.equal(y_true, 1), 'float32')
403
            prd = self.batch_size / K.sum(msk)
404
405
406
            acc = K.cast(
                euclidean_distance(anchor * msk, positive * msk) <
                euclidean_distance(anchor * msk, negative * msk), anchor.dtype)
407
            return K.mean(acc) * prd
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
408

Martino Bertoni's avatar
Martino Bertoni committed
409
        def accH(y_true, y_pred):
410
            anchor, positive, negative, _, _ = split_output(y_pred)
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
411
            msk = K.cast(K.equal(y_true, 2), 'float32')
412
            prd = self.batch_size / K.sum(msk)
413
414
415
            acc = K.cast(
                euclidean_distance(anchor * msk, positive * msk) <
                euclidean_distance(anchor * msk, negative * msk), anchor.dtype)
416
            return K.mean(acc) * prd
Martino Bertoni's avatar
Martino Bertoni committed
417

418
419
420
421
422
423
424
425
426
427
428
429
430
        def pearson_r(y_true, y_pred):
            x = y_true
            y = y_pred
            mx = K.mean(x, axis=0)
            my = K.mean(y, axis=0)
            xm, ym = x - mx, y - my
            r_num = K.sum(xm * ym)
            x_square_sum = K.sum(xm * xm)
            y_square_sum = K.sum(ym * ym)
            r_den = K.sqrt(x_square_sum * y_square_sum)
            r = r_num / r_den
            return K.mean(r)

431
        def cor1(y_true, y_pred):
Martino Bertoni's avatar
Martino Bertoni committed
432
433
            anchor, positive, negative, only_self, not_self = split_output(
                y_pred)
434
435
436
            return pearson_r(anchor, not_self)

        def cor2(y_true, y_pred):
Martino Bertoni's avatar
Martino Bertoni committed
437
438
            anchor, positive, negative, only_self, not_self = split_output(
                y_pred)
439
            return pearson_r(anchor, only_self)
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
440

441
        def cor3(y_true, y_pred):
Martino Bertoni's avatar
Martino Bertoni committed
442
443
            anchor, positive, negative, only_self, not_self = split_output(
                y_pred)
444
445
            return pearson_r(not_self, only_self)

Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
446
447
448
        metrics = [accTot]
        if not self.standard:
            metrics.extend([accE,
Martino Bertoni's avatar
Martino Bertoni committed
449
450
451
452
453
                            accM,
                            accH,
                            cor1,
                            cor2,
                            cor3])
454

Martino Bertoni's avatar
Martino Bertoni committed
455
        def tloss(y_true, y_pred):
456
            anchor, positive, negative, _, _ = split_output(y_pred)
Martino Bertoni's avatar
Martino Bertoni committed
457
458
459
460
            pos_dist = K.sum(K.square(anchor - positive), axis=1)
            neg_dist = K.sum(K.square(anchor - negative), axis=1)
            basic_loss = pos_dist - neg_dist + self.margin
            loss = K.maximum(basic_loss, 0.0)
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
461
462
463
            return loss

        def bayesian_tloss(y_true, y_pred):
464
            anchor, positive, negative, _, _ = split_output(y_pred)
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
465
466
467
            loss = 1.0 - K.sigmoid(
                K.sum(anchor * positive, axis=-1, keepdims=True) -
                K.sum(anchor * negative, axis=-1, keepdims=True))
468
            return K.mean(loss)
469

Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
470
471
        def orthogonal_tloss(y_true, y_pred):
            def global_orthogonal_regularization(y_pred):
472
                anchor, positive, negative, _, _ = split_output(y_pred)
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
473
474
                neg_dis = K.sum(anchor * negative, axis=1)
                dim = K.int_shape(y_pred)[1]
Martino Bertoni's avatar
Martino Bertoni committed
475
476
                gor = K.pow(K.mean(neg_dis), 2) + \
                    K.maximum(K.mean(K.pow(neg_dis, 2)) - 1.0 / dim, 0.0)
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
477
478
479
480
481
482
                return gor

            gro = global_orthogonal_regularization(y_pred) * self.alpha
            loss = tloss(y_true, y_pred)
            return loss + gro

Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
483
484
        def only_self_loss(y_true, y_pred):
            def only_self_regularization(y_pred):
Martino Bertoni's avatar
Martino Bertoni committed
485
                anchor, positive, negative, only_self, _ = split_output(y_pred)
486
487
488
489
490
491
492
493
494
                pos_dist = K.sum(K.square(anchor - only_self), axis=1)
                neg_dist = K.sum(K.square(anchor - negative), axis=1)
                basic_loss = pos_dist - neg_dist + self.margin
                loss = K.maximum(basic_loss, 0.0)
                neg_dis = K.sum(anchor * negative, axis=1)
                dim = K.int_shape(y_pred)[1]
                gor = K.pow(K.mean(neg_dis), 2) + \
                    K.maximum(K.mean(K.pow(neg_dis, 2)) - 1.0 / dim, 0.0)
                return loss + (gor * self.alpha)
495

Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
496
497
498
499
            loss = orthogonal_tloss(y_true, y_pred)
            o_self = only_self_regularization(y_pred)
            return loss + o_self

500
        def penta_loss(y_true, y_pred):
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
501
            def only_self_regularization(y_pred):
Martino Bertoni's avatar
Martino Bertoni committed
502
503
                anchor, positive, negative, only_self, not_self = split_output(
                    y_pred)
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
504
505
506
507
508
509
510
511
512
513
                pos_dist = K.sum(K.square(anchor - only_self), axis=1)
                neg_dist = K.sum(K.square(anchor - negative), axis=1)
                basic_loss = pos_dist - neg_dist + self.margin
                loss = K.maximum(basic_loss, 0.0)
                neg_dis = K.sum(anchor * negative, axis=1)
                dim = K.int_shape(y_pred)[1]
                gor = K.pow(K.mean(neg_dis), 2) + \
                    K.maximum(K.mean(K.pow(neg_dis, 2)) - 1.0 / dim, 0.0)
                return loss + (gor * self.alpha)

514
            def not_self_regularization(y_pred):
Martino Bertoni's avatar
Martino Bertoni committed
515
516
                anchor, positive, negative, only_self, not_self = split_output(
                    y_pred)
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
517
518
519
520
521
522
523
524
525
526
527
                pos_dist = K.sum(K.square(anchor - not_self), axis=1)
                neg_dist = K.sum(K.square(anchor - negative), axis=1)
                basic_loss = pos_dist - neg_dist + self.margin
                loss = K.maximum(basic_loss, 0.0)
                neg_dis = K.sum(anchor * negative, axis=1)
                dim = K.int_shape(y_pred)[1]
                gor = K.pow(K.mean(neg_dis), 2) + \
                    K.maximum(K.mean(K.pow(neg_dis, 2)) - 1.0 / dim, 0.0)
                return loss + (gor * self.alpha)

            def both_self_regularization(y_pred):
Martino Bertoni's avatar
Martino Bertoni committed
528
529
                anchor, positive, negative, only_self, not_self = split_output(
                    y_pred)
530
531
532
533
534
535
536
537
538
539
540
                pos_dist = K.sum(K.square(not_self - only_self), axis=1)
                neg_dist = K.sum(K.square(not_self - negative), axis=1)
                basic_loss = pos_dist - neg_dist + self.margin
                loss = K.maximum(basic_loss, 0.0)
                neg_dis = K.sum(anchor * negative, axis=1)
                dim = K.int_shape(y_pred)[1]
                gor = K.pow(K.mean(neg_dis), 2) + \
                    K.maximum(K.mean(K.pow(neg_dis, 2)) - 1.0 / dim, 0.0)
                return loss + (gor * self.alpha)

            loss = orthogonal_tloss(y_true, y_pred)
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
541
            o_self = only_self_regularization(y_pred)
542
            n_self = not_self_regularization(y_pred)
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
543
            b_self = both_self_regularization(y_pred)
Martino Bertoni's avatar
Martino Bertoni committed
544
            return loss + ((o_self + n_self + b_self) / 3)  # n_self
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
545
546
547

        def mse_loss(y_true, y_pred):
            def mse_loss(y_pred):
Martino Bertoni's avatar
Martino Bertoni committed
548
549
                anchor, positive, negative, anchor_sign3, _ = split_output(
                    y_pred)
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
550
551
552
553
                return keras.losses.mean_squared_error(anchor_sign3, anchor)
            loss = orthogonal_tloss(y_true, y_pred)
            mse_loss = mse_loss(y_pred)
            return loss + mse_loss
554

Martino Bertoni's avatar
Martino Bertoni committed
555
556
        lfuncs_dict = {'tloss': tloss,
                       'bayesian_tloss': bayesian_tloss,
557
                       'orthogonal_tloss': orthogonal_tloss,
558
559
                       'only_self_loss': only_self_loss,
                       'penta_loss': penta_loss}
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
560

561
        # compile and print summary
Martino Bertoni's avatar
Martino Bertoni committed
562
563
        self.__log.info('Loss function: %s' %
                        lfuncs_dict[self.loss_func].__name__)
564

Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
565
566
        if self.learning_rate == 'auto':
            optimizer = keras.optimizers.Adam(lr=MIN_LR)
567
        else:
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
568
            optimizer = keras.optimizers.Adam(lr=self.learning_rate)
569

570
        model.compile(
571
            optimizer=optimizer,
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
572
            loss=lfuncs_dict[self.loss_func],
573
574
575
576
577
578
579
580
            metrics=metrics)
        model.summary()

        # if pre-trained model is specified, load its weights
        self.model = model
        if load:
            self.model.load_weights(self.model_file)
        # this will be the encoder/transformer
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
581
        self.transformer = self.model.layers[-2]
582

583
    def find_lr(self, params, num_lr=5, generator=None):
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
584
        import matplotlib.pyplot as plt
585
        from scipy.stats import rankdata
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
586
587
588
589
        # Initialize model
        input_shape = (self.tr_shapes[0][1],)
        self.build_model(input_shape)

Martino Bertoni's avatar
Martino Bertoni committed
590
        # Find lr by grid search
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
591
592
593
        self.__log.info('Finding best lr')
        lr_iters = []
        lr_params = params.copy()
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
594
        lr_params['epochs'] = 1
595
        lrs = [1e-6, 1e-5, 1e-4]
596
        for lr in lrs:
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
597
            self.__log.info('Trying lr %s' % lr)
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
598
            lr_params['learning_rate'] = lr
Martino Bertoni's avatar
Martino Bertoni committed
599
            siamese = SiameseTriplets(
600
                self.model_dir, evaluate=True, plot=False, save_params=False,
601
                generator=generator, **lr_params)
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
602
            siamese.fit(save=False)
Martino Bertoni's avatar
Martino Bertoni committed
603
604
605
            h_file = os.path.join(
                self.model_dir, 'siamesetriplets_history.pkl')
            h_metrics = pickle.load(open(h_file, "rb"))
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
606
607
            loss = h_metrics['loss'][0]
            val_loss = h_metrics['val_loss'][0]
608
609
            acc = h_metrics['accTot'][0]
            val_acc = h_metrics['val_accTot'][0]
610
            lr_iters.append([loss, val_loss, val_acc])
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
611
612

        lr_iters = np.array(lr_iters)
Martino Bertoni's avatar
Martino Bertoni committed
613
614
        lr_scores = np.mean(np.array([rankdata(1 / col) if i > 1 else rankdata(col)
                                      for i, col in enumerate(lr_iters.T)]).T, axis=1)
615
616
        lr_index = np.argmin(lr_scores)
        lr = lrs[lr_index]
Martino Bertoni's avatar
Martino Bertoni committed
617
618
619
        lr_results = {'lr_iters': lr_iters,
                      'lr_scores': lr_scores, 'lr': lr, 'lrs': lrs}

Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
620
621
        fname = 'lr_score.pkl'
        pkl_file = os.path.join(self.model_dir, fname)
622
623
        pickle.dump(lr_results, open(pkl_file, "wb"))

Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
624
        fig, axes = plt.subplots(1, 3, figsize=(9, 3))
625
626
627
628
629
        ax = axes.flatten()
        log_lrs = np.log10(lrs)

        ax[0].set_title('Loss')
        ax[0].set_xlabel('lrs')
Martino Bertoni's avatar
Martino Bertoni committed
630
631
        ax[0].scatter(log_lrs, lr_iters[:, 0], label='train')
        ax[0].scatter(log_lrs, lr_iters[:, 1], label='test')
632
633
634
635
        ax[0].legend()

        ax[1].set_title('ValAccT')
        ax[1].set_xlabel('lrs')
Martino Bertoni's avatar
Martino Bertoni committed
636
        ax[1].scatter(log_lrs, lr_iters[:, 2], label='train')
637

638
        ax[2].set_title('Lr score')
639
        ax[2].set_xlabel('lrs')
640
        ax[2].scatter(log_lrs, lr_scores)
641
        fig.tight_layout()
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
642
643
644
645
646

        fname = 'lr_score.png'
        plot_file = os.path.join(self.model_dir, fname)
        plt.savefig(plot_file)
        plt.close()
647

Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
648
649
650
        return lr

    def fit(self, monitor='val_loss', save=True):
651
652
653
654
655
656
657
658
659
660
661
        """Fit the model.

        monitor(str): variable to monitor for early stopping.
        """
        # builf model
        input_shape = (self.tr_shapes[0][1],)
        self.build_model(input_shape)

        # prepare callbacks
        callbacks = list()

Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
662
        def mask_keep(idxs, x1_data, x2_data, x3_data):
663
664
665
666
667
668
669
670
671
672
673
            # we will fill an array of NaN with values we want to keep
            x1_data_transf = np.zeros_like(x1_data, dtype=np.float32) * np.nan
            for idx in idxs:
                # copy column from original data
                col_slice = slice(idx * 128, (idx + 1) * 128)
                x1_data_transf[:, col_slice] = x1_data[:, col_slice]
            x2_data_transf = np.zeros_like(x2_data, dtype=np.float32) * np.nan
            for idx in idxs:
                # copy column from original data
                col_slice = slice(idx * 128, (idx + 1) * 128)
                x2_data_transf[:, col_slice] = x2_data[:, col_slice]
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
674
675
676
677
678
            x3_data_transf = np.zeros_like(x3_data, dtype=np.float32) * np.nan
            for idx in idxs:
                # copy column from original data
                col_slice = slice(idx * 128, (idx + 1) * 128)
                x3_data_transf[:, col_slice] = x3_data[:, col_slice]
679
            # keep rows containing at least one not-NaN value
680
            """
681
682
            not_nan1 = np.isfinite(x1_data_transf).any(axis=1)
            not_nan2 = np.isfinite(x2_data_transf).any(axis=1)
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
683
684
            not_nan3 = np.isfinite(x3_data_transf).any(axis=1)
            not_nan = np.logical_and(not_nan1, not_nan2, not_nan3)
685
686
            x1_data_transf = x1_data_transf[not_nan]
            x2_data_transf = x2_data_transf[not_nan]
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
687
            x3_data_transf = x3_data_transf[not_nan]
688
            """
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
689
            return x1_data_transf, x2_data_transf, x3_data_transf
690

Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
691
        def mask_exclude(idxs, x1_data, x2_data, x3_data):
692
693
694
695
696
697
698
699
700
701
            x1_data_transf = np.copy(x1_data)
            for idx in idxs:
                # set current space to nan
                col_slice = slice(idx * 128, (idx + 1) * 128)
                x1_data_transf[:, col_slice] = np.nan
            x2_data_transf = np.copy(x2_data)
            for idx in idxs:
                # set current space to nan
                col_slice = slice(idx * 128, (idx + 1) * 128)
                x2_data_transf[:, col_slice] = np.nan
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
702
703
704
705
706
            x3_data_transf = np.copy(x3_data)
            for idx in idxs:
                # set current space to nan
                col_slice = slice(idx * 128, (idx + 1) * 128)
                x3_data_transf[:, col_slice] = np.nan
707
            # drop rows that only contain NaNs
708
            """
709
710
            not_nan1 = np.isfinite(x1_data_transf).any(axis=1)
            not_nan2 = np.isfinite(x2_data_transf).any(axis=1)
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
711
712
            not_nan3 = np.isfinite(x3_data_transf).any(axis=1)
            not_nan = np.logical_and(not_nan1, not_nan2, not_nan3)
713
714
            x1_data_transf = x1_data_transf[not_nan]
            x2_data_transf = x2_data_transf[not_nan]
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
715
            x3_data_transf = x3_data_transf[not_nan]
716
            """
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
717
            return x1_data_transf, x2_data_transf, x3_data_transf
718

Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
719
720
        vsets = ['train_test', 'test_test']
        if self.evaluate and self.plot:
721
722
723
724
725
726
727
728
729
730
731
732
733
            # additional validation sets
            if "dataset_idx" in self.augment_kwargs:
                space_idx = self.augment_kwargs['dataset_idx']
                mask_fns = {
                    'ALL': None,
                    'NOT-SELF': partial(mask_exclude, space_idx),
                    'ONLY-SELF': partial(mask_keep, space_idx),
                }
            else:
                mask_fns = {
                    'ALL': None
                }
            validation_sets = list()
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
734

735
736
737
            for split in vsets:
                for set_name, mask_fn in mask_fns.items():
                    name = '_'.join([split, set_name])
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
738
                    shapes, dtypes, gen = NeighborTripletTraintest.generator_fn(
739
740
741
742
743
                        self.traintest_file, split,
                        batch_size=self.batch_size,
                        replace_nan=self.replace_nan,
                        mask_fn=mask_fn,
                        sharedx=self.sharedx,
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
744
745
746
                        augment_kwargs=self.augment_kwargs,
                        augment_fn=self.augment_fn,
                        train=False,
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
747
                        shuffle=False,
748
                        standard=self.standard, trim_mask=self.trim_mask)
749
750
751
752
753
                    validation_sets.append((gen, shapes, name))
            additional_vals = AdditionalValidationSets(
                validation_sets, self.model, batch_size=self.batch_size)
            callbacks.append(additional_vals)

754
        class CustomEarlyStopping(EarlyStopping):
755

756
757
758
759
760
761
762
            def __init__(self,
                         monitor='val_loss',
                         min_delta=0,
                         patience=0,
                         verbose=0,
                         mode='auto',
                         baseline=None,
763
                         threshold=0,
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
                         restore_best_weights=False):
                super(EarlyStopping, self).__init__()

                self.monitor = monitor
                self.baseline = baseline
                self.patience = patience
                self.verbose = verbose
                self.min_delta = min_delta
                self.wait = 0
                self.stopped_epoch = 0
                self.restore_best_weights = restore_best_weights
                self.best_weights = None
                self.threshold = threshold

                if mode not in ['auto', 'min', 'max']:
                    mode = 'auto'

                if mode == 'min':
                    self.monitor_op = np.less
                elif mode == 'max':
                    self.monitor_op = np.greater
                else:
                    if 'acc' in self.monitor:
                        self.monitor_op = np.greater
                    else:
                        self.monitor_op = np.less

                if self.monitor_op == np.greater:
                    self.min_delta *= 1
                else:
                    self.min_delta *= -1

            def on_epoch_end(self, epoch, logs=None):
                current = self.get_monitor_value(logs)
                threshold = logs.get(self.monitor.replace('val_', ''))
                if current is None:
                    return

                if self.threshold > threshold:
                    self.best = current
                    self.wait = 0
                    if self.restore_best_weights:
                        self.best_weights = self.model.get_weights()
                elif self.monitor_op(current - self.min_delta, self.best):
808
809
810
811
                    self.best = current
                    self.wait = 0
                    if self.restore_best_weights:
                        self.best_weights = self.model.get_weights()
812
813
814
815
816
817
818
819
820
821
822
                else:
                    self.wait += 1
                    if self.wait >= self.patience:
                        self.stopped_epoch = epoch
                        self.model.stop_training = True
                        if self.restore_best_weights:
                            if self.verbose > 0:
                                print('Restoring model weights from the end of '
                                      'the best epoch')
                            self.model.set_weights(self.best_weights)

823
        early_stopping = EarlyStopping(
824
825
            monitor=monitor,
            verbose=1,
826
            patience=self.patience,
827
828
            mode='min',
            restore_best_weights=True)
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
829
830
        if monitor or not self.evaluate:
            callbacks.append(early_stopping)
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
831

832
833
834
835
836
837
838
839
        # call fit and save model
        t0 = time()
        self.history = self.model.fit_generator(
            generator=self.tr_gen,
            steps_per_epoch=self.steps_per_epoch,
            epochs=self.epochs,
            callbacks=callbacks,
            validation_data=self.val_gen,
840
841
            validation_steps=self.validation_steps,
            shuffle=True)
842
        self.time = time() - t0
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
843
844
845
        if save:
            self.model.save(self.model_file)
        if self.evaluate and self.plot:
846
847
848
849
            self.history.history.update(additional_vals.history)

        # check early stopping
        if early_stopping.stopped_epoch != 0:
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
850
            self.last_epoch = early_stopping.stopped_epoch - self.patience
851
852
853
854
855
856
857
        else:
            self.last_epoch = self.epochs

        # save and plot history
        history_file = os.path.join(
            self.model_dir, "%s_history.pkl" % self.name)
        pickle.dump(self.history.history, open(history_file, 'wb'))
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
858
859
        history_file = os.path.join(self.model_dir, "history.png")
        anchor_file = os.path.join(self.model_dir, "anchor_distr.png")
860
861
        if self.evaluate:
            self._plot_history(self.history.history, vsets, history_file)
862
863
        if not self.standard:
            self._plot_anchor_dist(anchor_file)
864

865
    def predict(self, x_matrix, dropout_fn=None, dropout_samples=10, cp=False):
866
867
868
869
870
871
        """Do predictions.

        prediction_file(str): Path to input file containing Xs.
        split(str): which split to predict.
        batch_size(int): batch size for prediction.
        """
872
873
874
875
876
877
        # apply input scaling
        if hasattr(self, 'scaler'):
            scaled = self.scaler.transform(x_matrix)
        else:
            scaled = x_matrix

878
        # apply trimming of input matrix
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
879
        if self.trim_mask is not None:
880
            trimmed = scaled[:, np.repeat(self.trim_mask, 128)]
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
881
        else:
882
883
            trimmed = scaled
            
884
        # load model if not alredy there
885
        if self.model is None:
886
            self.build_model((trimmed.shape[1],), load=True, cp=cp)
887

888
        # get rid of NaNs
889
        no_nans = np.nan_to_num(trimmed)
890
891
892
        # get default dropout function
        if dropout_fn is None:
            return self.transformer.predict(no_nans)
Martino Bertoni's avatar
Martino Bertoni committed
893
        # sample with dropout (repeat input)
894
        samples = list()
895
896
897
898
        for i in range(dropout_samples):
            dropped_ds = dropout_fn(no_nans)
            no_nans_drop = np.nan_to_num(dropped_ds)
            samples.append(self.transformer.predict(no_nans_drop))
899
900
        samples = np.vstack(samples)
        samples = samples.reshape(
901
            no_nans.shape[0], dropout_samples, samples.shape[1])
902
        return samples
903
904
905
906
907
908
909
910
911

    def _plot_history(self, history, vsets, destination):
        """Plot history.

        history(dict): history result from Keras fit method.
        destination(str): path to output file.
        """
        import matplotlib.pyplot as plt

912
        metrics = sorted(list({k.split('_')[-1] for k in history}))
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937

        rows = len(metrics)
        cols = len(vsets)

        plt.figure(figsize=(cols * 5, rows * 5), dpi=100)

        c = 1
        for metric in metrics:
            for vset in vsets:
                plt.subplot(rows, cols, c)
                plt.title(metric.capitalize())
                plt.plot(history[metric], label="Train", lw=2, ls='--')
                plt.plot(history['val_' + metric], label="Val", lw=2, ls='--')
                vset_met = [k for k in history if vset in k and metric in k]
                for valset in vset_met:
                    plt.plot(history[valset], label=valset, lw=2)
                plt.legend()
                c += 1

        plt.tight_layout()

        if destination is not None:
            plt.savefig(destination)
        plt.close('all')

Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
938
939
940
941
942
    def _plot_anchor_dist(self, plot_file):
        from scipy.spatial.distance import cosine
        import matplotlib.pyplot as plt
        import seaborn as sns

Martino Bertoni's avatar
Martino Bertoni committed
943
944
        def sim(a, b):
            return -(cosine(a, b) - 1)
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
945

Martino Bertoni's avatar
Martino Bertoni committed
946
        # Need to create a new train_train generator without train=False
947
948
949
950
951
952
953
        tr_shape_type_gen = NeighborTripletTraintest.generator_fn(
            self.traintest_file,
            'train_train',
            batch_size=self.batch_size,
            replace_nan=self.replace_nan,
            sharedx=self.sharedx,
            augment_fn=self.augment_fn,
954
            augment_kwargs=self.augment_kwargs,
955
            train=False,
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
956
            shuffle=False,
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
957
            standard=self.standard)
958
959
960

        tr_gen = tr_shape_type_gen[2]()

961
962
963
964
965
966
967
968
969
970
971
972
973
        if self.evaluate:
            trval_shape_type_gen = NeighborTripletTraintest.generator_fn(
                self.traintest_file,
                'train_test',
                batch_size=self.batch_size,
                replace_nan=self.replace_nan,
                sharedx=self.sharedx,
                augment_fn=self.augment_fn,
                augment_kwargs=self.augment_kwargs,
                train=False,
                shuffle=False,
                standard=self.standard)
            trval_gen = trval_shape_type_gen[2]()
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
974

975
976
977
978
979
980
981
982
983
984
985
986
            val_shape_type_gen = NeighborTripletTraintest.generator_fn(
                self.traintest_file,
                'test_test',
                batch_size=self.batch_size,
                replace_nan=self.replace_nan,
                sharedx=self.sharedx,
                augment_fn=self.augment_fn,
                augment_kwargs=self.augment_kwargs,
                train=False,
                shuffle=False,
                standard=self.standard)
            val_gen = val_shape_type_gen[2]()
Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
987

988
989
990
991
            vset_dict = {'train_train': tr_gen,
                         'train_test': trval_gen, 'test_test': val_gen}
        else:
            vset_dict = {'train_train': tr_gen}
Martino Bertoni's avatar
Martino Bertoni committed
992

Pau Badia i Mompel's avatar
Pau Badia i Mompel committed
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
        fig, axes