import os
import re
import torch
import pytorch_lightning as pl

from lightning.pytorch.callbacks import Callback
from lightning.pytorch.utilities import rank_zero_only
from timm.utils.model_ema import ModelEmaV2


class EMA(Callback):
    def __init__(self, decay=0.9995, 
                 save_last=False, dirpath=None, filename='last_ema.ckpt'):
        self.decay = decay
        self.dirpath = dirpath
        self.filename = filename
        self.filepath = None
        self.save_last = save_last

        self.ema = None


    def on_fit_start(self, trainer, pl_module):
        self.ema = ModelEmaV2(pl_module, decay=self.decay, device=pl_module.device)

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        self.ema.update(pl_module)

    def on_validation_epoch_start(self, trainer, pl_module):
        if trainer.sanity_checking: return None
        self._swap_model_weights(pl_module)

    def on_validation_epoch_end(self, trainer, pl_module):
        if trainer.sanity_checking: return None
        self._swap_model_weights(pl_module)

        if self.save_last and rank_zero_only.rank == 0:
            self._save_ema_checkpoint(trainer, pl_module)

    def on_test_epoch_start(self, trainer, pl_module):
        self._swap_model_weights(pl_module)

    def on_test_epoch_end(self, trainer, pl_module):
        self._swap_model_weights(pl_module)
        
    def _swap_model_weights(self, pl_module):
        for param, ema_param in zip(pl_module.parameters(), self.ema.module.parameters()):
            param.data, ema_param.data = ema_param.data, param.data
    
    def _save_ema_checkpoint(self, trainer, pl_module):
        if self.filepath is None:
            self._initial_save_path(trainer)
            
        ema_ckpt = {
            'state_dict': self.ema.module.state_dict(),
            'hyper_parameters': pl_module.hparams,
            'pytorch-lightning_version': pl.__version__,
            'epoch': pl_module.current_epoch,
            'global_step': pl_module.global_step
        }
        torch.save(ema_ckpt, self.filepath)

    def _initial_save_path(self, trainer):
        if self.dirpath is None:
            if trainer.log_dir:
                self.dirpath = trainer.log_dir
            else:
                self.dirpath = trainer.default_root_dir or '.'
        filename = self.filename
        while os.path.exists(os.path.join(self.dirpath, filename)):
            base, ext = os.path.splitext(filename)
            
            suffix = '_ema'
            if base.endswith(suffix):
                base_without_suffix = base[:-len(suffix)]
            else:
                base_without_suffix = base
                suffix = ''

            match = re.match(r'(.+)-v(\d+)$', base_without_suffix)
            
            if match:
                base_part, version_str = match.groups()
                version = int(version_str) + 1
            else:
                base_part = base_without_suffix
                version = 1
            
            filename = f'{base_part}-v{version}{suffix}{ext}'

        self.filepath = os.path.join(self.dirpath, filename)
