import collections

import numpy as np
import torch

from sklearn.metrics import roc_auc_score, f1_score

class MetricLogger:

    def __init__(self, nb_tasks, nb_classes):
        self.metrics = collections.defaultdict(list)

        self.nb_tasks = nb_tasks
        self.nb_classes = nb_classes

        self._task_counter = 0

    def log_task(self, ypreds, ytrue, zid, novelty_threshold=None, evaluate_ood=False):
        if evaluate_ood:            
            md_test = zid <= self._task_counter + 1
            ypreds = ypreds[md_test]
            ytrue = ytrue[md_test]
            novelty = novelty[md_test]
            
            noveltrue = (zid >= (self._task_counter + 1)).astype(int)
            
            if isinstance(novelty_threshold, dict):
                threshold = np.array([novelty_threshold[c] for c in np.argmax(ypreds, axis=1)])
                noveltypred = (novelty >= threshold).astype(int)                                
            else:
                noveltypred = (novelty>=novelty_threshold).astype(int)
            self.metrics["ood_accuracy"].append(np.sum(noveltrue==noveltypred)/len(noveltrue))

            ytrue2 = ytrue.copy()
            ytrue2[noveltrue==1] = -1
            ypreds2 = np.argmax(ypreds, axis=1)
            ypreds2[noveltypred==1] = -1
            self.metrics["md_f1"].append(f1_score(ytrue2, ypreds2, labels=np.unique(ytrue2), average='macro'))
            
            if np.sum(noveltrue) > 0:
                self.metrics["ood_auroc"].append(roc_auc_score(noveltrue, novelty))
                seen_task_indexes = np.where(zid <= self._task_counter)[0]
                ypreds = ypreds[seen_task_indexes]
                ytrue = ytrue[seen_task_indexes]
            
        self.metrics["accuracy"].append(
            accuracy_per_task(ypreds, ytrue, zid, self.nb_tasks, topk=1)
        )  # FIXME various task size
        
        if self.nb_classes > 5:
            self.metrics["accuracy_top5"].append(
                accuracy_per_task(ypreds, ytrue, zid, self.nb_tasks, topk=5)
            )

        self.metrics["incremental_accuracy"].append(incremental_accuracy(self.metrics["accuracy"]))
        
        if self.nb_classes > 5:
            self.metrics["incremental_accuracy_top5"].append(
                incremental_accuracy(self.metrics["accuracy_top5"])
            )
        self.metrics["forgetting"].append(forgetting(self.metrics["accuracy"]))

        if self._task_counter > 0:
            self.metrics["old_accuracy"].append(old_accuracy(ypreds, ytrue, zid, self._task_counter))
            self.metrics["new_accuracy"].append(new_accuracy(ypreds, ytrue, zid, self._task_counter))

        self._task_counter += 1

    @property
    def last_results(self):
        results = {
            "task_id": len(self.metrics["accuracy"]) - 1,
            "accuracy": self.metrics["accuracy"][-1],
            "incremental_accuracy": self.metrics["incremental_accuracy"][-1],
            "forgetting": self.metrics["forgetting"][-1],
        }
        
        if self.nb_classes > 5:
            results.update(
                {
                    "accuracy_top5": self.metrics["accuracy_top5"][-1],
                    "incremental_accuracy_top5": self.metrics["incremental_accuracy_top5"][-1],
                }
            )
            
        
        if "old_accuracy" in self.metrics:
            results.update(
                {
                    "old_accuracy": self.metrics["old_accuracy"][-1],
                    "new_accuracy": self.metrics["new_accuracy"][-1],
                    "avg_old_accuracy": np.mean(self.metrics["old_accuracy"]),
                    "avg_new_accuracy": np.mean(self.metrics["new_accuracy"]),
                }
            )
        
        if "ood_accuracy" in self.metrics:
            results.update({"ood_accuracy": self.metrics["ood_accuracy"][-1]})
        if "md_f1" in self.metrics:
            results.update({"md_f1": self.metrics["md_f1"][-1]})            
        if "ood_auroc" in self.metrics:
            results.update({"ood_auroc": self.metrics["ood_auroc"][-1]})

        return results


def accuracy_per_task(ypreds, ytrue, zid, cur_task, topk=1):
    """Computes accuracy for the whole test & per task.

    :param ypred: The predictions array.
    :param ytrue: The ground-truth array.
    :param zid: The task ID array.
    :return: A dictionnary.
    """
    all_acc = {}

    all_acc["total"] = accuracy(ypreds, ytrue, topk=topk)

    for task in range(cur_task+1):
        idxes = np.where(zid==task)[0]
        if len(idxes) > 0:
            all_acc[f"Task{task}"] = accuracy(ypreds[idxes], ytrue[idxes], topk=topk)

    return all_acc

def old_accuracy(ypreds, ytrue, zid, cur_task):
    """Computes accuracy for the whole test & per task.

    :param ypred: The predictions array.
    :param ytrue: The ground-truth array.
    :param zid: The task ID array.
    :return: A dictionnary.
    """
    old_task_indexes = np.where(zid < cur_task)[0]
    return accuracy(ypreds[old_task_indexes], ytrue[old_task_indexes], topk=1)


def new_accuracy(ypreds, ytrue, zid, cur_task):
    """Computes accuracy for the whole test & per task.

    :param ypred: The predictions array.
    :param ytrue: The ground-truth array.
    :param zid: The task ID array.
    :return: A dictionnary.
    """
    new_task_indexes = np.where(zid == cur_task)[0]
    return accuracy(ypreds[new_task_indexes], ytrue[new_task_indexes], topk=1)


def accuracy(output, targets, topk=1):
    """Computes the precision@k for the specified values of k"""
    output, targets = torch.tensor(output), torch.tensor(targets)

    batch_size = targets.shape[0]
    if batch_size == 0:
        return 0.
    nb_classes = len(np.unique(targets))
    topk = min(topk, nb_classes)

    _, pred = output.topk(topk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(targets.view(1, -1).expand_as(pred))

    correct_k = correct[:topk].reshape(-1).float().sum(0).item()
    return round(correct_k / batch_size, 3)

def incremental_accuracy(accuracies):
    """Computes the average incremental accuracy as described in iCaRL.

    It is the average of the current task accuracy (tested on 0-X) with the
    previous task accuracy.

    :param acc_dict: A list TODO
    """
    return sum(task_acc["total"] for task_acc in accuracies) / len(accuracies)


def forgetting(accuracies):
    if len(accuracies) == 1:
        return 0.

    last_accuracies = accuracies[-1]
    usable_tasks = last_accuracies.keys()

    forgetting = 0.
#     for task in usable_tasks:
    for task in range(len(accuracies)-1):
#         if task == "total":
#             continue

        max_task = 0.
        for task_accuracies in accuracies[:-1]:
            if f"Task{task}" in task_accuracies:
                max_task = max(max_task, task_accuracies[f"Task{task}"])

        forgetting += max_task - last_accuracies[f"Task{task}"]

    return forgetting / (len(accuracies)-1)   # average across t-1 tasks


def forward_transfer(accuracies):
    """Measures the influence that learning a task has on the performance of future tasks.

    References:
        * Don't forget, there is more than forgetting: new metrics for Continual Learning
          Diaz et al.
          NeurIPS Workshop 2018
    """
    nb_tasks = len(accuracies)

    fwt = 0
    for i in range(nb_tasks):
        pass
