from functools import partial
import json
import math
import warnings

from jsonmerge import merge

from . import augmentation, layers, models, utils


def load_config(file):
    defaults = {
        'model': {
            'sigma_data': 1.,
            'patch_size': 1,
            'dropout_rate': 0.,
            'augment_wrapper': True,
            'augment_prob': 0.,
            'mapping_cond_dim': 0,
            'unet_cond_dim': 0,
            'cross_cond_dim': 0,
            'cross_attn_depths': None,
            'skip_stages': 0,
            'has_variance': False,
            'loss_config': 'karras',
        },
        'dataset': {
            'type': 'imagefolder',
        },
        'optimizer': {
            'type': 'adamw',
            'lr': 1e-4,
            'betas': [0.95, 0.999],
            'eps': 1e-6,
            'weight_decay': 1e-3,
        },
        'lr_sched': {
            'type': 'constant',
        },
        'ema_sched': {
            'type': 'inverse',
            'power': 0.6667,
            'max_value': 0.9999
        },
    }
    config = json.load(file)
    return merge(defaults, config)


def make_model(config):
    config = config['model']
    if config['type'] == 'openai':
        from external import OpenAIDenoiser
        from guided_diffusion import dist_util
        from guided_diffusion.script_util import (
            NUM_CLASSES,
            model_and_diffusion_defaults,
            create_model_and_diffusion,
            args_to_dict,
            add_dict_to_argparser
        )
        from condition.diffpir_utils import utils_model
        openai_args = utils_model.create_argparser(config['model']['openai']).parse_args([])
        inner_model, diffusion = create_model_and_diffusion(**args_to_dict(openai_args, model_and_diffusion_defaults().keys()))
        model = OpenAIDenoiser(inner_model, diffusion)
    else:
        if config['type'] == 'image_v1':
            ImageDenoiser = models.ImageDenoiserModelV1
        elif config['type'] == 'image_v2':
            ImageDenoiser = models.ImageDenoiserModelV2
        else:
            raise ValueError('Invalid denoiser type')
        model = ImageDenoiser(
            config['input_channels'],
            config['mapping_out'],
            config['depths'],
            config['channels'],
            config['self_attn_depths'],
            config['cross_attn_depths'],
            patch_size=config['patch_size'],
            dropout_rate=config['dropout_rate'],
            mapping_cond_dim=config['mapping_cond_dim'] + (9 if config['augment_wrapper'] else 0),
            unet_cond_dim=config['unet_cond_dim'],
            cross_cond_dim=config['cross_cond_dim'],
            skip_stages=config['skip_stages'],
            has_variance=config['has_variance'],
        )
    if config['augment_wrapper']:
        model = augmentation.KarrasAugmentWrapper(model)
    return model


def make_denoiser_wrapper(config):
    config = config['model']
    sigma_data = config.get('sigma_data', 1.)
    has_variance = config.get('has_variance', False)
    loss_config = config.get('loss_config', 'karras')
    if loss_config == 'karras':
        if not has_variance:
            return partial(layers.Denoiser, sigma_data=sigma_data)
        return partial(layers.DenoiserWithVariance, sigma_data=sigma_data)
    if loss_config == 'simple':
        if has_variance:
            raise ValueError('Simple loss config does not support a variance output')
        return partial(layers.SimpleLossDenoiser, sigma_data=sigma_data)
    raise ValueError('Unknown loss config type')


def make_sample_density(config):
    sd_config = config['sigma_sample_density']
    sigma_data = config['sigma_data']
    if sd_config['type'] == 'lognormal':
        loc = sd_config['mean'] if 'mean' in sd_config else sd_config['loc']
        scale = sd_config['std'] if 'std' in sd_config else sd_config['scale']
        return partial(utils.rand_log_normal, loc=loc, scale=scale)
    if sd_config['type'] == 'loglogistic':
        loc = sd_config['loc'] if 'loc' in sd_config else math.log(sigma_data)
        scale = sd_config['scale'] if 'scale' in sd_config else 0.5
        min_value = sd_config['min_value'] if 'min_value' in sd_config else 0.
        max_value = sd_config['max_value'] if 'max_value' in sd_config else float('inf')
        return partial(utils.rand_log_logistic, loc=loc, scale=scale, min_value=min_value, max_value=max_value)
    if sd_config['type'] == 'loguniform':
        min_value = sd_config['min_value'] if 'min_value' in sd_config else config['sigma_min']
        max_value = sd_config['max_value'] if 'max_value' in sd_config else config['sigma_max']
        return partial(utils.rand_log_uniform, min_value=min_value, max_value=max_value)
    if sd_config['type'] in {'v-diffusion', 'cosine'}:
        min_value = sd_config['min_value'] if 'min_value' in sd_config else 1e-3
        max_value = sd_config['max_value'] if 'max_value' in sd_config else 1e3
        return partial(utils.rand_v_diffusion, sigma_data=sigma_data, min_value=min_value, max_value=max_value)
    if sd_config['type'] == 'split-lognormal':
        loc = sd_config['mean'] if 'mean' in sd_config else sd_config['loc']
        scale_1 = sd_config['std_1'] if 'std_1' in sd_config else sd_config['scale_1']
        scale_2 = sd_config['std_2'] if 'std_2' in sd_config else sd_config['scale_2']
        return partial(utils.rand_split_log_normal, loc=loc, scale_1=scale_1, scale_2=scale_2)
    raise ValueError('Unknown sample density type')
