# Neural Transformation Learning for Anomaly Detection (NeuTraLAD) - a self-supervised method for anomaly detection
# Copyright (c) 2022 Robert Bosch GmbH
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.


import os
import json
import torch
import random
import numpy as np
from ..loader import load_dataset
from ..utils import Logger


class KVariantEval:

    def __init__(self, dataset, exp_path, model_configs):
        self.num_cls = 1
        self.data_name = dataset
        self.model_configs = model_configs
        self._NESTED_FOLDER = exp_path
        self._FOLD_BASE = '_CLS'
        self._RESULTS_FILENAME = 'results.json'
        self._ASSESSMENT_FILENAME = 'assessment_results.json'

    def process_results(self):

        TS_f1s = []
        TS_aps = []
        TS_aucs = []

        results = {}

        for i in range(self.num_cls):
            try:
                config_filename = os.path.join(self._NESTED_FOLDER, str(i)+self._FOLD_BASE,
                                               self._RESULTS_FILENAME)
                with open(config_filename, 'r') as fp:
                    variant_scores = json.load(fp)
                    ts_f1 = np.array(variant_scores['TS_F1'])
                    ts_auc = np.array(variant_scores['TS_AUC'])
                    ts_ap = np.array(variant_scores['TS_AP'])

                    TS_f1s.append(ts_f1)
                    TS_aucs.append(ts_auc)
                    TS_aps.append(ts_ap)

                results['avg_TS_f1_' + str(i)] = ts_f1.mean()
                results['std_TS_f1_' + str(i)] = ts_f1.std()
                results['avg_TS_ap_' + str(i)] = ts_ap.mean()
                results['std_TS_ap_' + str(i)] = ts_ap.std()
                results['avg_TS_auc_' + str(i)] = ts_auc.mean()
                results['std_TS_auc_' + str(i)] = ts_auc.std()
            except Exception as e:
                print(e)

        TS_f1s = np.array(TS_f1s)
        TS_aps = np.array(TS_aps)
        TS_aucs = np.array(TS_aucs)
        avg_TS_f1 = np.mean(TS_f1s, 0)
        avg_TS_ap = np.mean(TS_aps, 0)
        avg_TS_auc = np.mean(TS_aucs, 0)
        results['avg_TS_f1_all'] = avg_TS_f1.mean()
        results['std_TS_f1_all'] = avg_TS_f1.std()
        results['avg_TS_ap_all'] = avg_TS_ap.mean()
        results['std_TS_ap_all'] = avg_TS_ap.std()
        results['avg_TS_auc_all'] = avg_TS_auc.mean()
        results['std_TS_auc_all'] = avg_TS_auc.std()

        with open(os.path.join(self._NESTED_FOLDER, self._ASSESSMENT_FILENAME), 'w') as fp:
            json.dump(results, fp,indent=0)

    def risk_assessment(self, experiment_class):

        if not os.path.exists(self._NESTED_FOLDER):
            os.makedirs(self._NESTED_FOLDER)

        for cls in range(self.num_cls):

            folder = self._NESTED_FOLDER
            # folder = os.path.join(self._NESTED_FOLDER, str(cls)+self._FOLD_BASE)
            if not os.path.exists(folder):
                os.makedirs(folder)

            json_results = os.path.join(folder, self._RESULTS_FILENAME)
            if not os.path.exists(json_results):

                self._risk_assessment_helper(cls, 'normal', experiment_class, folder)
            else:
                print(
                    f"File {json_results} already present! Shutting down to prevent loss of previous experiments")
                continue

        self.process_results()

        self.experiment = None

    def _risk_assessment_helper(self, train_loader, test_loader, n_dim, cls, cls_type, experiment_class, exp_path):

        self.best_config = self.model_configs[0]
        experiment = experiment_class(self.best_config, exp_path)

        logger = Logger(str(os.path.join(experiment.exp_path, 'experiment.log')), mode='a')
        # logger = None
        # Mitigate bad random initializations

        # dataset = load_data(self.data_name, cls, cls_type)
        dataset = train_loader, None, test_loader, n_dim


        torch.cuda.empty_cache()
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        np.random.seed(41)
        random.seed(41)
        torch.manual_seed(41)
        torch.cuda.manual_seed(41)
        torch.cuda.manual_seed_all(41)

        experiment.run_test(dataset,logger)

        self.experiment = experiment

    def get_score(self, test_loader):
        return self.experiment.get_score(test_loader)
