import argparse
import pickle
import numpy as np
import torch.nn as nn

from .karras_diffusion import KarrasDenoiser
from .unet import UNetModel
import cm.enc_dec_lib as enc_dec_lib
from cm.resample import create_named_schedule_sampler
from ldm.modules.diffusionmodules.openaimodel import UNetModel as SR_UNetModel
from cm.networks import EDMPrecond_CTM

NUM_CLASSES = 1000

def ctm_data_defaults():
    return dict(
        data_name='cifar10',
        train_classes=-1,
        type='png',
        sigma_data=0.5,
    )

def ctm_loss_defaults():
    return dict(
        diffusion_training=False,
        denoising_weight=0.,
        discriminator_training=False,
        discriminator_free_target=False,
        discriminator_weight=1.,
        discriminator_input_channel=512,
        discriminator_start_itr=10000,
        discriminator_input='latent',
        feature_aggregated=False,
        decoded_loss=False,
        apply_adaptive_weight=True,
        use_d_fp16=False,
        d_loss='hinge',
        d_architecture='unet',
        lazy_reg=10,
        r1_gamma=0.05,
        g_learning_period=2,
        gan_target='denoised',
        hinge_value=1.,
        d_out_res=8,
        d_learning_period=1,
        consistency_weight=1.0,
        sample_s_strategy='uniform',
        heun_step_strategy='uniform',
        heun_step_multiplier=1.0,
        auxiliary_type='stop_grad',
        diffusion_mult = 0.7,
        diffusion_schedule_sampler='lognormal',
        gan_estimate_type='same',
        embed=False,
        discriminator_fix=False,
        discriminator_checkpoint='',
        channelwise_normalization=True,
        cm_ratio=0.0,
        augment=False,
    )

def ctm_train_defaults():
    return dict(
        sampling_batch=64,
        sample_interval=1000,
        sampling_steps=18,
        time_conditioned_classifier=False,
        classifier_model_path="",
        parametrization='euler',
        out_res=-1,
        classifier_pool='attention',
        clip_denoised=True,
        clip_output=True,
        beta_min=0.1,
        beta_max=20.,
        multiplier=1.,
        load_optimizer=True,
        num_heun_step=1,
        num_heun_step_random=False,
        d_lr=0.0004,
        loss_type_='l2',
        out_channels=1000,
        cos_t_classifier=False,
        in_channels=3,
        deterministic=False,
        time_continuous=False,
        edm_nn_ncsn=False,
        edm_nn_ddpm=False,
        inner_parametrization='no',
        save_period=100,
        d_apply_adaptive_weight=False,
        shift_ratio=0.125,
        cutout_ratio=0.2,
    )

def ctm_eval_defaults():
    return dict(
        eval_interval=10000,
        eval_num_samples=10000,
        eval_batch=-1,
        ref_path='',
        large_log=False,
        compute_ema_fids=False,
    )

def cm_train_defaults():
    return dict(
        teacher_model_path="",
        teacher_dropout=0.1,
        training_mode="ctm",
        target_ema_mode="fixed",
        scale_mode="fixed",
        total_training_steps=600000,
        start_ema=0.0,
        start_scales=40,
        end_scales=40,
        distill_steps_per_iter=50000,
        loss_norm="lpips",
    )


def model_and_diffusion_defaults():
    """
    Defaults for image training.
    """
    res = dict(
        sigma_min=0.002,
        sigma_max=80.0,
        rho=7,
        image_size=64,
        num_channels=128,
        num_res_blocks=2,
        num_heads=4,
        num_heads_upsample=-1,
        num_head_channels=-1,
        attention_resolutions="32,16,8",
        channel_mult="",
        dropout=0.0,
        class_cond=False,
        use_checkpoint=False,
        use_scale_shift_norm=True,
        resblock_updown=False,
        use_fp16=False,
        use_new_attention_order=False,
        learn_sigma=False,
        weight_schedule="karras",
        diffusion_weight_schedule="karras_weight",
    )
    return res


def create_model_and_diffusion(args, feature_networks=None, teacher=False):
    schedule_sampler = create_named_schedule_sampler(args, args.schedule_sampler, args.start_scales)
    diffusion_schedule_sampler = create_named_schedule_sampler(args, args.diffusion_schedule_sampler, args.start_scales)
    if args.data_name in ['church']:
        try:
            channel_mult = [int(x) for x in args.channel_mult.split(',')]
        except:
            channel_mult = args.channel_mult
        try:
            attention_resolutions = [int(x) for x in args.attention_resolutions.split(',')]
        except:
            attention_resolutions = args.attention_resolutions
        vpsde = enc_dec_lib.vpsde(beta_min=1.5, beta_max=15.5, multiplier=2.)
        model = SR_UNetModel(image_size=args.image_size, in_channels=args.in_channels,
                              num_classes=None, model_channels=args.num_channels,
                              num_res_blocks=args.num_res_blocks, channel_mult=channel_mult,
                              num_heads=args.num_heads, num_head_channels=args.num_head_channels,
                              num_heads_upsample=args.num_heads_upsample, attention_resolutions=attention_resolutions,
                              dropout=args.dropout, use_checkpoint=args.use_checkpoint,
                              use_scale_shift_norm=args.use_scale_shift_norm, resblock_updown=args.resblock_updown,
                              use_fp16=args.use_fp16, use_new_attention_order=args.use_new_attention_order,
                              training_mode=('teacher' if teacher else args.training_mode), vpsde=vpsde, data_std=args.sigma_data,
                              )
    else:
        if args.edm_nn_ncsn or args.edm_nn_ddpm:
            assert args.data_name.lower() == 'cifar10'
            model = EDMPrecond_CTM(img_resolution=args.image_size, img_channels=args.in_channels,
                                   label_dim=10 if args.class_cond else 0, use_fp16=args.use_fp16,
                                   sigma_min=args.sigma_min, sigma_max=args.sigma_max,
                                   sigma_data=args.sigma_data, model_type='SongUNet',
                                   teacher=teacher, teacher_model_path=args.teacher_model_path or args.model_path if teacher else None,
                                   training_mode=args.training_mode, arch='ddpmpp' if args.edm_nn_ddpm else 'ncsnpp')
        else:
            model = create_model(
                args.image_size,
                args.num_channels,
                args.num_res_blocks,
                channel_mult=args.channel_mult,
                learn_sigma=args.learn_sigma,
                class_cond=args.class_cond,
                use_checkpoint=args.use_checkpoint,
                attention_resolutions=args.attention_resolutions,
                num_heads=args.num_heads,
                num_head_channels=args.num_head_channels,
                num_heads_upsample=args.num_heads_upsample,
                use_scale_shift_norm=args.use_scale_shift_norm,
                dropout=args.dropout,
                resblock_updown=args.resblock_updown,
                use_fp16=args.use_fp16,
                use_new_attention_order=args.use_new_attention_order,
                training_mode=('teacher' if teacher else args.training_mode),
            )
    diffusion = KarrasDenoiser(
        args=args, schedule_sampler=schedule_sampler,
        diffusion_schedule_sampler=diffusion_schedule_sampler,
        feature_networks=feature_networks,
    )
    return model, diffusion

def create_model(
    image_size,
    num_channels,
    num_res_blocks,
    channel_mult="",
    learn_sigma=False,
    class_cond=False,
    use_checkpoint=False,
    attention_resolutions="16",
    num_heads=1,
    num_head_channels=-1,
    num_heads_upsample=-1,
    use_scale_shift_norm=False,
    dropout=0,
    resblock_updown=False,
    use_fp16=False,
    use_new_attention_order=False,
    training_mode='',
):
    if channel_mult == "":
        if image_size == 512:
            channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
        elif image_size == 256:
            channel_mult = (1, 1, 2, 2, 4, 4)
        elif image_size == 128:
            channel_mult = (1, 1, 2, 3, 4)
        elif image_size == 64:
            channel_mult = (1, 2, 3, 4)
        else:
            raise ValueError(f"unsupported image size: {image_size}")
    else:
        channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(","))

    attention_ds = []
    for res in attention_resolutions.split(","):
        attention_ds.append(image_size // int(res))

    return UNetModel(
        image_size=image_size,
        in_channels=3,
        model_channels=num_channels,
        out_channels=(3 if not learn_sigma else 6),
        num_res_blocks=num_res_blocks,
        attention_resolutions=tuple(attention_ds),
        dropout=dropout,
        channel_mult=channel_mult,
        num_classes=(NUM_CLASSES if class_cond else None),
        use_checkpoint=use_checkpoint,
        use_fp16=use_fp16,
        num_heads=num_heads,
        num_head_channels=num_head_channels,
        num_heads_upsample=num_heads_upsample,
        use_scale_shift_norm=use_scale_shift_norm,
        resblock_updown=resblock_updown,
        use_new_attention_order=use_new_attention_order,
        training_mode=training_mode,
    )


def create_ema_and_scales_fn(
    target_ema_mode,
    start_ema,
    scale_mode,
    start_scales,
    end_scales,
    total_steps,
    distill_steps_per_iter,
):
    def ema_and_scales_fn(step):
        if target_ema_mode == "fixed" and scale_mode == "fixed":
            target_ema = start_ema
            scales = start_scales
        elif target_ema_mode == "fixed" and scale_mode == "progressive":
            target_ema = start_ema
            scales = np.ceil(
                np.sqrt(
                    (step / total_steps) * ((end_scales + 1) ** 2 - start_scales**2)
                    + start_scales**2
                )
                - 1
            ).astype(np.int32)
            scales = np.maximum(scales, 1)
            scales = scales + 1

        elif target_ema_mode == "adaptive" and scale_mode == "progressive":
            scales = np.ceil(
                np.sqrt(
                    (step / total_steps) * ((end_scales + 1) ** 2 - start_scales**2)
                    + start_scales**2
                )
                - 1
            ).astype(np.int32)
            scales = np.maximum(scales, 1)
            c = -np.log(start_ema) * start_scales
            target_ema = np.exp(-c / scales)
            scales = scales + 1
        elif target_ema_mode == "fixed" and scale_mode == "progdist":
            distill_stage = step // distill_steps_per_iter
            scales = start_scales // (2**distill_stage)
            scales = np.maximum(scales, 2)

            sub_stage = np.maximum(
                step - distill_steps_per_iter * (np.log2(start_scales) - 1),
                0,
            )
            sub_stage = sub_stage // (distill_steps_per_iter * 2)
            sub_scales = 2 // (2**sub_stage)
            sub_scales = np.maximum(sub_scales, 1)

            scales = np.where(scales == 2, sub_scales, scales)

            target_ema = 1.0
        else:
            raise NotImplementedError

        return float(target_ema), int(scales)

    return ema_and_scales_fn


def add_dict_to_argparser(parser, default_dict):
    for k, v in default_dict.items():
        v_type = type(v)
        if v is None:
            v_type = str
        elif isinstance(v, bool):
            v_type = str2bool
        parser.add_argument(f"--{k}", default=v, type=v_type)


def args_to_dict(args, keys):
    return {k: getattr(args, k) for k in keys}


def str2bool(v):
    """
    https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
    """
    if isinstance(v, bool):
        return v
    if v.lower() in ("yes", "true", "t", "y", "1"):
        return True
    elif v.lower() in ("no", "false", "f", "n", "0"):
        return False
    else:
        raise argparse.ArgumentTypeError("boolean value expected")
