import os
from collections import defaultdict

import numpy as np
import pandas as pd
import torch.utils.data
from torch.utils.data import DataLoader

import cole as cl


class Evaluator:
    """
    This class creates the results.csv file and takes care of evaluating, as well as keeping track of the c
    current number of iterations.

    It can store checkpoints to, but that's not fully tested.
    """

    def __init__(self, log_dir, result_name, num_classes, frequency='dense', checkpoints=None):
        self.log_dir = log_dir
        self.result_name = result_name
        self.metrics = defaultdict(list)
        self._init_metrics(num_classes)
        # This determines how often evaluation happens, see _eval_defs below to get more details.
        self.evaluate_scheme = _eval_defs[frequency]
        self.current_iter = 0
        self.task_iter = 0
        self.checkpoints = _checkpoints_defs[checkpoints]

    def evaluate(self, model, test_datasets, train_datasets=None, device='cuda', force_eval=False):
        """
        Decides whether to evaluate or not, then does the evaluation and records the metrics.
        force_eval allows to evaluate without increasing the iteration count (mainly used in the beginning and
        at the end of training).
        """

        if not force_eval:
            self.current_iter += 1

        if self._check_do_eval() or force_eval:
            self.update_simple_metric('iter', self.current_iter)
            self.update_simple_metric('task_iter', self.task_iter)

            model = model.to(device)
            test_loaders = [DataLoader(ts, batch_size=128, num_workers=4, shuffle=True) for ts in test_datasets]
            test_res = cl.test_per_class(model, test_loaders, device=device)
            # print(f"{np.mean(list(test_res[0].values())):.5f} \t {np.mean(list(test_res[1].values())):.5f}")
            self.update_cls_metrics(['test_acc', 'test_loss'], test_res)

            if train_datasets is not None:
                train_loaders = [DataLoader(ts, batch_size=128, num_workers=4) for ts in train_datasets]
                train_res = cl.test_per_class(model, train_loaders, device=device)
                # print(f"{np.mean(list(train_res[0].values())):.5f} \t {np.mean(list(train_res[1].values())):.5f}")
                self.update_cls_metrics(['train_acc', 'train_loss'], train_res)

        self.make_checkpoints(model)

    def _check_do_eval(self):
        for eval_stage in self.evaluate_scheme:
            if self.current_iter <= eval_stage[0]:
                return self.current_iter % eval_stage[1] == 0

    def _init_metrics(self, num_classes):
        """
        Init some metrics here. Iters to keep them first, cls metrics because not always all metrics are used.
        """
        self.metrics['iter'] = []
        self.metrics['task_iter'] = []
        for metric in ['test_acc', 'test_loss', 'train_acc', 'train_loss']:
            for cls in range(num_classes):
                self.metrics[f'{metric}_{cls}'] = []

    def update_cls_metrics(self, metric_names, values_per_class):
        for mn, values in zip(metric_names, values_per_class):
            for cls in sorted(values.keys()):
                self.metrics[f'{mn}_{cls}'].append(values[cls])

    def update_simple_metric(self, metric_name, value):
        self.metrics[metric_name].append(value)

    def make_checkpoints(self, model):
        if self.checkpoints is not None and self.current_iter in self.checkpoints:
            model_dir = os.path.join(self.log_dir, 'models')
            os.makedirs(model_dir, exist_ok=True)
            model_name = f"chkpt_{self.current_iter}.pth"
            torch.save(model.state_dict(), os.path.join(model_dir, model_name))

    def dump_results(self):
        metrics = self.get_log_metrics()
        df = pd.DataFrame.from_dict(metrics)
        result_file = os.path.join(self.log_dir, f'{self.result_name}.csv')

        if os.path.exists(result_file):
            df.to_csv(os.path.join(self.log_dir, f'{self.result_name}.csv'), mode='a', header=False, index=False)
        else:
            df.to_csv(os.path.join(self.log_dir, f'{self.result_name}.csv'), mode='w', header=True, index=False)

        for k in self.metrics.keys():
            self.metrics[k] = []  # Don't just reset, like this we keep the order.

    def get_log_metrics(self):

        max_length = max([len(v) for v in self.metrics.values()])
        log_metrics = {}

        for k in self.metrics:
            if len(self.metrics[k]) == max_length:
                log_metrics[k] = self.metrics[k]
            else:
                # Pad with -1 if not long enough, although it should probably be empty
                log_metrics[k] = self.metrics[k] + [-1] * (max_length - len(self.metrics[k]))

        return log_metrics

    def reset_iter_count(self):
        self.task_iter += 1
        self.current_iter = 0


# These are a way to define when the evaluation needs to be done. Each tuple is (max_iteration, evaluation frequency).
# This means, until iteration max_iteration, evaluate every i-th iteration. The last tuple is used until the end of
# training. (Add a large number there, but it isn't really checked anyway). These bounds are exclusive.

_eval_defs = {
    'dense': [(10, 1), (20, 2), (100, 10), (1000, 50), (100_000, 250)],
    'sparse': [(100_000, 250)]
}

_checkpoints_defs = {
    'log': [1, 10, 20, 40, 80, 160, 320, 640, 1280, 2560, 5120, 10240],
    'none': None
}

