from typing import Any
import torch

from utils.logger.logger import Logger
from utils.utils import get_class_name
from .lpips_loss import LPIPSLoss


def get_l2_loss() -> torch.nn.MSELoss:
    return torch.nn.MSELoss()


def get_l1_loss() -> torch.nn.L1Loss:
    return torch.nn.L1Loss()


def get_lpips_vgg_loss(size: int = None) -> torch.nn.Module:
    return LPIPSLoss(net='vgg', size=size)


def get_loss(loss_name: str, loss_params: dict[str, Any] = None) -> torch.nn.Module:
    Logger.debug(f'{get_class_name(get_loss)} - loss_name: {loss_name}, loss_params: {loss_params}')
    assert isinstance(loss_name, str), 'loss_name must be a str'
    if loss_params is None:
        loss_params: dict[str, Any] = {}
    assert isinstance(loss_params, dict), 'loss_params must be a dict'
    if loss_name == 'l2':
        assert len(loss_params) == 0, 'l2 loss has no params'
        Logger.debug(f'{get_class_name(get_loss)} - l2')
        return get_l2_loss()
    elif loss_name == 'l1':
        assert len(loss_params) == 0, 'l1 loss has no params'
        Logger.debug(f'{get_class_name(get_loss)} - l1')
        return get_l1_loss()
    elif loss_name == 'lpips-vgg':
        assert all(
            loss_param in ['size']
            for loss_param in loss_params.keys()
        ), f'unknown loss param: {loss_params.keys()}'
        size: int = loss_params.get('size', 224)
        Logger.debug(f'{get_class_name(get_loss)} - lpips-vgg - size: {size}')
        assert type(size) is int, 'size must be an int'
        return get_lpips_vgg_loss(size=size)
    else:
        raise ValueError(f'unknown loss name: {loss_name}')
