import copy
from typing import Optional, Any

import torch
from ema_pytorch import EMA
from pydantic import BaseModel, ConfigDict, Field

from src.utils.load import load_from_state_dict


class EMAConfig(BaseModel):
    model_config: ConfigDict = ConfigDict(frozen=True, strict=True, validate_assignment=True, extra='forbid')
    beta: float = Field()
    update_every: int = Field()
    update_after_step: int = Field(default=100)
    inv_gamma: float = Field(default=1.0)
    power: float = Field(default=2 / 3)
    start_step: Optional[int] = Field(default=None)
    allow_different_devices: bool = Field(default=True)
    load_path: Optional[str] = Field(default=None)
    load_keys: list[str] = Field(default_factory=list)

    def get_ema(self, model: torch.nn.Module) -> EMA:
        ema: EMA = EMA(
            model=model,
            beta=self.beta,
            update_every=self.update_every,
            update_after_step=self.update_after_step,
            inv_gamma=self.inv_gamma,
            power=self.power,
            allow_different_devices=self.allow_different_devices
        )
        ema.copy_params_from_model_to_ema()
        ema.initted.data.copy_(torch.tensor(True))
        load_from_state_dict(
            entity=ema.ema_model,
            load_path=self.load_path,
            load_keys=self.load_keys
        )
        if self.start_step is not None:
            ema.step += self.start_step
        return ema

    def get_tuple(self) -> tuple:
        return (
            self.beta,
            self.update_every,
            self.update_after_step,
            round(self.inv_gamma, 5),
            round(self.power, 5)
        )

    def get_str(self) -> str:
        return (
            f'beta_{self.beta}_'
            f'update_every_{self.update_every}_'
            f'update_after_step_{self.update_after_step}_'
            f'inv_gamma_{round(self.inv_gamma, 5)}_'
            f'power_{round(self.power, 5)}'
        )
