"""Regularization methods."""
# https://github.com/ContinualAI/avalanche/blob/master/avalanche/training/regularization.py

import copy
from collections import defaultdict
from typing import List

import torch
import torch.nn.functional as F

#from avalanche.models import MultiTaskModule, avalanche_forward

class SynapticIntelligence:
    def __init__(self, model, fisher, importance, decay=0.9):
        self.model = model
        self.fisher = fisher
        self.online_reg = True  # Original SI works in an online updating fashion
        self.damping_factor = 0.1
        self.importance = importance
        self.decay = decay
        self.old_params = {}
        for name, param in model.named_parameters():
            self.old_params[name] = param.detach().clone()

    def __call__(self, model):
        return self.loss(
            model, self.old_params, self.fisher, self.importance, self.decay
        )

    def loss(model, old_params, fisher, importance, decay=0.9):
        """Compute synaptic intelligence."""
        si_loss = 0
        for name, param in model.named_parameters():
            if name in old_params:
                mean = old_params[name]
                fish = fisher[name]
                imp = 1 #importance[name]
                si_loss += (fish * (param - mean).pow(2)).sum()
                si_loss += (decay * imp * (param - mean).pow(2)).sum()
        print("SI loss: ", si_loss)
        return si_loss

    def calculate_importance(self, dataloader):

        # Initialize the importance matrix
        importance = {}
        for n, p in self.params.items():
            importance[n] = p.clone().detach().fill_(0)  # zero initialized
        prev_params = self.initial_params

        # Calculate or accumulate the Omega (the importance matrix)
        for n, p in importance.items():
            delta_theta = self.params[n].detach() - prev_params[n]
            p += self.w[n]/(delta_theta**2 + self.damping_factor)
            self.w[n].zero_()

        return importance

def ewc_penalty(model, old_params):
    """Compute EWC penalty."""
    ewc_loss = 0
    for name, param in model.named_parameters():
        if name in old_params:
            mean = old_params[name]
            ewc_loss += (param - mean).pow(2).sum()
#    print("EWC loss: ", ewc_loss)
    return ewc_loss


def stable_softmax(x):
    z = x - torch.max(x, dim=1, keepdim=True)[0]
    numerator = torch.exp(z)
    denominator = torch.sum(numerator, dim=1, keepdim=True)
    softmax = numerator / denominator
    return softmax


def cross_entropy_with_oh_targets(outputs, targets, reduction="mean"):
    """Calculates cross-entropy with temperature scaling,
    targets can also be soft targets but they must sum to 1"""
    outputs = stable_softmax(outputs)
    ce = -(targets * outputs.log()).sum(1)
    if reduction == "mean":
        ce = ce.mean()
    elif reduction == "none":
        return ce
    else:
        raise NotImplementedError("reduction must be mean or none")
    return ce


class RegularizationMethod:
    """RegularizationMethod implement regularization strategies.
    RegularizationMethod is a callable.
    The method `update` is called to update the loss, typically at the end
    of an experience.
    """

    def update(self, *args, **kwargs):
        raise NotImplementedError()

    def pre_adapt(self, agent, exp):
        pass  # implementation may be empty if adapt is not needed

    def post_adapt(self, agent, exp):
        pass  # implementation may be empty if adapt is not needed

    def __call__(self, *args, **kwargs):
        raise NotImplementedError()


class LearningWithoutForgetting(RegularizationMethod):
    """Learning Without Forgetting.

    The method applies knowledge distilllation to mitigate forgetting.
    The teacher is the model checkpoint after the last experience.
    """

    def __init__(self, alpha=1, temperature=2):
        """
        :param alpha: distillation hyperparameter. It can be either a float
                number or a list containing alpha for each experience.
        :param temperature: softmax temperature for distillation
        """
        self.alpha = alpha
        self.temperature = temperature
        self.prev_model = None
        self.expcount = 0
        # count number of experiences (used to increase alpha)
        self.prev_classes_by_task = defaultdict(set)
        """ In Avalanche, targets of different experiences are not ordered. 
        As a result, some units may be allocated even though their 
        corresponding class has never been seen by the model.
        Knowledge distillation uses only units corresponding
        to old classes. 
        """

    def _distillation_loss(self, out, prev_out, active_units):
        """Compute distillation loss between output of the current model and
        and output of the previous (saved) model.
        """
        # we compute the loss only on the previously active units.
        au = list(active_units)

        # some people use the crossentropy instead of the KL
        # They are equivalent. We compute
        # kl_div(log_p_curr, p_prev) = p_prev * (log (p_prev / p_curr)) =
        #   p_prev * log(p_prev) - p_prev * log(p_curr).
        # Now, the first term is constant (we don't optimize the teacher),
        # so optimizing the crossentropy and kl-div are equivalent.
        if type(out) is tuple: # contrastive loss
            log_p = torch.log_softmax(out[0][:, au] / self.temperature, dim=1)
            q = torch.softmax(prev_out[0][:, au] / self.temperature, dim=1)
            res = torch.nn.functional.kl_div(log_p, q, reduction="batchmean")
            log_p = torch.log_softmax(out[1][:, au] / self.temperature, dim=1)
            q = torch.softmax(prev_out[1][:, au] / self.temperature, dim=1)
            res += torch.nn.functional.kl_div(log_p, q, reduction="batchmean")
        else:
            log_p = torch.log_softmax(out[:, au] / self.temperature, dim=1)
            q = torch.softmax(prev_out[:, au] / self.temperature, dim=1)
            res = torch.nn.functional.kl_div(log_p, q, reduction="batchmean")
        return res

    def _lwf_penalty(self, out, x, curr_model=None, texts=None):
        """
        Compute weighted distillation loss.
        """
        if self.prev_model is None:
            return 0
        else:
            if False and isinstance(self.prev_model, MultiTaskModule):
                # output from previous output heads.
                with torch.no_grad():
                    y_prev = avalanche_forward(self.prev_model, x, None)
                y_prev = {k: v for k, v in y_prev.items()}
                # in a multitask scenario we need to compute the output
                # from all the heads, so we need to call forward again.
                # TODO: can we avoid this?
                y_curr = avalanche_forward(curr_model, x, None)
                y_curr = {k: v for k, v in y_curr.items()}
            else:  # no task labels. Single task LwF
                with torch.no_grad():
                    if 'siglip' in self.prev_model.__class__.__name__.lower():
                        y_prev = {0: self.prev_model(pixel_values=x)[0]}
                    elif texts is None:
                        y_prev = {0: self.prev_model(x)}
                    else:
                        y_prev = {0: self.prev_model.get_logits(x, texts)}
                if type(y_prev[0]) is not type(out) is torch.Tensor:
                    y_prev[0] = y_prev[0].logits
                y_curr = {0: out}
            
            dist_loss = 0
            if self.mask is not None:
                if type(y_prev[0]) is not tuple: # not contrastive learning
                    y_prev[0] = y_prev[0][:, self.mask]

            for task_id in y_prev.keys():
                # compute kd only for previous heads and only for seen units.
                if task_id in self.prev_classes_by_task or True:
                    yp = y_prev[task_id]
                    yc = y_curr[task_id]
                    au = self.prev_classes_by_task[task_id]
                    if type(yc) is tuple:
                        au = [e for e in range(yc[0].shape[-1])]
                    else:
                        au = [e for e in range(yc.shape[-1])]
                    dist_loss += self._distillation_loss(yc, yp, au)
            return dist_loss

    def __call__(self, mb_x, mb_pred, model=None, mb_text=None):
        """
        Add distillation loss
        """
        alpha = (
            self.alpha[self.expcount]
            if isinstance(self.alpha, (list, tuple))
            else self.alpha
        )
        return alpha * self._lwf_penalty(mb_pred, mb_x, model, mb_text)

    def post_adapt(self, agent, exp):
        """Save a copy of the model after each experience and
        update self.prev_classes to include the newly learned classes.

        :param agent: agent state
        :param exp: current experience
        """
        self.expcount += 1
        self.prev_model = copy.deepcopy(agent.model)
        task_ids = exp.dataset.targets_task_labels.uniques

        for task_id in task_ids:
            task_data = exp.dataset.task_set[task_id]
            pc = set(task_data.targets.uniques)

            if task_id not in self.prev_classes_by_task:
                self.prev_classes_by_task[task_id] = pc
            else:
                self.prev_classes_by_task[task_id] = self.prev_classes_by_task[
                    task_id
                ].union(pc)

    def update(self, experience, model):
        """Save a copy of the model after each experience and
        update self.prev_classes to include the newly learned classes.

        :param experience: current experience
        :param model: current model
        """
        self.expcount += 1
        self.prev_model = copy.deepcopy(model)
        task_ids = experience.dataset.targets_task_labels.uniques

        for task_id in task_ids:
            task_data = experience.dataset.task_set[task_id]
            pc = set(task_data.targets.uniques)

            if task_id not in self.prev_classes_by_task:
                self.prev_classes_by_task[task_id] = pc
            else:
                self.prev_classes_by_task[task_id] = self.prev_classes_by_task[
                    task_id
                ].union(pc)


class ACECriterion(RegularizationMethod):
    """
    Asymetric cross-entropy (ACE) Criterion used in
    "New Insights on Reducing Abrupt Representation
    Change in Online Continual Learning"
    by Lucas Caccia et. al.
    https://openreview.net/forum?id=N8MaByOzUfb
    """

    def __init__(self):
        pass

    def __call__(self, out_in, target_in, out_buffer, target_buffer):
        current_classes = torch.unique(target_in)
        loss_buffer = F.cross_entropy(out_buffer, target_buffer)
        oh_target_in = F.one_hot(target_in, num_classes=out_in.shape[1])
        oh_target_in = oh_target_in[:, current_classes]
        loss_current = cross_entropy_with_oh_targets(
            out_in[:, current_classes], oh_target_in
        )
        return (loss_buffer + loss_current) / 2


class AMLCriterion(RegularizationMethod):
    """
    Asymmetric metric learning (AML) Criterion used in
    "New Insights on Reducing Abrupt Representation
    Change in Online Continual Learning"
    by Lucas Caccia et. al.
    https://openreview.net/forum?id=N8MaByOzUfb
    """

    def __init__(
        self,
        feature_extractor,
        temp: float = 0.1,
        base_temp: float = 0.07,
        same_task_neg: bool = True,
        device: str = "cpu",
    ):
        """
        ER_AML criterion constructor.
        :param feature_extractor: Model able to map an input in a latent space.
        :param temp: Supervised contrastive temperature.
        :param base_temp: Supervised contrastive base temperature.
        :param same_task_neg: Option to remove negative samples of different tasks.
        :param device: Accelerator used to speed up the computation.
        """
        self.device = device
        self.feature_extractor = feature_extractor
        self.temp = temp
        self.base_temp = base_temp
        self.same_task_neg = same_task_neg

    def __sample_pos_neg(
        self,
        y_in: torch.Tensor,
        t_in: torch.Tensor,
        x_memory: torch.Tensor,
        y_memory: torch.Tensor,
        t_memory: torch.Tensor,
    ) -> tuple:
        """
        Method able to sample positive and negative examples with respect the input minibatch from input and buffer minibatches.
        :param x_in: Input of new minibatch.
        :param y_in: Output of new minibatch.
        :param t_in: Task ids of new minibatch.
        :param x_memory: Input of memory.
        :param y_memory: Output of minibatch.
        :param t_memory: Task ids of minibatch.
        :return: Tuple of positive and negative input and output examples and a mask for identify invalid values.
        """
        valid_pos = y_in.reshape(1, -1) == y_memory.reshape(-1, 1)
        if self.same_task_neg:
            same_task = t_in.view(1, -1) == t_memory.view(-1, 1)
            valid_neg = ~valid_pos & same_task
        else:
            valid_neg = ~valid_pos

        pos_idx = torch.multinomial(valid_pos.float().T, 1).squeeze(1)
        neg_idx = torch.multinomial(valid_neg.float().T, 1).squeeze(1)

        pos_x = x_memory[pos_idx]
        pos_y = y_memory[pos_idx]
        neg_x = x_memory[neg_idx]
        neg_y = y_memory[neg_idx]

        return pos_x, pos_y, neg_x, neg_y

    def __sup_con_loss(
        self,
        anchor_features: torch.Tensor,
        features: torch.Tensor,
        anchor_targets: torch.Tensor,
        targets: torch.Tensor,
    ) -> torch.Tensor:
        """
        Method able to compute the supervised contrastive loss of new minibatch.
        :param anchor_features: Anchor features related to new minibatch duplicated mapped in latent space.
        :param features: Features related to half positive and half negative examples mapped in latent space.
        :param anchor_targets: Labels related to anchor features.
        :param targets: Labels related to features.
        :return: Supervised contrastive loss.
        """
        pos_mask = (
            (anchor_targets.reshape(-1, 1) == targets.reshape(1, -1))
            .float()
            .to(self.device)
        )
        similarity = anchor_features @ features.T / self.temp
        similarity -= similarity.max(dim=1)[0].detach()
        log_prob = similarity - torch.log(torch.exp(similarity).sum(1))
        mean_log_prob_pos = (pos_mask * log_prob).sum(1) / pos_mask.sum(1)
        loss = -(self.temp / self.base_temp) * mean_log_prob_pos.mean()
        return loss

    def __scale_by_norm(self, x: torch.Tensor) -> torch.Tensor:
        """
        Function able to scale by its norm a certain tensor.
        :param x: Tensor to normalize.
        :return: Normalized tensor.
        """
        x_norm = torch.norm(x, p=2, dim=1).unsqueeze(1).expand_as(x)
        return x / (x_norm + 1e-05)

    def __call__(
        self,
        input_in: torch.Tensor,
        target_in: torch.Tensor,
        task_in: torch.Tensor,
        output_buffer: torch.Tensor,
        target_buffer: torch.Tensor,
        pos_neg_replay: tuple,
    ) -> torch.Tensor:
        """
        Method able to compute the ER_AML loss.
        :param input_in: New inputs examples.
        :param target_in: Labels of new examples.
        :param task_in: Task identifiers of new examples.
        :param output_buffer: Predictions of samples from buffer.
        :param target_buffer: Labels of samples from buffer.
        :param pos_neg_replay: Replay data to compute positive and negative samples.
        :return: ER_AML computed loss.
        """
        pos_x, pos_y, neg_x, neg_y = self.__sample_pos_neg(
            target_in, task_in, *pos_neg_replay
        )
        loss_buffer = F.cross_entropy(output_buffer, target_buffer)
        hidden_in = self.__scale_by_norm(self.feature_extractor(input_in))
        hidden_pos_neg = self.__scale_by_norm(
            self.feature_extractor(torch.cat((pos_x, neg_x)))
        )
        loss_in = self.__sup_con_loss(
            anchor_features=hidden_in.repeat(2, 1),
            features=hidden_pos_neg,
            anchor_targets=target_in.repeat(2),
            targets=torch.cat((pos_y, neg_y)),
        )
        return loss_in + loss_buffer


__all__ = [
    "RegularizationMethod",
    "LearningWithoutForgetting",
    "ACECriterion",
    "AMLCriterion",
]
