"""
Knowledge Distillation Framework

This module implements the core knowledge distillation components including:
- DistilModel: Wrapper that combines teacher and student models
- DistillationLoss: Loss function that combines KL divergence and similarity measures
- DistilTrainer: Custom trainer for distillation training

The framework supports various similarity measures and layer alignment strategies
for transferring knowledge from teacher to student models.
"""

from typing import List, Optional, Tuple
from torch import nn, Tensor
from torch.nn import functional as F
import torch
from similarity_measures import LinearMeasure, CKA, MSE_w_padding
from transformers import Trainer


class DistilModel(nn.Module):
    """
    Wrapper model that combines teacher and student models for distillation.
    
    This model handles the forward pass for both teacher and student models,
    ensuring the teacher model is in eval mode and gradients are disabled.
    """
    _keys_to_ignore_on_save = set()

    def __init__(self, student_model: nn.Module, teacher_model: nn.Module, *args, **kwargs):
        """
        Initialize the distillation model wrapper.
        
        Args:
            student_model: The model to be trained (student)
            teacher_model: The model to distill knowledge from (teacher)
            *args, **kwargs: Additional arguments passed to nn.Module
        """
        super().__init__(*args, **kwargs)
        self.student_model = student_model
        self.teacher_model = teacher_model
        
        # Ensure teacher model is in evaluation mode and gradients are disabled
        self.teacher_model.eval()
        self.teacher_model.requires_grad_(False)
        
        # Configure keys to ignore when saving (exclude teacher model parameters)
        DistilModel._keys_to_ignore_on_save = set([
            'teacher_model.' + k for k in self.teacher_model.state_dict()
        ])

    def forward(self, *args, **kwargs):
        """
        Forward pass that returns both student and teacher outputs.
        
        Args:
            *args, **kwargs: Input arguments passed to both models
            
        Returns:
            Tuple of (student_output, teacher_output) where each output
            contains logits and hidden states
        """
        # Get student model outputs with hidden states
        student_output = self.student_model(*args, output_hidden_states=True, **kwargs)
        
        # Get teacher model outputs with hidden states (no gradients)
        with torch.no_grad():
            teacher_output = self.teacher_model(*args, output_hidden_states=True, **kwargs)

        return student_output, teacher_output 

    def train(self, mode: bool = True):
        """
        Set training mode for the student model only.
        
        Args:
            mode: Whether to set training mode (True) or evaluation mode (False)
            
        Returns:
            self for method chaining
        """
        self.student_model.train(mode)
        return self

    def parameters(self, recurse: bool = True):
        """
        Return parameters of the student model only.
        
        Args:
            recurse: Whether to recursively return parameters of submodules
            
        Returns:
            Iterator over student model parameters
        """
        return self.student_model.parameters(recurse=recurse)

    def named_parameters(
            self,
            prefix: str = '',
            recurse: bool = True,
            remove_duplicate: bool = True
    ):
        """
        Return named parameters of the student model only.
        
        Args:
            prefix: Prefix to prepend to all parameter names
            recurse: Whether to recursively return parameters of submodules
            remove_duplicate: Whether to remove duplicate parameters
            
        Returns:
            Iterator over (name, parameter) pairs for student model
        """
        return self.student_model.named_parameters(
            prefix=prefix, 
            recurse=recurse, 
            remove_duplicate=remove_duplicate
        )


class DistillationLoss(nn.Module):
    """
    Loss function for knowledge distillation that combines KL divergence and similarity measures.
    
    This loss function supports various similarity measures and allows for flexible
    layer alignment between teacher and student models.
    """
    
    def __init__(
            self,
            gamma=0.6,
            temperature=2.,
            similarity_measure=None,
            full_similarity=False,
            align_match=None,
            **similarity_measure_kwargs,
    ):
        """
        Initialize the distillation loss function.
        
        Args:
            gamma: Weight for balancing similarity loss vs KL divergence (0-1)
            temperature: Temperature for softmax in KL divergence
            similarity_measure: Type of similarity measure ('linear', 'cosine', 'cka', 'euclidean')
            full_similarity: Whether to use full similarity across all layers
            align_match: Layer alignment configuration [[student_layers], [teacher_layers]]
            **similarity_measure_kwargs: Additional arguments for similarity measures
        """
        super().__init__()
        self.register_buffer('gamma', torch.tensor(gamma))
        self.register_buffer('temperature', torch.tensor(temperature))

        # KL divergence loss for logits
        self.kl_div = nn.KLDivLoss(reduction='batchmean')
        
        # Similarity measure configuration
        self.similarity_measure = similarity_measure
        self.full_similarity = full_similarity
        self.align_match = align_match
        
        # Initialize similarity loss based on measure type
        if similarity_measure == 'cosine':
            self.similarity_loss = nn.CosineEmbeddingLoss()
        elif similarity_measure == 'linear':
            self.similarity_loss = LinearMeasure(**similarity_measure_kwargs)
        elif similarity_measure == "cka":
            self.similarity_loss = CKA(**similarity_measure_kwargs)
        elif similarity_measure == "euclidean":
            self.similarity_loss = MSE_w_padding()
        elif similarity_measure is None or similarity_measure == 'none':
            self.similarity_loss = None
        else:
            raise ValueError(f'Unrecognized similarity measure {similarity_measure}')

    def forward(self,
                student_logits,
                teacher_logits,
                student_hidden: Tuple[Tensor, ...] = None,
                teacher_hidden: Tuple[Tensor, ...] = None,
                return_parts=False):
        """
        Compute the distillation loss.
        
        Args:
            student_logits: Logits from student model
            teacher_logits: Logits from teacher model
            student_hidden: Hidden states from student model
            teacher_hidden: Hidden states from teacher model
            return_parts: Whether to return individual loss components
            
        Returns:
            Total loss or tuple of (total_loss, kl_loss, sim_loss) if return_parts=True
        """
        # Compute KL divergence loss between logits
        if (1 - self.gamma) > 1e-8:
            soft_log_student = F.log_softmax(student_logits / self.temperature, dim=-1)
            soft_teacher = F.softmax(teacher_logits / self.temperature, dim=-1)
            kl_loss = self.kl_div(soft_log_student, soft_teacher)
        else:
            kl_loss = torch.tensor(0, device=student_logits.device)

        # Weighted KL loss
        output = (1 - self.gamma) * self.temperature ** 2 * kl_loss

        # Compute similarity loss if enabled
        sim_loss = None
        if not (self.gamma < 1e-8 or self.similarity_loss is None):
            if self.full_similarity:
                # Use all layers with subsampling
                assert (len(teacher_hidden) - 1) % (len(student_hidden) - 1) == 0
                subsampling_ratio = (len(teacher_hidden) - 1) // (len(student_hidden) - 1)
                teacher_extracted = teacher_hidden[::subsampling_ratio]
                student_extracted = student_hidden
            elif self.align_match is not None:
                # Use specified layer alignment
                if len(self.align_match[0]) != len(self.align_match[1]):
                    raise ValueError("Need to have the same number of layers for align matches")

                student_extracted = torch.stack(student_hidden)[self.align_match[0]]
                teacher_extracted = torch.stack(teacher_hidden)[self.align_match[1]]
            else:
                # Default: use last layers
                student_extracted = student_hidden[-1]
                teacher_extracted = teacher_hidden[-1]

            # Flatten tensors for similarity computation
            to_align_student = student_extracted.flatten(end_dim=1)
            to_align_teacher = teacher_extracted.flatten(end_dim=1)
            
            # Compute similarity loss based on measure type
            if self.similarity_measure == 'cosine':
                sim_loss = self.similarity_loss(
                    to_align_student.flatten(end_dim=1),
                    to_align_teacher.flatten(end_dim=1),
                    torch.ones(
                        to_align_student.shape[0] * to_align_student.shape[1],
                        device=to_align_student.device
                    ).long()
                )
            elif self.similarity_measure in ["linear", "cka", "euclidean"]:
                sim_loss = self.similarity_loss(to_align_student, to_align_teacher)
            else:
                raise NotImplementedError

        # Add similarity loss to total loss
        if sim_loss is not None:
            output += self.gamma * sim_loss
            
        # Return appropriate output format
        if return_parts:
            output = (output, kl_loss) + ((sim_loss,) if sim_loss is not None else ())
        return output


class DistilTrainer(Trainer):
    """
    Custom trainer for knowledge distillation training.
    
    This trainer extends the standard HuggingFace Trainer to handle
    distillation-specific training logic and loss computation.
    """
    
    def __init__(self, student_model=None, teacher_model=None, loss_fn=None, 
                 temperature=None, include_targets=False, *args, **kwargs):
        """
        Initialize the distillation trainer.
        
        Args:
            student_model: The model to be trained (student)
            teacher_model: The model to distill knowledge from (teacher)
            loss_fn: Distillation loss function
            temperature: Temperature for softmax (if not provided in loss_fn)
            include_targets: Whether to include target labels in loss computation
            *args, **kwargs: Additional arguments passed to Trainer
        """
        # Create distillation model wrapper
        model = DistilModel(student_model, teacher_model)
        super().__init__(model=model, *args, **kwargs)
        
        # Store distillation-specific parameters
        self.temperature = temperature if temperature else 1.
        self.loss_fn = loss_fn
        self.include_targets = include_targets

    def compute_loss(self, model, inputs, return_outputs=False):
        """
        Compute the distillation loss for a batch.
        
        Args:
            model: The distillation model wrapper
            inputs: Input batch
            return_outputs: Whether to return model outputs along with loss
            
        Returns:
            Loss value or tuple of (loss, outputs) if return_outputs=True
        """
        # Get outputs from both student and teacher models
        student_output, teacher_output = model(**inputs)
        
        # Compute distillation loss
        loss = self.loss_fn(
            student_output.logits, 
            teacher_output.logits, 
            student_output.hidden_states, 
            teacher_output.hidden_states
        )
        
        # Optionally add cross-entropy loss on targets
        if self.include_targets:
            loss += F.cross_entropy(student_output.logits, inputs["labels"])
            
        return (loss, student_output) if return_outputs else loss