import torch
import random
import numpy as np
import os

class AverageMeter:
    """Compute and store the average and current value.

    Examples::
        >>> # 1. Initialize a meter to record loss
        >>> losses = AverageMeter()
        >>> # 2. Update meter after every mini-batch update
        >>> losses.update(loss_value, batch_size)
    """

    def __init__(self, ema=False):
        """
        Args:
            ema (bool, optional): apply exponential moving average.
        """
        self.ema = ema
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        if isinstance(val, torch.Tensor):
            val = val.item()

        self.val = val
        self.sum += val * n
        self.count += n

        if self.ema:
            self.avg = val if self.count == n else self.avg * 0.9 + self.val * 0.1
        else:
            self.avg = self.sum / self.count


def customed_forgetting(all_preds, all_targets, all_tasks):
    """Measures the average forgetting.

    Reference:
    * Riemannian Walk for Incremental Learning: Understanding Forgetting and Intransigence
      Chaudhry et al. ECCV 2018

    See eq. 3.
    """
    k = len(all_preds)  # Number of seen tasks so far
    # TODO if we take in account zeroshot, we should take the max of all_tasks?
    if k <= 1:
        return 0.

    f = 0.
    for j in range(k - 1):
        # Accuracy on task j after learning current task k
        a_kj = _get_R_ij(k - 1, j, all_preds, all_targets, all_tasks)
        # Best previous accuracy on task j
        max_a_lj = max(_get_R_ij(l, j, all_preds, all_targets, all_tasks) for l in range(k - 1) if l >= j)
        f += max_a_lj - a_kj  # We want this results to be as low as possible

    metric = f / (k - 1)
    assert -1.0 <= metric <= 1.0, metric
    return metric


def _get_R_ij(i, j, all_preds, all_targets, all_tasks):
    """Computes an accuracy after task i on task j.

    R matrix:

          || T_e1 | T_e2 | T_e3
    ============================|
     T_r1 || R*  | R_ij  | R_ij |
    ----------------------------|
     T_r1 || R*  | R_ij  | R_ij |
    ----------------------------|
     T_r1 || R*  | R_ij  | R_ij |
    ============================|

    R_13 is the R of the first column and the third row.

    From Chaudhry et al., calling this function with (i, j) equals to a_{i,j}.

    Except OOD and Zeroshot, i should be >= j.

    Reference:
    * Don’t forget, there is more than forgetting: newmetrics for Continual Learning
      Diaz-Rodriguez and Lomonaco et al. NeurIPS Workshop 2018

    :param i: Task id after which a model was trained.
    :param j: Task id of the test data.
    :param all_preds: All predicted labels up to now.
    :param all_targets: All targets up to now.
    :param all_tasks: All task ids up to now.
    :return: a float metric between 0 and 1.
    """
    preds = all_preds[i]
    targets = all_targets[i]
    tasks = all_tasks[i]

    indexes = np.where(tasks == j)[0]
    if len(indexes) == 0:
        raise ValueError(
            f"You haven't evaluated on any samples of task {j} after "
            f"training on task {i}, therefore it is impossible to compute some "
            "metrics (e.g. FwT, BwT).\n"
            "Either don't use those metrics or evaluate on all seen tasks. "
            "You can do it one by one (`test_scenario[0], test_scenario[1]`) "
            "or all together (`test_scenario[:task_id + 1]`)."
        )
    return (preds[indexes] == targets[indexes]).mean()