import os
import lightning.pytorch as pl 
from collections import OrderedDict
import torch

from src.utils.ema import ExponentialMovingAverage

class EMACallback(pl.Callback):
    def __init__(
        self,
        decay: float,
        ema_warmup: int, # don't do ema until
        dirpath: str,
        filename: str,
        monitor: str,
        mode: str,
        every_n_train_steps: int,
        save_last: bool,
        save_top_k: int,
        save_weights_only: bool,
    ):
        """
        Args:
            decay: Decay rate for EMA.
            dirpath: Directory to save checkpoints.
            filename: Checkpoint filename format.
            monitor: Metric to monitor for checkpointing.
            mode: "min" or "max" for the monitored metric.
            every_n_train_steps: Save every N training steps.
            save_last: If True, always save a checkpoint at the end of training.
            save_top_k: Save top k checkpoints based on the monitored metric.
            save_weights_only: If True, save only weights; else, save full model.
        """
        super().__init__()
        self.decay = decay
        self.ema_warmup_steps = ema_warmup
        self.ema = None
        self.dirpath = dirpath
        self.filename = filename
        self.monitor = monitor
        self.mode = mode
        self.every_n_train_steps = every_n_train_steps
        self.save_last = save_last
        self.save_top_k = save_top_k
        self.save_weights_only = save_weights_only
        self.best_k_models = {}  # Track best models for save_top_k

    # def state_dict(self) -> dict:
    #     # return serializable dict of your callback state
    #     return {"ema_state": self.ema.state_dict()}

    # def load_state_dict(self, state_dict: dict) -> None:
    #     # restore callback state
    #     self.ema.load_state_dict(state_dict["ema_state"])

    def on_train_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
        self.ema = ExponentialMovingAverage(pl_module, decay=self.decay)
        self.ema.to(pl_module.device)

    def on_train_batch_end(
        self,
        trainer: pl.Trainer,
        pl_module: pl.LightningModule,
        outputs,
        batch,
        batch_idx: int,
    ) -> None:
        if trainer.global_step < self.ema_warmup_steps:
            return 
        self.ema.update(pl_module)
        # Save EMA checkpoint every N steps
        # if trainer.global_step % self.every_n_train_steps == 0:
        #     self._save_ema_checkpoint(trainer, pl_module)

    def on_validation_start(self, trainer, pl_module):
        # when starting validation, store current weights and load EMA weights
        if self.ema is None or trainer.global_step < self.ema_warmup_steps:
            return
        self.ema.store(pl_module.parameters())
        self.ema.copy_to(pl_module)

    def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
        # Save EMA checkpoint if monitored metric improved
        if trainer.global_step < self.ema_warmup_steps:
            return 
        self._save_top_k_ema_checkpoint(trainer, pl_module)
        # when finished validation, replace EMA weights again by current training weights
        self.ema.restore(pl_module.parameters())

    def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
        if trainer.global_step < self.ema_warmup_steps:
            return 
        if self.save_last:
            self._save_ema_checkpoint(trainer, pl_module, suffix="last")

    def _save_ema_checkpoint(
        self, trainer: pl.Trainer, pl_module: pl.LightningModule, suffix: str = None
    ) -> None:
        # Create checkpoint filename
        step = trainer.global_step
        filename = self.filename.format(step=step)
        if suffix:
            filename = f"{filename}_{suffix}"
        ckpt_path = os.path.join(self.dirpath, f"ema_{filename}.ckpt")
        os.makedirs(self.dirpath, exist_ok=True)

        # Prepare checkpoint
        checkpoint = {
            "state_dict": self.ema.state_dict(),
            "step": step,
            "monitor_value": trainer.callback_metrics.get(self.monitor, float("inf")),
        }
        if not self.save_weights_only:
            checkpoint.update({
                "hyper_parameters": pl_module.hparams,
                "trainer_global_step": step,
            })

        # Save checkpoint
        torch.save(checkpoint, ckpt_path)
        print(f"Saved EMA checkpoint to {ckpt_path}")

    def _save_top_k_ema_checkpoint(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
        # Only proceed if monitoring a metric
        if self.monitor not in trainer.callback_metrics:
            return

        current_value = trainer.callback_metrics[self.monitor]
        step = trainer.global_step

        # Save current model
        self._save_ema_checkpoint(trainer, pl_module)

        # Update best_k_models
        self.best_k_models[step] = current_value
        while len(self.best_k_models) > self.save_top_k:
            # Remove the worst model
            if self.mode == "min":
                worst_step = max(self.best_k_models, key=self.best_k_models.get)
            else:
                worst_step = min(self.best_k_models, key=self.best_k_models.get)
            worst_filename = self.filename.format(step=worst_step)
            worst_path = os.path.join(
                self.dirpath, f"ema_{worst_filename}.ckpt"
            )
            if os.path.exists(worst_path):
                os.remove(worst_path)
            del self.best_k_models[worst_step]


    # TODO: should be clear from the ckpt filename
    # def on_load_checkpoint(self, trainer: pl.Trainer, pl_module: pl.LightningModule, callback_state: dict) -> None:
    #     # Optionally, load EMA state from checkpoint
    #     if "ema_state_dict" in callback_state:
    #         self.ema.load_state_dict(callback_state["ema_state_dict"])
