import HPO.base_grid_hpo
from HPO.hpo_logger import HPOLogger
from models.utils.continual_model import ContinualModel
from copy import deepcopy
import numpy as np
from datasets.utils.continual_dataset import ContinualDataset
from argparse import Namespace
import HPO.CV_grid_hpo
import utils.training
from HPO.hyperparams import register_hyperparams, register_lr_hyperparam


class FixedValHPO(HPO.base_grid_hpo.BaseGridHPO):

    def __init__(self, mask_classes=False):
        super(FixedValHPO, self).__init__()
        self.mask_classes = mask_classes

    def select_model(self, model: ContinualModel, logger: HPOLogger,
                           data_stream: ContinualDataset, args: Namespace, task_id: int) -> ContinualModel:
        i = 0
        perfs = []
        best_perf = -1
        best_model = None
        best_setting = None
        for j, setting in enumerate(self.grid()):
            self.set_hyperparams(setting)
            model_copy = deepcopy(model)
            stream_copy = deepcopy(data_stream)
            try:
                HPO.CV_grid_hpo._train(stream_copy.train_loader, model_copy, stream_copy, args, j, i, setting, task_id)
                accs = utils.training.evaluate(model_copy, stream_copy, last=self.mask_classes)

                if "hetro" in data_stream.NAME:
                    totals = np.array([len(loader.dataset) for loader in data_stream.test_loaders])
                    correct = np.array(accs) * totals
                    perf = np.sum(correct, axis=1) / np.sum(totals)
                else:
                    perf = np.mean(accs, axis=1)
                perf = perf[0]

                perfs.append((setting, perf))

                if perf >= best_perf:
                    best_perf = perf
                    best_setting = setting
                    best_model = model_copy
            except AssertionError:
                temp = 0

        print("\nSelected Hyperparams: "+str(best_setting)+" avg val acc: "+str(best_perf))

        # log HPO stats
        if not args.disable_log:
            if 'selected_hp' in logger.logged_vals:
                logger.logged_vals['selected_hp'].append(best_setting)
            else:
                logger.logged_vals['selected_hp'] = [best_setting]

            if 'hpo_avg_perf_stats' in logger.logged_vals:
                logger.logged_vals['hpo_avg_perf_stats'].append(perfs)
            else:
                logger.logged_vals['hpo_avg_perf_stats'] = [perfs]

        hps = [name for name in self.hyperparams]
        for name in hps:
            self.unregister_hyperparam(name)

        if args.only_lr_hpo:
            register_lr_hyperparam(best_model, data_stream, args, self)
        else:
            register_hyperparams(best_model, data_stream, args, self)

        self.set_hyperparams(best_setting)
        return best_model
