"""CMI Loss Trainer implementation compatible with Hugging Face Transformers."""

import torch
from transformers import Trainer
from typing import Optional, Dict, Any, Union, Tuple
from .config import CMILossConfig
from .loss import compute_cmi_loss, get_cmi_lambda_scheduled


class CMILossTrainer(Trainer):
    """
    Trainer class that implements CMI Loss for supervised fine-tuning.
    
    This trainer extends the Hugging Face Trainer to incorporate CMI Loss,
    which encourages models to articulate reasoning before generating responses.
    """
    
    def __init__(
        self,
        model=None,
        args=None,
        data_collator=None,
        train_dataset=None,
        eval_dataset=None,
        tokenizer=None,
        model_init=None,
        compute_metrics=None,
        callbacks=None,
        optimizers=(None, None),
        preprocess_logits_for_metrics=None,
        cmi_config: Optional[CMILossConfig] = None,
    ):
        super().__init__(
            model=model,
            args=args,
            data_collator=data_collator,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            tokenizer=tokenizer,
            model_init=model_init,
            compute_metrics=compute_metrics,
            callbacks=callbacks,
            optimizers=optimizers,
            preprocess_logits_for_metrics=preprocess_logits_for_metrics,
        )
        
        self.cmi_config = cmi_config or CMILossConfig()
        self._setup_thinking_tokens()
        
    def _setup_thinking_tokens(self):
        """Setup thinking token IDs from tokenizer."""
        self.thinking_start_tokens = None
        self.thinking_end_tokens = None
        
        if self.tokenizer is not None:
            try:
                # Try to encode thinking markers
                think_start = self.tokenizer.encode("<think>", add_special_tokens=False)
                think_end = self.tokenizer.encode("</think>", add_special_tokens=False)
                
                self.thinking_start_tokens = think_start
                self.thinking_end_tokens = think_end
            except Exception:
                # Thinking tokens not available, CMI will use full sequence
                pass
    
    def compute_loss(
        self,
        model,
        inputs,
        return_outputs=False,
        num_items_in_batch=None
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, Any]]:
        """
        Compute loss using CMI Loss formulation.
        
        Overrides the standard compute_loss to implement CMI regularization.
        """
        labels = inputs.pop("labels") if "labels" in inputs else None
        
        # Get sample types if available (for selective application)
        sample_types = inputs.pop("sample_types") if "sample_types" in inputs else None
        
        # Forward pass
        outputs = model(**inputs)
        logits = outputs.get("logits")
        
        if labels is not None and logits is not None:
            # Get current CMI lambda with scheduling
            cmi_lambda = self._get_current_cmi_lambda()
            
            # Compute CMI loss
            loss, metrics = compute_cmi_loss(
                logits=logits,
                labels=labels,
                cmi_lambda=cmi_lambda,
                thinking_start_tokens=self.thinking_start_tokens,
                thinking_end_tokens=self.thinking_end_tokens,
                thinking_weight=self.cmi_config.cmi_thinking_weight,
                normalize_losses=self.cmi_config.cmi_loss_normalize,
                sample_types=sample_types,
                apply_to_harmful_only=self.cmi_config.cmi_apply_to_harmful_only,
            )
            
            # Log metrics
            if self.state.global_step > 0:
                for key, value in metrics.items():
                    self.log({f"cmi/{key}": value})
        else:
            # Fallback to standard loss if needed
            loss = outputs.get("loss") if outputs.get("loss") is not None else None
        
        return (loss, outputs) if return_outputs else loss
    
    def _get_current_cmi_lambda(self) -> float:
        """Get current CMI lambda value with scheduling."""
        if not hasattr(self.state, 'global_step') or self.args.max_steps is None:
            return self.cmi_config.cmi_lambda_start
        
        return get_cmi_lambda_scheduled(
            current_step=self.state.global_step,
            max_steps=self.args.max_steps,
            warmup_ratio=self.cmi_config.cmi_warmup_ratio,
            rampup_ratio=self.cmi_config.cmi_rampup_ratio,
            lambda_start=self.cmi_config.cmi_lambda_start,
            lambda_end=self.cmi_config.cmi_lambda,
        )