import argparse
# import pickle
import numpy as np
# import torch.nn as nn

from .karras_diffusion_v3 import KarrasDenoiser
from .unet import UNetModel
# import cm.enc_dec_lib as enc_dec_lib
from cm.resample import create_named_schedule_sampler
# from ldm.modules.diffusionmodules.openaimodel import UNetModel as SR_UNetModel
# from cm.networks import EDMPrecond_CTM
from tango_edm.models_edm import AudioDiffusionEDM


def ctm_train_defaults():
    return dict(
        # CTM hyperparams
        consistency_weight=1.0,
        loss_norm='l2', # l1, ictm
        loss_distance='l2',
        loss_domain='latent', # mel, waveform
        weight_schedule="uniform", # karras, uniform, snr, sq-snr
        parametrization='euler',
        inner_parametrization='edm',
        num_heun_step=39,
        num_heun_step_random=True,
        teacher_dropout=0.1, # NOTE: what's this?
        training_mode="ctm",
        match_point='zs', # zs
        target_ema_mode="fixed",
        scale_mode="fixed",
        start_ema=0.999, # NOTE: 0.999 for cifer10
        start_scales=40,
        end_scales=40,
        sigma_min=0.002,
        sigma_max=80.0,
        rho=7,
        latent_channels=8,
        latent_f_size=16,
        latent_t_size=256,
        
        cfg_distill=False,
        target_cfg=3.5,
        unform_sampled_cfg_distill=False,
        w_min=2.0,
        w_max=5.0,
        
        
        # DSM hyperparams
        diffusion_training=True,
        denoising_weight=1., # 1.0 for cifer 10
        diffusion_mult = 0.7,
        diffusion_schedule_sampler='halflognormal',
        apply_adaptive_weight=True,
        dsm_loss_target='z_0', # z_0 or z_target
        diffusion_weight_schedule="karras_weight", # "karras_weight"
        cm_ratio=0.0,
        augment=False,
    )
def gan_defaults():
    return dict(
        # GAN hyperparams
        discriminator_training=True,
        discriminator_input='latent',
        gan_target='z_0', # NOTE: z_0 or z_target
        
        sample_s_strategy='uniform',
        heun_step_strategy='weighted', # 'uniform'
        heun_step_multiplier=1.0,
        auxiliary_type='stop_grad',
        gan_estimate_type='same',

        discriminator_fix=False,
        discriminator_free_target=False,
        d_apply_adaptive_weight=True,
        discriminator_start_itr=3501,
        discriminator_weight=1.0,
        d_lr=0.0002,
        
        r1_reg_enable=False,
        reg_gamma=2.0, 
        
        d_architecture='DAC_GAN', # 'DAC_GAN', 'DAC_SAN', 'DAC_CGAN', 'DAC_CSAN', 'L_VQGAN', L_CVQGAN, MEL_VQGAN, MEL_CVQGAN
        dac_dis_rates=[],
        dac_dis_periods=[2, 3, 5, 7, 11],
        dac_dis_fft_sizes=[1024, 512, 256, 128],
        dac_dis_sample_rate=16000,
        dac_dis_bands=[(0.0, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)],
        
        d_cond_type='text_encoder', # 'text_encoder', 'clap_text_encoder'
        c_dim=1024, # 1024 for T5, 512 for CLAP
        cmap_dim=128, # 128 for T5, 64 for CLAP
        
        vqgan_ndf=64,
        vqgan_n_layers=3,
        vqgan_use_spectral_norm=False,
        
        mbdisc_ndf=64,
        n_bins=64,
        increase_ch=False,
        
        fm_apply_adaptive_weight=True,
        fm_weight=2.,
        
    )



def ctm_eval_defaults():
    return dict(
        intermediate_samples=False,
        compute_ema_fads=True,
        sampling_steps=18,
        # eval_num_samples=10000,
        # eval_batch=-1,
        ref_path='',
        large_log=False,
    )

def others_defaults():
    res = dict( 
        distill_steps_per_iter=50000,
        out_res=-1,
        clip_denoised=False,
        clip_output=False,
        beta_min=0.1,
        beta_max=20.,
        multiplier=1.,
        load_optimizer=True,
        
        num_channels=128,
        num_res_blocks=2,
        num_heads=4,
        num_heads_upsample=-1,
        num_head_channels=-1,
        attention_resolutions="32,16,8",
        channel_mult="",
        dropout=0.0,
        class_cond=False,
        use_checkpoint=False,
        use_scale_shift_norm=True,
        resblock_updown=False,
        # use_fp16=False,
        use_new_attention_order=False,
        learn_sigma=False,
        out_channels=8,
        in_channels=8, # adjusted to TAGNO
        deterministic=False,
        time_continuous=False,
    )
    return res


def create_model_and_diffusion(args, feature_networks=None, teacher=False):
    schedule_sampler = create_named_schedule_sampler(args, args.schedule_sampler, args.start_scales) # 'uniform', 40 (at model)
    diffusion_schedule_sampler = create_named_schedule_sampler(args, args.diffusion_schedule_sampler, args.start_scales) # 'halflognormal', 40 (at model)
    
    if args.tango:
        model = AudioDiffusionEDM(text_encoder_name=args.text_encoder_name,
                                  unet_model_name=None,
                                  unet_model_config_path=args.unet_model_config,
                                  sigma_data=args.sigma_data,
                                  freeze_text_encoder=args.freeze_text_encoder,
                                #   uncondition=args.uncondition,
                                #   precond_type=args.precond_type,
                                #   use_fp16=args.use_fp16,
                                #   force_fp32=args.force_fp32,
                                  teacher=teacher,
                                  ctm_unet_model_config_path=args.ctm_unet_model_config
                                  )
    else:
        raise NotImplementedError
        model = create_model(
            args.image_size,
            args.num_channels,
            args.num_res_blocks,
            channel_mult=args.channel_mult,
            learn_sigma=args.learn_sigma,
            class_cond=args.class_cond,
            use_checkpoint=args.use_checkpoint,
            attention_resolutions=args.attention_resolutions,
            num_heads=args.num_heads,
            num_head_channels=args.num_head_channels,
            num_heads_upsample=args.num_heads_upsample,
            use_scale_shift_norm=args.use_scale_shift_norm,
            dropout=args.dropout,
            resblock_updown=args.resblock_updown,
            use_fp16=args.use_fp16,
            use_new_attention_order=args.use_new_attention_order,
            training_mode=('teacher' if teacher else args.training_mode),
        )
    diffusion = KarrasDenoiser(
        args=args, schedule_sampler=schedule_sampler,
        diffusion_schedule_sampler=diffusion_schedule_sampler,
        feature_networks=feature_networks,
    )
    return model, diffusion

def create_model(
    image_size,
    num_channels,
    num_res_blocks,
    channel_mult="",
    learn_sigma=False,
    class_cond=False,
    use_checkpoint=False,
    attention_resolutions="16",
    num_heads=1,
    num_head_channels=-1,
    num_heads_upsample=-1,
    use_scale_shift_norm=False,
    dropout=0,
    resblock_updown=False,
    use_fp16=False,
    use_new_attention_order=False,
    training_mode='',
):
    if channel_mult == "":
        if image_size == 512:
            channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
        elif image_size == 256:
            channel_mult = (1, 1, 2, 2, 4, 4)
        elif image_size == 128:
            channel_mult = (1, 1, 2, 3, 4)
        elif image_size == 64:
            channel_mult = (1, 2, 3, 4)
        else:
            raise ValueError(f"unsupported image size: {image_size}")
    else:
        channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(","))

    attention_ds = []
    for res in attention_resolutions.split(","):
        attention_ds.append(image_size // int(res))

    return UNetModel(
        image_size=image_size,
        in_channels=3,
        model_channels=num_channels,
        out_channels=(3 if not learn_sigma else 6),
        num_res_blocks=num_res_blocks,
        attention_resolutions=tuple(attention_ds),
        dropout=dropout,
        channel_mult=channel_mult,
        # num_classes=(NUM_CLASSES if class_cond else None),
        use_checkpoint=use_checkpoint,
        # use_fp16=use_fp16,
        num_heads=num_heads,
        num_head_channels=num_head_channels,
        num_heads_upsample=num_heads_upsample,
        use_scale_shift_norm=use_scale_shift_norm,
        resblock_updown=resblock_updown,
        use_new_attention_order=use_new_attention_order,
        training_mode=training_mode,
    )


def create_ema_and_scales_fn(
    target_ema_mode,
    start_ema,
    scale_mode,
    start_scales,
    end_scales,
    total_steps,
    distill_steps_per_iter,
):
    def ema_and_scales_fn(step):
        if target_ema_mode == "fixed" and scale_mode == "fixed":
            target_ema = start_ema
            scales = start_scales
        elif target_ema_mode == "fixed" and scale_mode == "progressive":
            target_ema = start_ema
            scales = np.ceil(
                np.sqrt(
                    (step / total_steps) * ((end_scales + 1) ** 2 - start_scales**2)
                    + start_scales**2
                )
                - 1
            ).astype(np.int32)
            scales = np.maximum(scales, 1)
            scales = scales + 1

        elif target_ema_mode == "adaptive" and scale_mode == "progressive":
            scales = np.ceil(
                np.sqrt(
                    (step / total_steps) * ((end_scales + 1) ** 2 - start_scales**2)
                    + start_scales**2
                )
                - 1
            ).astype(np.int32)
            scales = np.maximum(scales, 1)
            c = -np.log(start_ema) * start_scales
            target_ema = np.exp(-c / scales)
            scales = scales + 1
        elif target_ema_mode == "fixed" and scale_mode == "progdist":
            distill_stage = step // distill_steps_per_iter
            scales = start_scales // (2**distill_stage)
            scales = np.maximum(scales, 2)

            sub_stage = np.maximum(
                step - distill_steps_per_iter * (np.log2(start_scales) - 1),
                0,
            )
            sub_stage = sub_stage // (distill_steps_per_iter * 2)
            sub_scales = 2 // (2**sub_stage)
            sub_scales = np.maximum(sub_scales, 1)

            scales = np.where(scales == 2, sub_scales, scales)

            target_ema = 1.0
        else:
            raise NotImplementedError

        return float(target_ema), int(scales)

    return ema_and_scales_fn


def add_dict_to_argparser(parser, default_dict):
    for k, v in default_dict.items():
        v_type = type(v)
        if v is None:
            v_type = str
        elif isinstance(v, bool):
            v_type = str2bool
        parser.add_argument(f"--{k}", default=v, type=v_type)


def args_to_dict(args, keys):
    return {k: getattr(args, k) for k in keys}


def str2bool(v):
    """
    https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
    """
    if isinstance(v, bool):
        return v
    if v.lower() in ("yes", "true", "t", "y", "1"):
        return True
    elif v.lower() in ("no", "false", "f", "n", "0"):
        return False
    else:
        raise argparse.ArgumentTypeError("boolean value expected")
