"""
Customized HuggingFace trainer for RosettaModel
"""

import torch
import os
from typing import List
from transformers import Trainer, TrainerCallback
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR

from ..model.wrapper import RosettaModel


class RosettaTrainer(Trainer):
    """Custom Trainer for RosettaModel"""
    
    def compute_loss(self, model, inputs, return_outputs=False):
        """Custom loss computation for RosettaModel"""
        labels = inputs.pop("labels")
        
        # Add position_ids
        attention_mask = inputs["attention_mask"]
        inputs["position_ids"] = attention_mask.long().cumsum(-1) - 1
        inputs["use_cache"] = False
        
        # Forward pass
        outputs = model(**inputs, labels=labels)
        loss = outputs.loss
        
        return (loss, outputs) if return_outputs else loss


class ProjectorSaveCallback(TrainerCallback):
    """Callback to save projector weights and config"""
    
    def __init__(self):
        super().__init__()
    
    def on_save(self, args, state, control, model=None, **kwargs):
        """Save projector weights when checkpointing"""
        if state.is_world_process_zero:
            checkpoint_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
            
            # Unwrap DDP if needed
            base_model = model.module if hasattr(model, 'module') else model
            
            # Save projectors
            for i, projector in enumerate(base_model.projector_list):
                torch.save(
                    projector.state_dict(), 
                    os.path.join(checkpoint_dir, f"projector_{i}.pt")
                )
            
            # Save projector config
            base_model.save_projector_config(
                os.path.join(checkpoint_dir, "projector_config.json")
            )


def freeze_model_components(rosetta_model: RosettaModel, freeze_config: List[str]):
    """Freeze specified model components"""
    if "base" in freeze_config:
        for param in rosetta_model.model_list[0].parameters():
            param.requires_grad = False
    
    if "teacher" in freeze_config:
        for param in rosetta_model.model_list[1].parameters():
            param.requires_grad = False
    
    if "projector" in freeze_config:
        for projector in rosetta_model.projector_list:
            for param in projector.parameters():
                param.requires_grad = False