from typing import Any

from external.edm.training.networks import EDMPrecond
from utils.logger.logger import Logger
from utils.utils import get_class_name


def create_edm_model(model_name: str, model_params: dict[str, Any] = None) -> EDMPrecond:
    Logger.debug(f'{get_class_name(create_edm_model)} - model_name: {model_name}, model_params: {model_params}')
    assert isinstance(model_name, str), 'model_name must be a str'
    if model_params is None:
        model_params: dict[str, Any] = {}
    assert isinstance(model_params, dict), 'model_params must be a dict'
    if model_name == 'edm-cifar10-32x32-cond-vp':
        assert len(model_params) == 0, 'edm-cifar10-32x32-cond-vp has no params'
        Logger.debug(f'{get_class_name(create_edm_model)} - edm-cifar10-32x32-cond-vp')
        return EDMPrecond(
            img_resolution=32,
            img_channels=3,
            label_dim=10,
            model_type='SongUNet',
            channel_mult=[2, 2, 2],
            augment_dim=9,
            dropout=0.13
        )
    elif model_name == 'edm-cifar10-32x32-uncond-vp':
        assert len(model_params) == 0, 'edm-cifar10-32x32-uncond-vp has no params'
        Logger.debug(f'{get_class_name(create_edm_model)} - edm-cifar10-32x32-uncond-vp')
        return EDMPrecond(
            img_resolution=32,
            img_channels=3,
            model_type='SongUNet',
            channel_mult=[2, 2, 2],
            augment_dim=9,
            dropout=0.13
        )
    elif model_name == 'edm-afhqv2-64x64-uncond-vp':
        assert len(model_params) == 0, 'edm-afhqv2-64x64-uncond-vp has no params'
        Logger.debug(f'{get_class_name(create_edm_model)} - edm-afhqv2-64x64-uncond-vp')
        return EDMPrecond(
            img_resolution=64,
            img_channels=3,
            model_type='SongUNet',
            augment_dim=9,
            dropout=0.25
        )
    elif model_name == 'edm-imagenet-64x64-cond-adm':
        assert all(
            model_param in ['use_fp16']
            for model_param in model_params.keys()
        ), f'unknown model param: {model_params.keys()}'
        use_fp16: bool = model_params.get('use_fp16', True)
        Logger.debug(f'{get_class_name(create_edm_model)} - edm-imagenet-64x64-cond-adm - use_fp16: {use_fp16}')
        return EDMPrecond(
            img_resolution=64,
            img_channels=3,
            label_dim=1000,
            use_fp16=use_fp16
        )
    else:
        raise ValueError(f'unknown model name: {model_name}')
