"""
Exponential Moving Average (EMA) utilities for parameter updates during fine-tuning.
Compatible with Hugging Face Trainer checkpoint saving and resuming.
"""

import os
import logging
from typing import Dict, Optional, Union, Callable
import torch
import torch.nn as nn
from transformers import PreTrainedModel, AutoModelForCausalLM, TrainerCallback
from contextlib import contextmanager

try:
    from peft import PeftModel

    PEFT_AVAILABLE = True
except ImportError:
    PEFT_AVAILABLE = False


logger = logging.getLogger(__name__)


class EMAUpdateCallback(TrainerCallback):
    def __init__(self, ema_model: 'EMAModel', should_update_fn: Optional[Callable[[], bool]] = None):
        self.ema_model = ema_model
        self.should_update_fn = should_update_fn

    def on_optimizer_step(self, args, state, control, **kwargs):
        if self.ema_model is None:
            return
        if self.should_update_fn is not None and not self.should_update_fn():
            return
        self.ema_model.step_ema(state.global_step)


class EMAModel:
    """
    Exponential Moving Average (EMA) model for parameter updates.

    This class maintains EMA copies of model parameters and provides methods
    to update them and swap with the original model for evaluation/inference.

    Compatible with both regular models and PEFT models.
    """

    def __init__(
        self,
        model: Union[PreTrainedModel, PeftModel],
        decay: float = 0.999,
        min_decay: float = 0.0,
        update_after_step: int = 0,
        use_ema_warmup: bool = False,
        inv_gamma: float = 1.0,
        power: float = 2 / 3,
        # model_cls: Optional[type] = None,
        # model_config: Optional[object] = None,
    ):
        """
        Initialize EMA model.

        Args:
            model: The model to create EMA for
            decay: EMA decay rate
            min_decay: Minimum decay rate (for warmup)
            update_after_step: Start updating EMA after this many steps
            use_ema_warmup: Whether to use EMA warmup
            inv_gamma: Inverse gamma for warmup
            power: Power for warmup
            # model_cls: Model class for creating EMA copy (if needed)
            # model_config: Model config for creating EMA copy (if needed)
        """
        self.decay = decay
        self.min_decay = min_decay
        self.update_after_step = update_after_step
        self.use_ema_warmup = use_ema_warmup
        self.inv_gamma = inv_gamma
        self.power = power

        # Get only trainable parameters
        self.trainable_params = {}  # reference to the original model parameters; do not overwrite them! 
        self.ema_params = {}
        self.params_cache = {}

        # Handle PEFT models differently
        if PEFT_AVAILABLE and isinstance(model, PeftModel):
            self._init_peft_ema(model)
        else:
            self._init_regular_ema(model)

        self.step = 0
        self.cur_decay_value = None

    def _init_peft_ema(self, model: PeftModel):
        """Initialize EMA for PEFT model."""
        logger.info("Initializing EMA for PEFT model")

        # Get trainable parameters from PEFT adapters
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.trainable_params[name] = param
                # Create EMA copy
                self.ema_params[name] = param.data.clone().detach()

        logger.info(f"EMA initialized for {len(self.trainable_params)} trainable parameters")

    def _init_regular_ema(self, model: PreTrainedModel):
        """Initialize EMA for regular model."""
        logger.info("Initializing EMA for regular model")

        # Get all parameters
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.trainable_params[name] = param
                # Create EMA copy
                self.ema_params[name] = param.data.clone().detach()

        logger.info(f"EMA initialized for {len(self.trainable_params)} trainable parameters")

    def get_decay(self, optimization_step: int) -> float:
        """Calculate decay value with optional warmup."""
        step = max(0, optimization_step - self.update_after_step - 1)

        if self.use_ema_warmup:
            # Warmup phase
            value = 1 - (1 + step / self.inv_gamma) ** -self.power
        else:
            value = 1

        return max(self.min_decay, value * self.decay)

    def step_ema(self, optimization_step: int):
        """Update EMA parameters."""
        self.step = optimization_step
        decay = self.get_decay(optimization_step)
        self.cur_decay_value = decay

        with torch.no_grad():
            for name, param in self.trainable_params.items():
                if name in self.ema_params:
                    self.ema_params[name] = decay * self.ema_params[name] + (1 - decay) * param.data

    def copy_to(self, model: Union[PreTrainedModel, PeftModel]):
        """Copy EMA parameters to model."""
        for name, param in model.named_parameters():
            if name in self.ema_params:
                param.data.copy_(self.ema_params[name])

    def store(self, model: Union[PreTrainedModel, PeftModel]):
        """Store current model parameters."""
        for name, param in model.named_parameters():
            if name in self.trainable_params:
                self.params_cache[name] = param.data.clone()

    def restore(self, model: Union[PreTrainedModel, PeftModel]):
        """Restore stored parameters to model."""
        for name, param in model.named_parameters():
            if name in self.params_cache:
                param.data.copy_(self.params_cache[name])
        self.params_cache.clear()

    @contextmanager
    def ema_loaded(self, model: Union[PreTrainedModel, PeftModel]):
        if not self.trainable_params:
            yield  # early return if not initialized
            return
        self.store(model)   # store current parameters
        self.copy_to(model) # copy EMA parameters to model
        try:
            yield
        finally:
            self.restore(model)  # always restore original parameters, even if evaluation fails
    
    def save_pretrained(self, model: Union[PreTrainedModel, PeftModel], save_dir: str):
        with self.ema_loaded(model):
            model.save_pretrained(save_dir)
            logging.info(f"Saved model with EMA adapters to {save_dir}")

    def state_dict(self) -> Dict:
        """Get EMA state dict for checkpointing."""
        return {
            "ema_params": self.ema_params,
            "step": self.step,
            "cur_decay_value": self.cur_decay_value,
            "decay": self.decay,
            "min_decay": self.min_decay,
            "update_after_step": self.update_after_step,
            "use_ema_warmup": self.use_ema_warmup,
            "inv_gamma": self.inv_gamma,
            "power": self.power,
        }

    def load_state_dict(self, state_dict: Dict, device: Optional[torch.device] = None):
        """Load EMA state dict from checkpoint.
        
        Args:
            state_dict: The state dict to load
            device: Optional device to move EMA parameters to. If None, parameters remain on their current device.
        """
        # Load EMA parameters and move to device if specified
        ema_params = state_dict["ema_params"]
        if device is not None:
            # Move all EMA parameter tensors to the specified device
            ema_params = {name: param.to(device) for name, param in ema_params.items()}
        self.ema_params = ema_params
        
        self.step = state_dict["step"]
        self.cur_decay_value = state_dict["cur_decay_value"]
        self.decay = state_dict["decay"]
        self.min_decay = state_dict["min_decay"]
        self.update_after_step = state_dict["update_after_step"]
        self.use_ema_warmup = state_dict["use_ema_warmup"]
        self.inv_gamma = state_dict["inv_gamma"]
        self.power = state_dict["power"]

        logger.info(f"Loaded EMA state from checkpoint at step {self.step}" + (f" on device {device}" if device is not None else ""))
    
    def get_weight_norms(self, model: Union[PreTrainedModel, PeftModel] = None) -> Dict:
        """Get weight differences between model and EMA parameters."""
        weight_diffs = {}
        param_norms = {}
        ema_norms = {}
        named_params = model.named_parameters() if model is not None else self.trainable_params.items()
        for name, param in named_params:
            if name in self.ema_params:
                weight_diffs[name] = torch.norm(param.data - self.ema_params[name]).item()
                param_norms[name] = torch.norm(param.data).item()
                ema_norms[name] = torch.norm(self.ema_params[name]).item()
        return weight_diffs, param_norms, ema_norms
    
    @classmethod
    def load_from_checkpoint(cls, checkpoint_path: str) -> 'EMAModel':
        pass
    
    @classmethod
    def load_from_model_and_ema_state(cls, model, ema_state_path: str) -> 'EMAModel':
        """Load model from checkpoint, then load ema state and return an EMAModel instance"""
        ema_model = cls(model)
        if os.path.exists(ema_state_path):
            # Load to CPU first, then move to model device
            ema_state = torch.load(ema_state_path, map_location="cpu")
            # Get device from model
            model_to_check = model
            if hasattr(model_to_check, "module"):  # Handle DeepSpeed/DataParallel
                model_to_check = model_to_check.module
            # Get device from first trainable parameter, fallback to first parameter if none are trainable
            try:
                device = next(p.device for p in model_to_check.parameters() if p.requires_grad)
            except StopIteration:
                device = next(model_to_check.parameters()).device
            ema_model.load_state_dict(ema_state, device=device)
            logging.info(f"Loaded EMA state from {ema_state_path}")
        else:
            logging.warning(f"EMA checkpoint not found at {ema_state_path}")
        return ema_model