import torch
import torch.nn as nn
from omegaconf import DictConfig

import lightning.pytorch as pl
from lightning.pytorch.utilities import rank_zero_only, grad_norm

from .modeling.base import PreTrainedModelForAIM
from .aim_impl import get_aim_impl


def freeze(module: nn.Module):
    for param in module.parameters():
        param.requires_grad = False


def unfreeze(module: nn.Module):
    for param in module.parameters():
        param.requires_grad = True


# copied from `src/transformers/models/gemma3/modeling_gemma3.py`
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


# TODO: measure perplexity of the student model against the teacher model as validation??
class AttentionInfluenceModelingTask(pl.LightningModule):
    def __init__(self,
        cfg: DictConfig,
        teacher_model: PreTrainedModelForAIM,
        student_model: PreTrainedModelForAIM,
    ):
        super().__init__()
        self.save_hyperparameters(cfg)

        self.cfg = cfg

        self.teacher_model = teacher_model
        self.student_model = student_model

        num_attention_heads = self.teacher_model.config.num_attention_heads
        num_key_value_heads = getattr(
            self.teacher_model.config,
            'num_key_value_heads',
            num_attention_heads,
        )
        self.num_key_value_groups = num_attention_heads // num_key_value_heads

        self.aim_impl = get_aim_impl(cfg.task.aim_impl)

        if cfg.task.loss == "mse":
            self.loss_fn = nn.MSELoss()
        elif cfg.task.loss == "cosine":
            self.loss_fn = nn.CosineEmbeddingLoss()
        else:
            raise ValueError(f"Unknown loss: {cfg.task.loss}")

        freeze(self.teacher_model)
        freeze(self.student_model)

        unfreeze(self.student_model.get_input_embeddings())
        if self.cfg.model.freeze_original_embeddings:
            self.student_model.get_input_embeddings().freeze()
            self.student_model.get_input_embeddings().unfreeze()

    def _create_dummy_loss_with_grad(self) -> torch.Tensor:
        """
        Create a dummy loss tensor that has a gradient computation graph
        but doesn't affect training. This is used when all tokens in a batch
        are frozen (no new tokens to learn from).
        
        Returns:
            A scalar tensor with grad_fn that evaluates to 0.0
        """
        # Get a parameter from the student model that requires grad
        # This ensures we have a proper gradient computation graph
        student_embeddings = self.student_model.get_input_embeddings()

        # if using PartlyFrozenEmbeddings FIXME
        if hasattr(student_embeddings, 'active_embeddings'):
            dummy_loss = 0.0 * student_embeddings.active_embeddings.weight.sum()
        else:
            dummy_loss = 0.0 * student_embeddings.weight.sum()
        
        return dummy_loss

    def common_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor:
        teacher_input_ids = batch["teacher_input_ids"]
        teacher_attention_mask = batch["teacher_attention_mask"]
        teacher_word_ids = batch["teacher_word_ids"]

        student_input_ids = batch["student_input_ids"]
        student_attention_mask = batch["student_attention_mask"]
        student_word_ids = batch["student_word_ids"]

        teacher_output = self.teacher_model(
            input_ids=teacher_input_ids,
            attention_mask=teacher_attention_mask,
        )
        student_output = self.student_model(
            input_ids=student_input_ids,
            attention_mask=student_attention_mask,
        )

        teacher_word_states, student_word_states = self.aim_impl(
            teacher_attn_weights=teacher_output.attentions,
            teacher_value_states=repeat_kv(
                teacher_output.value_states,
                self.num_key_value_groups,
            ),
            teacher_word_ids=teacher_word_ids,
            student_attn_weights=student_output.attentions,
            student_value_states=repeat_kv(
                student_output.value_states,
                self.num_key_value_groups,
            ),
            student_word_ids=student_word_ids,
        )

        # Handle edge cases where no gradients are available
        if teacher_word_states.size(0) != student_word_states.size(0):
            print('batch_idx', batch_idx)
            print('teacher_word_states.size(0)', teacher_word_states.size(0))
            print('student_word_states.size(0)', student_word_states.size(0))
            print('teacher_word_ids', teacher_word_ids)
            print('student_word_ids', student_word_ids)
            print('teacher_input_ids', teacher_input_ids)
            print('student_input_ids', student_input_ids)
            print('teacher_word_states.size(0) != student_word_states.size(0)')
            print('\n\n')
            # Return a dummy loss with proper gradient computation
            return self._create_dummy_loss_with_grad()

        if self.cfg.task.loss == "mse":
            loss = self.loss_fn(teacher_word_states, student_word_states)
        elif self.cfg.task.loss == "cosine":
            target = torch.ones(teacher_word_states.size(0), device=teacher_word_states.device)
            loss = self.loss_fn(teacher_word_states, student_word_states, target)
        else:
            raise ValueError(f"Unknown loss: {self.cfg.task.loss}")

        # Handle edge cases where no gradients are available (all frozen tokens in batch)
        if not loss.requires_grad:
            return self._create_dummy_loss_with_grad()

        self.log('loss', loss)

        return loss

    def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor:
        return self.common_step(batch, batch_idx)

    def configure_optimizers(self) -> torch.optim.Optimizer:
        return torch.optim.AdamW(
            self.parameters(),
            lr=self.cfg.task.lr,
        )
