"""
Custom trainer class for distilling attentions ("attention transfer"). Can substitute for Hugging Face trainer.

In this implementation we support using either just the softmax attention outputs, or the softmax attention weights.
"""
import torch
import torch.nn as nn

from .default_lm import OurTrainer as DefaultTrainer


class OurTrainer(DefaultTrainer):
    """
    Custom trainer class for distilling attentions. 
    - We compute and store the attention outputs and/or weights for each head and layer,
      for both the "teacher" softmax attentions and "student" learnable subquadratic attentions
    - We then train the student layers to minimize either MSE(outputs) or CrossEntropy(weights)
    """
    def __init__(self,
                 model: nn.Module,
                 metric_for_best_model: str = 'distill/eval/loss',
                 mse_factor: float = 1e3,
                 xent_factor: float = 0,
                 **kwargs: any):
        super().__init__(model=model, 
                         metric_for_best_model=metric_for_best_model,
                         **kwargs)
        self.criterion_xent = nn.CrossEntropyLoss(reduction='mean')
        self.criterion_mse = nn.MSELoss(reduction='mean')
        self.mse_factor = mse_factor
        self.xent_factor = xent_factor
        self.compute_loss_backprop = False  # Whether we backprop in self.compute_loss

    def compute_loss(self, model: nn.Module, data: dict[torch.Tensor],
                     sample_idx: int = None, **kwargs: any,) -> tuple[torch.Tensor, dict[any]]:
        """
        Attention distillation ("attention transfer")
        - For each layer and head, get attentions and train to 
          minimize some combo of MSE and cross-entropy loss
        """
        inputs = {k: v.to(model.device) for k, v in data.items() if k != 'labels'}
        outputs = model(**inputs, output_attentions=True, use_cache=False)
        outputs = outputs.get('attentions')
        device = model.device

        # Attentions are tuple[tuple[torch.Tensor, torch.Tensor]]
        # n_layers x (predicted_attns, true_attns)
        # predicted_attns and true_attns are shape (batch, n_heads, q_len, k_len)
        loss_mse = 0
        loss_xent = 0
        n_layers = 0  # Number of layers to distill
        softmax_layers = []
        for layer_idx, attns in enumerate(outputs):
            if attns is not None:
                if len(attns) != 2:
                    attns = attns.cpu()
                else:
                    if self.xent_factor > 0:
                        # Cross-entropy loss
                        a_pred, a_true = attns[0]
                        # a_pred, a_true = a_pred.to(device), a_true.to(device)
                        a_pred = a_pred.clamp(min=1e-12).log()  # nn.CrossEntropy assumes unnormalized logits
                        k_len = a_true.shape[-1]  # batch, n_heads, q_len, k_len
                        # Compute mean cross-entropy over all queries
                        a_pred = a_pred.contiguous().view(-1, k_len)
                        a_true = a_true.contiguous().view(-1, k_len)
                        loss_xent += self.criterion_xent(a_pred, a_true)
                    if self.mse_factor > 0:
                        y_pred, y_true = attns[1]
                        # y_pred, y_true = y_pred.to(device), y_true.to(device)
                        loss_mse += self.criterion_mse(y_pred, y_true)
                    n_layers += 1
            else:
                softmax_layers.append(layer_idx)
        if n_layers > 0:
            loss_xent = loss_xent / n_layers * self.xent_factor
            loss_mse = loss_mse / n_layers * self.mse_factor
        loss = loss_xent + loss_mse
        if 'position_ids' in data:
            outputs = {'loss_xent': loss_xent.item() if self.xent_factor > 0 else 0,
                       'loss_mse': loss_mse.item() if self.mse_factor > 0 else 0,
                       'input_len': data['position_ids'].shape[1],
                       'position_ids': data['position_ids'][0].detach().cpu().numpy(),
                       'mse_factor': self.mse_factor,
                       'xent_factor': self.xent_factor,}
        else:
            outputs = {'loss_xent': loss_xent.item() if self.xent_factor > 0 else 0,
                       'loss_mse': loss_mse.item() if self.mse_factor > 0 else 0, 
                       'mse_factor': self.mse_factor, 
                       'xent_factor': self.xent_factor}
        return loss, outputs
