from typing import Any, Optional

import torch
from pydantic import BaseModel, Field, ConfigDict

from src.optimizers.optimizers import get_optimizer
from src.utils.load import load_from_state_dict


class OptimizerConfig(BaseModel):
    model_config: ConfigDict = ConfigDict(frozen=True, strict=True, validate_assignment=True, extra='forbid')
    name: str = Field()
    learning_rate: float = Field()
    params: dict[str, Any] = Field(default={})
    load_path: Optional[str] = Field(default=None)
    load_keys: list[str] = Field(default_factory=list)

    def get_optimizer(
            self,
            model: torch.nn.Module,
            batch_repeats: int,
            learning_rate_batch_repeats: bool
    ) -> torch.optim.Optimizer:
        optimizer: torch.optim.Optimizer = get_optimizer(
            model,
            self.name,
            self.learning_rate / batch_repeats if learning_rate_batch_repeats else self.learning_rate,
            self.params
        )
        return load_from_state_dict(
            entity=optimizer,
            load_path=self.load_path,
            load_keys=self.load_keys
        )
