# Neural Transformation Learning for Anomaly Detection (NeuTraLAD) - a self-supervised method for anomaly detection
# Copyright (c) 2022 Robert Bosch GmbH
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.


from ..config.base import Config
from torch.utils.data import DataLoader

class runExperiment():

    def __init__(self, model_configuration, exp_path):
        self.model_config = Config.from_dict(model_configuration)
        self.exp_path = exp_path


    def run_test(self, dataset, logger):
        train_loader, val_loader, test_loader, x_dim = dataset

        n_epoch = self.model_config.training_epochs
        model_class = self.model_config.model
        loss_class = self.model_config.loss
        optim_class = self.model_config.optimizer
        sched_class = self.model_config.scheduler
        stopper_class = self.model_config.early_stopper
        network = self.model_config.network
        trainer_class = self.model_config.trainer
        shuffle = self.model_config['shuffle'] if 'shuffle' in self.model_config else True


        # try:
        #     x_dim = self.model_config['x_dim']
        # except:
        #     x_dim = train_data.dim_features
        #
        # try:
        #     batch_size = self.model_config['batch_size']
        # except:
        #     batch_size = int(len(train_data)/4)
        # if len(train_data) >batch_size:
        #     drop = True
        # else:
        #     drop = False
        # train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=shuffle,
        #                           drop_last=drop)
        #
        #
        # if len(val_data) == 0:
        #     val_loader = None
        # else:
        #     val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False,
        #                             drop_last=False)
        #
        # if len(test_data) == 0:
        #     test_loader = None
        # else:
        #     test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False,
        #                              drop_last=False)

        model = model_class(network(),x_dim, config=self.model_config)
        optimizer = optim_class(model.parameters(),
                                lr=self.model_config['learning_rate'], weight_decay=self.model_config['l2'])

        if sched_class is not None:
            scheduler = sched_class(optimizer)
        else:
            scheduler = None

        trainer = trainer_class(model, loss_function=loss_class(self.model_config['loss_temp']),
                         device=self.model_config['device'])

        _ = \
            trainer.train(train_loader=train_loader,
                      max_epochs=n_epoch,
                      # self.model_config['training_epochs'],
                      optimizer=optimizer, scheduler=scheduler,
                      validation_loader=val_loader, test_loader=test_loader, early_stopping=None,
                      logger=logger)
        self.trainer = trainer
        self.model = model

    def get_score(self, test_loader):
        score = self.trainer.detect_outliers(test_loader, cls=None)
        return score
