from typing import Any

import torch.optim

from utils.logger.logger import Logger
from utils.utils import get_class_name


def get_optimizer(
        model: torch.nn.Module,
        optimizer_name: str,
        learning_rate: float,
        optimizer_params: dict[str, Any] = None
) -> torch.optim.Optimizer:
    Logger.debug(
        f'{get_class_name(get_optimizer)} - '
        f'optimizer_name: {optimizer_name}, '
        f'learning_rate: {learning_rate}, '
        f'optimizer_params: {optimizer_params}'
        f'')
    assert isinstance(model, torch.nn.Module), 'model must be a torch.nn.Module'
    assert isinstance(optimizer_name, str), 'optimizer_name must be a str'
    assert isinstance(learning_rate, float), 'learning_rate must be a float'
    if optimizer_params is None:
        optimizer_params: dict[str, Any] = {}
    assert isinstance(optimizer_params, dict), 'optimizer_params must be a dict'
    if optimizer_name == 'adam':
        assert all(
            optimizer_param in ['betas', 'eps', 'weight_decay', 'amsgrad']
            for optimizer_param in optimizer_params.keys()
        ), f'unknown optimizer param: {optimizer_params.keys()}'
        return torch.optim.Adam(model.parameters(), lr=learning_rate, **optimizer_params)
    if optimizer_name == 'r-adam':
        assert all(
            optimizer_param in ['betas', 'eps', 'weight_decay']
            for optimizer_param in optimizer_params.keys()
        ), f'unknown optimizer param: {optimizer_params.keys()}'
        return torch.optim.RAdam(model.parameters(), lr=learning_rate, **optimizer_params)
    else:
        raise ValueError(f'unknown optimizer name: {optimizer_name}')
