"""Regularization methods."""
import copy
from collections import defaultdict
from typing import List

import torch
import torch.nn.functional as F

from avalanche.models import avalanche_forward


def cross_entropy_with_oh_targets(outputs, targets, eps=1e-5):
    """ Calculates cross-entropy with temperature scaling, 
    targets can also be soft targets but they must sum to 1 """
    outputs = torch.nn.functional.softmax(outputs, dim=1)
    ce = -(targets * outputs.log()).sum(1)
    ce = ce.mean()
    return ce


class RegularizationMethod:
    """RegularizationMethod implement regularization strategies.
    RegularizationMethod is a callable.
    The method `update` is called to update the loss, typically at the end
    of an experience.
    """

    def update(self, *args, **kwargs):
        raise NotImplementedError()

    def __call__(self, *args, **kwargs):
        raise NotImplementedError()


class ACECriterion(RegularizationMethod):
    """
    Asymetric cross-entropy (ACE) Criterion used in
    "New Insights on Reducing Abrupt Representation 
    Change in Online Continual Learning"
    by Lucas Caccia et. al.
    https://openreview.net/forum?id=N8MaByOzUfb
    """

    def __init__(self):
        pass

    def __call__(self, out_in, target_in, out_buffer, target_buffer):
        current_classes = torch.unique(target_in)
        loss_buffer = F.cross_entropy(out_buffer, target_buffer)
        oh_target_in = F.one_hot(target_in, num_classes=out_in.shape[1])
        oh_target_in = oh_target_in[:, current_classes]
        loss_current = cross_entropy_with_oh_targets(
                out_in[:, current_classes], oh_target_in
        )
        return (loss_buffer + loss_current) / 2


__all__ = ["RegularizationMethod", "ACECriterion"]
