import os
from typing import Any, Dict

import torch
from hydra.utils import get_class
from torch_ema import ExponentialMovingAverage

from ..utils import create_hook
from ..utils.io import setup_model_from_checkpoint


class PropertyMixIn:
    """
    A mixin class that provides convenient access to the device and dtype
    of the model's parameters. This is useful for ensuring that operations
    are consistently applied across the model.

    Properties:
        device (torch.device): The device on which the model's parameters are located.
        dtype (torch.dtype): The data type of the model's parameters.
    """

    @property
    def device(self):
        """
        Returns the device of the first parameter in the model.

        Returns:
            torch.device: The device on which the model's parameters are located.
        """
        return next(self.parameters()).device

    @property
    def dtype(self):
        """
        Returns the dtype of the first parameter in the model.

        Returns:
            torch.dtype: The data type of the model's parameters.
        """
        return next(self.parameters()).dtype


class EMAMixIn:
    """
    A mixin class that adds Exponential Moving Average (EMA) functionality to the model.
    This class hooks into various stages of the training and validation process to
    ensure that the EMA is updated and transferred correctly.

    Methods:
        _ema_to: Moves the EMA to the correct device and dtype.
        _ema_update: Updates the EMA with the current model parameters.
        _ema_save: Saves the EMA state to the checkpoint.
        _ema_load: Loads the EMA state from the checkpoint.
    """

    def __new__(cls, *args, **kwargs):
        """
        Overrides __new__ to attach hooks for EMA updates and transfers during training and checkpointing.
        """
        cls.on_train_start = create_hook(cls.on_train_start, cls._ema_to)
        cls.on_validation_epoch_start = create_hook(cls.on_validation_epoch_start, cls._ema_to)
        cls.on_before_zero_grad = create_hook(cls.on_before_zero_grad, cls._ema_update)
        cls.on_load_checkpoint = create_hook(cls.on_load_checkpoint, cls._ema_load)
        cls.on_save_checkpoint = create_hook(cls.on_save_checkpoint, cls._ema_save)
        return super().__new__(cls)

    def _after_init(self, *args, **kwargs):
        try:
            self.ema_decay = kwargs['model']['ema_decay']
        except:
            self.ema_decay = 0.99
        self.ema = None
        return super()._after_init(*args, **kwargs)

    def _ema_to(self, *args, **kwargs):
        """
        Moves the EMA object to the same device and dtype as the model's parameters.
        """
        if self.ema is None:
            self.ema = ExponentialMovingAverage(self.model.parameters(), self.ema_decay)
        self.ema.to(self.device, self.dtype)

    def _ema_update(self, *args, **kwargs):
        """
        Updates the EMA with the current state of the model's parameters.
        """
        self.ema.update()

    def _ema_save(self, checkpoint) -> None:
        """
        Saves the EMA state to the checkpoint.

        Args:
            checkpoint (dict): The checkpoint dictionary where the EMA state will be saved.
        """
        checkpoint['ema'] = self.ema.state_dict()

    def _ema_load(self, checkpoint) -> None:
        """
        Loads the EMA state from the checkpoint.

        Args:
            checkpoint (dict): The checkpoint dictionary from which the EMA state will be loaded.
        """
        if self.ema is None:
            self.ema = ExponentialMovingAverage(self.model.parameters(), self.ema_decay)
        self.ema.load_state_dict(checkpoint['ema'])


class AutoEncoderManagerMixin:
    """
    A mixin class for automatically managing an AutoEncoder model instance.
    This allows you to treat the autoencoder similarly to a tokenizer:
      - Loading it from a checkpoint or constructing from hyperparameters.
      - Optionally freezing its parameters.
      - Syncing any tabular transform in the autoencoder to the main class.

    Usage:
        class MyModel(AutoEncoderManagerMixin, SomeOtherBaseClass):
            def freeze_ae(self):
                # (Optional) Overwrite to control whether the AE is frozen.
                return True

            def __init__(self, ae_cls='path.to.AutoEncoderClass', ae_ckpt_path=None, ae_hparams=None, *args, **kwargs):
                super().__init__(
                    ae_cls=ae_cls,
                    ae_ckpt_path=ae_ckpt_path,
                    ae_hparams=ae_hparams,
                    *args, **kwargs
                )
                # Additional initialization code...
    """

    def __new__(cls, *args, **kwargs):
        """
        Override default instantiation to attach a checkpoint-saving hook.
        This hook saves the autoencoder hyperparameters into the checkpoint.
        """
        cls.on_save_checkpoint = create_hook(cls.on_save_checkpoint, cls._save_ae_hparams)
        return super().__new__(cls)

    def _after_init(self, *args, **kwargs):
        kwargs = kwargs.get('kwargs', dict())
        if not hasattr(self, '_ae_cls'):
            self._ae_cls = kwargs.get('ae_hparams', dict()).get('_target_', None)

        ae_ckpt_path = kwargs.get('ae_ckpt_path', None)
        if not hasattr(self, '_ae_hparams'):
            self._ae_hparams = kwargs.get('ae_hparams', dict())
        # Load AE from checkpoint or from hyperparameters
        if ae_ckpt_path:
            self._load_ae_from_checkpoint(ae_ckpt_path)
        else:
            self._load_ae_from_hparams(self._ae_hparams)

        # Optionally freeze the autoencoder
        if not hasattr(self, 'freeze_ae') or self.freeze_ae():
            for param in self.ae.parameters():
                param.requires_grad = False
            self.ae.eval()
        return super()._after_init(*args, **kwargs)

    def _load_ae_from_checkpoint(self, ae_ckpt_path: str) -> None:
        """
        Load the autoencoder from a checkpoint file.

        Args:
            ae_ckpt_path (str):
                Path to the checkpoint file.

        Raises:
            AssertionError: If the checkpoint path does not exist.
        """
        assert os.path.isfile(ae_ckpt_path), f"{ae_ckpt_path} not found"

        ae, cfg = setup_model_from_checkpoint(ae_ckpt_path)
        checkpoint_data = torch.load(ae_ckpt_path, weights_only=False)
        self._ae_cls = cfg['_target_']
        self._ae_hparams = checkpoint_data['hyper_parameters']

        # If the autoencoder has an EMA mechanism, sync that state
        if hasattr(ae, 'ema'):
            ae.ema.copy_to()

        # Synchronize the tabular transform
        self._update_transform(ae)

    def _load_ae_from_hparams(self, ae_hparams: Dict[str, Any]) -> None:
        """
        Construct a new autoencoder using the provided hyperparameters.

        Args:
            ae_hparams (dict): Hyperparameters for the autoencoder.
        """
        ae = get_class(self._ae_cls)(**ae_hparams)
        self._ae_hparams = ae_hparams
        self._update_transform(ae)

    def _update_transform(self, ae):
        self.ae = ae
        self._transform = ae._transform
        self.model_flags['onehot'] = ae.model_flags['onehot']
        self.model_flags['scaler'] = ae.model_flags['scaler']
        self._refresh_schema()

    def _save_ae_hparams(self, checkpoint: Dict[str, Any]) -> None:
        """
        Hook function to save the autoencoder hyperparameters into the checkpoint.

        Args:
            checkpoint (dict):
                The checkpoint dictionary that is being saved.
        """
        checkpoint['hyper_parameters']['ae_hparams'] = self._ae_hparams

    def encode(self, *args, **kwargs):
        return self.ae.encode(*args, **kwargs)

    def decode(self, *args, **kwargs):
        return self.ae.decode(*args, **kwargs)
