"""Gradient-ascent-based unlearning.

Taken from:
    https://arxiv.org/pdf/2310.10683
"""
from typing import Any, Dict, Optional

import torch
from torch.utils.data import DataLoader
from transformers import PreTrainedModel

from npeff_torch.models import model_utils


###############################################################################



R"""

- Maybe not implement the random mismatch loss (epsilon_2 = 0) at first, paper has it as baseline
  in its experiments, so look at the impact there.

- Maybe add ability to sample from model on some examples for trivia-qa. Can see what responses are and help
  choose examples to pick.

- Can maybe try to increase KL from original model's predictions on examples to forget instead of label.

"""

###############################################################################


class LastTokenKlOnlyGradientAscent:
    """Gradient ascent on the KL-divergence of the original model's predictions for the last token.

    Intended to only be used for language models. Meant for my initial experimentation on the trivia-qa task.
    """

    def __init__(
        self, *,
        # The model that will get modified.
        model: PreTrainedModel,
        # Should be initialized to the same parameters as the `model`.
        original_model: PreTrainedModel,

        # The weight associated with the gradient ascent loss.
        epsilon_forget: float,
        # The weight associated with retention loss.
        epsilon_retain: float,
    ):
        self._model = model
        self._original_model = original_model

        # self._forget_dataloader = forget_dataloader
        # self._normal_dataloader = normal_dataloader

        self._epsilon_forget = epsilon_forget
        self._epsilon_retain = epsilon_retain

    #######################################################

    def _compute_last_token_log_probs(self, model: PreTrainedModel, batch: Dict[str, Any], device: Optional[torch.device]) -> torch.Tensor:
        # ret.shape = [batch, 1, vocab]

        # sequence_logits.shape = [batch, sequence, vocab]
        sequence_logits = model_utils.compute_logits(model, batch, device)
        # return sequence_logits[]

        n_non_paddings = torch.sum((batch['attention_mask'] != 0).type(torch.int64), dim=-1)
        # positions.shape = [batch]
        positions = n_non_paddings - 1
        g_positions = positions[:, None, None].expand(-1, -1, sequence_logits.shape[-1])

        # shape = [batch, 1, vocab]
        last_token_logits = torch.gather(sequence_logits, 1, g_positions)
        # shape = [batch, vocab]
        last_token_logits = torch.squeeze(last_token_logits, dim=1)

        return torch.nn.functional.log_softmax(last_token_logits, dim=-1)

    #######################################################


    # TODO: Add options for multiple forget/retain batches (possibly with different epsilons)


    

    def compute_loss_info(
        self, *,
        # Batch containing examples whose predictions we want to alter.
        forget_batch: Dict[str, Any],
        # Batch of normal examples whose predictions we want to keep.
        retain_batch: Dict[str, Any],

        device: Optional[torch.device] = None,
    ) -> Dict[str, torch.Tensor]:
        forget_batch = {k: v.to(device) for k, v in forget_batch.items()}
        retain_batch = {k: v.to(device) for k, v in retain_batch.items()}

        with torch.no_grad():
            original_forget_logits = self._compute_last_token_log_probs(self._original_model, forget_batch, device).detach()
            # original_forget_logits = torch.log_softmax(1.01 * original_forget_logits, dim=-1)
            original_retain_logits = self._compute_last_token_log_probs(self._original_model, retain_batch, device).detach()
        
        forget_logits = self._compute_last_token_log_probs(self._model, forget_batch, device)
        retain_logits = self._compute_last_token_log_probs(self._model, retain_batch, device)
        
        # NOTE: The order of the arguments to kl_div are reversed. This is correct due to how PyTorch does things.

        # We use KL(altered||original) for the forget since this is most similar to a supervised loss.
        forget_kl = torch.nn.functional.kl_div(original_forget_logits, forget_logits, reduction='batchmean', log_target=True)
        # NOTE: Following the paper, we use KL(original||altered) for the retain.
        retain_kl = torch.nn.functional.kl_div(retain_logits, original_retain_logits, reduction='batchmean', log_target=True)

        loss = self._epsilon_retain * retain_kl - self._epsilon_forget * forget_kl

        return {'loss': loss, 'forget_kl': forget_kl, 'retain_kl': retain_kl}
