import torch
import torch.nn as nn
from transformers import TrainingArguments
from trl import SFTTrainer 
import os
from abc import ABC, abstractmethod

from ldifs.featextract import compute_concat_intermediate_features, compute_feature_space_distance_v1, compute_feature_space_distance_normalized
from ldifs.kl_divergence import kl_divergence_loss

class LDIFSTrainer(SFTTrainer, ABC):
    def __init__(
        self, *args, frozen_model=None, ldifs_lambda=10.0, num_intermediate_layers=5, **kwargs
    ):
        super().__init__(*args, **kwargs)
        assert frozen_model is not None
        self.frozen_model = frozen_model
        self.ldifs_lambda = ldifs_lambda
        self._accum_l2_losses = []
        self._accum_base_losses = []
        self.num_intermediate_layers = num_intermediate_layers

    @abstractmethod
    def compute_feature_space_distance(self, model_trained, model_frozen, input_ids, attention_mask):
        pass

    def log(self, logs, start_time=None):
        # Add the last l2 loss to logs if it exists
        if hasattr(self, '_accum_l2_losses') and self._accum_l2_losses:
            logs['l2_loss'] = sum(self._accum_l2_losses)
            if "eval_loss" in logs.keys():
                logs['l2_loss'] = logs['l2_loss']/len(self._accum_l2_losses)
            self._accum_l2_losses.clear()
        if hasattr(self, '_accum_base_losses') and self._accum_base_losses:
            logs['base_loss'] = sum(self._accum_base_losses)
            if "eval_loss" in logs.keys():
                logs['base_loss'] = logs['base_loss']/len(self._accum_base_losses)
            self._accum_base_losses.clear()
        super().log(logs)

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        # Get outputs and compute the base loss using the parent class's method
        base_loss, outputs = super().compute_loss(model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch)

        # Compute the L2 loss between the fine-tuned and frozen models
        input_ids = inputs['input_ids']
        if num_items_in_batch is None:
            num_items_in_batch = input_ids.shape[0]
        attention_mask = inputs['attention_mask']

        l2_loss = self.compute_feature_space_distance(
            model_trained=model,
            model_frozen=self.frozen_model,
            input_ids=input_ids,
            attention_mask=attention_mask
        )

        total_loss = base_loss + (self.ldifs_lambda * l2_loss)
        self._accum_l2_losses.append(l2_loss.detach().cpu().item())
        self._accum_base_losses.append(base_loss.detach().cpu().item())
        return (total_loss, outputs) if return_outputs else total_loss

class LDIFSTrainer_v1(LDIFSTrainer):
    def compute_feature_space_distance(self, model_trained, model_frozen, input_ids, attention_mask):
        return compute_feature_space_distance_v1(
            model_trained=model_trained,
            model_frozen=model_frozen,
            input_ids=input_ids,
            attention_mask=attention_mask,
            num_intermediate_layers=self.num_intermediate_layers
        )
    
class LDIFSTrainer_normalized(LDIFSTrainer):
    def compute_feature_space_distance(self, model_trained, model_frozen, input_ids, attention_mask):
        return compute_feature_space_distance_normalized(
            model_trained=model_trained,
            model_frozen=model_frozen,
            input_ids=input_ids,
            attention_mask=attention_mask,
                        num_intermediate_layers=self.num_intermediate_layers
        )
    
class KLTrainer(LDIFSTrainer):
    def compute_feature_space_distance(self, model_trained, model_frozen, input_ids, attention_mask):
        # Efficiently eval frozen_model with no grad
        with torch.no_grad():
            logits_frozen = model_frozen(input_ids=input_ids, attention_mask=attention_mask).logits

        logits_trained = model_trained(input_ids=input_ids, attention_mask=attention_mask).logits

        return kl_divergence_loss(
            logits_trained, logits_frozen, attention_mask,
            reduction='batchmean', temperature=1.0
        )


class InterleavedKLTrainer(LDIFSTrainer):
    """
    KL divergence trainer that only applies the KL term to interleaved samples.
    
    This trainer expects the dataset to provide an 'is_interleaved' field in the
    batch inputs, which is a binary tensor of shape (batch_size,) indicating which
    samples are interleaved (1.0) vs. standard samples (0.0).
    
    Usage:
        trainer = InterleavedKLTrainer(
            model=model,
            frozen_model=frozen_model,
            ldifs_lambda=10.0,
            train_dataset=dataset,  # Must provide 'is_interleaved' field
            ...
        )
    """
    
    def __init__(self, *args, **kwargs):
        """Initialize and set up custom data collator to handle is_interleaved field."""
        super().__init__(*args, **kwargs)
        
        # Wrap the data collator to preserve is_interleaved
        original_collator = self.data_collator
        
        def custom_collator(features):
            # Extract is_interleaved from features before collation
            is_interleaved_list = []
            for feature in features:
                if isinstance(feature, dict) and 'is_interleaved' in feature:
                    is_interleaved_list.append(feature['is_interleaved'])
                else:
                    is_interleaved_list.append(0)  # Default to not interleaved
            
            # Call original collator
            if original_collator is not None:
                batch = original_collator(features)
            else:
                # Fallback if no collator
                batch = features
            
            # Add is_interleaved to the batch
            if isinstance(batch, dict):
                batch['is_interleaved'] = torch.tensor(is_interleaved_list, dtype=torch.float)
            
            return batch
        
        self.data_collator = custom_collator
    
    def compute_feature_space_distance(self, model_trained, model_frozen, input_ids, attention_mask, sample_mask=None):
        """
        Compute KL divergence with optional sample-level masking.
        
        Args:
            model_trained: The model being trained
            model_frozen: The frozen reference model
            input_ids: Input token IDs, shape (B, T)
            attention_mask: Attention mask, shape (B, T)
            sample_mask: Optional binary mask for samples to include, shape (B,)
        
        Returns:
            KL divergence loss (scalar)
        """
        # Efficiently eval frozen_model with no grad
        with torch.no_grad():
            logits_frozen = model_frozen(input_ids=input_ids, attention_mask=attention_mask).logits

        logits_trained = model_trained(input_ids=input_ids, attention_mask=attention_mask).logits

        return kl_divergence_loss(
            logits_trained, logits_frozen, attention_mask,
            sample_mask=sample_mask,  # Pass sample mask to only compute KL for interleaved samples
            reduction='batchmean', temperature=1.0
        )
    
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        """
        Compute loss with KL divergence applied only to interleaved samples.
        
        Extracts 'is_interleaved' from inputs if present, and passes it to
        compute_feature_space_distance to selectively apply KL divergence.
        """
        # Extract the sample mask before computing base loss
        # Pop it so it doesn't interfere with the base model forward pass
        sample_mask = inputs.pop('is_interleaved', None)
        
        # If sample_mask exists, ensure it's a float tensor on the right device
        if sample_mask is not None:
            sample_mask = sample_mask.float()
            if sample_mask.device != model.device:
                sample_mask = sample_mask.to(model.device)
        
        # Get outputs and compute the base loss using the parent class's method
        base_loss, outputs = super(LDIFSTrainer, self).compute_loss(
            model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch
        )

        # Compute the KL loss with sample masking
        input_ids = inputs['input_ids']
        if num_items_in_batch is None:
            num_items_in_batch = input_ids.shape[0]
        attention_mask = inputs['attention_mask']

        l2_loss = self.compute_feature_space_distance(
            model_trained=model,
            model_frozen=self.frozen_model,
            input_ids=input_ids,
            attention_mask=attention_mask,
            sample_mask=sample_mask  # Pass the mask to apply KL only to interleaved samples
        )

        total_loss = base_loss + (self.ldifs_lambda * l2_loss)
        self._accum_l2_losses.append(l2_loss.detach().cpu().item())
        self._accum_base_losses.append(base_loss.detach().cpu().item())
        
        return (total_loss, outputs) if return_outputs else total_loss
