import argparse
import torch

from infinity.models.bsq_vae.flux_vqgan import AutoEncoder

def load_cnn(model, state_dict, prefix, expand=False, use_linear=False):
    delete_keys = []
    loaded_keys = []
    for key in state_dict:
        if key.startswith(prefix):
            _key = key[len(prefix):]
            if _key in model.state_dict():
                # load nn.Conv2d or nn.Linear to nn.Linear
                if use_linear and (".q.weight" in key or ".k.weight" in key or ".v.weight" in key or ".proj_out.weight" in key):
                    load_weights = state_dict[key].squeeze()
                elif _key.endswith(".conv.weight") and expand:
                    if model.state_dict()[_key].shape == state_dict[key].shape:
                        # 2D cnn to 2D cnn
                        load_weights = state_dict[key]
                    else:
                        # 2D cnn to 3D cnn
                        _expand_dim = model.state_dict()[_key].shape[2]
                        load_weights = state_dict[key].unsqueeze(2).repeat(1, 1, _expand_dim, 1, 1)
                else:
                    load_weights = state_dict[key]
                model.state_dict()[_key].copy_(load_weights)
                delete_keys.append(key)
                loaded_keys.append(prefix+_key)
            # load nn.Conv2d to Conv class
            conv_list = ["conv"] if use_linear else ["conv", ".q.", ".k.", ".v.", ".proj_out.", ".nin_shortcut."]
            if any(k in _key for k in conv_list):
                if _key.endswith(".weight"):
                    conv_key = _key.replace(".weight", ".conv.weight")
                    if conv_key and conv_key in model.state_dict():
                        if model.state_dict()[conv_key].shape == state_dict[key].shape:
                            # 2D cnn to 2D cnn
                            load_weights = state_dict[key]
                        else:
                            # 2D cnn to 3D cnn
                            _expand_dim = model.state_dict()[conv_key].shape[2]
                            load_weights = state_dict[key].unsqueeze(2).repeat(1, 1, _expand_dim, 1, 1)
                        model.state_dict()[conv_key].copy_(load_weights)
                        delete_keys.append(key)
                        loaded_keys.append(prefix+conv_key)
                if _key.endswith(".bias"):
                    conv_key = _key.replace(".bias", ".conv.bias")
                    if conv_key and conv_key in model.state_dict():
                        model.state_dict()[conv_key].copy_(state_dict[key])
                        delete_keys.append(key)
                        loaded_keys.append(prefix+conv_key)
            # load nn.GroupNorm to Normalize class
            if "norm" in _key:
                if _key.endswith(".weight"):
                    norm_key = _key.replace(".weight", ".norm.weight")
                    if norm_key and norm_key in model.state_dict():
                        model.state_dict()[norm_key].copy_(state_dict[key])
                        delete_keys.append(key)
                        loaded_keys.append(prefix+norm_key)
                if _key.endswith(".bias"):
                    norm_key = _key.replace(".bias", ".norm.bias")
                    if norm_key and norm_key in model.state_dict():
                        model.state_dict()[norm_key].copy_(state_dict[key])
                        delete_keys.append(key)
                        loaded_keys.append(prefix+norm_key)
            
    for key in delete_keys:
        del state_dict[key]

    return model, state_dict, loaded_keys


def vae_model(vqgan_ckpt, schedule_mode, codebook_dim, codebook_size, test_mode=True, patch_size=16, encoder_ch_mult=[1, 2, 4, 4, 4], decoder_ch_mult=[1, 2, 4, 4, 4],):
    args=argparse.Namespace(
        vqgan_ckpt=vqgan_ckpt,
        sd_ckpt=None,
        inference_type='image',
        save='./imagenet_val_bsq',
        save_prediction=True,
        image_recon4video=False,
        junke_old=False,
        device='cuda',
        max_steps=1000000.0,
        log_every=1,
        visu_every=1000,
        ckpt_every=1000,
        default_root_dir='',
        compile='no',
        ema='no',
        lr=0.0001,
        beta1=0.9,
        beta2=0.95,
        warmup_steps=0,
        optim_type='Adam',
        disc_optim_type=None,
        lr_min=0.0,
        warmup_lr_init=0.0,
        max_grad_norm=1.0,
        max_grad_norm_disc=1.0,
        disable_sch=False,
        patch_size=patch_size,
        temporal_patch_size=4,
        embedding_dim=256,
        codebook_dim=codebook_dim,
        num_quantizers=8,
        quantizer_type='MultiScaleBSQ',
        use_vae=False,
        use_freq_enc=False,
        use_freq_dec=False,
        preserve_norm=False,
        ln_before_quant=False,
        ln_init_by_sqrt=False,
        use_pxsf=False,
        new_quant=True,
        use_decay_factor=False,
        mask_out=False,
        use_stochastic_depth=False,
        drop_rate=0.0,
        schedule_mode=schedule_mode,
        lr_drop=None,
        lr_drop_rate=0.1,
        keep_first_quant=False,
        keep_last_quant=False,
        remove_residual_detach=False,
        use_out_phi=False,
        use_out_phi_res=False,
        use_lecam_reg=False,
        lecam_weight=0.05,
        perceptual_model='vgg16',
        base_ch_disc=64,
        random_flip=False,
        flip_prob=0.5,
        flip_mode='stochastic',
        max_flip_lvl=1,
        not_load_optimizer=False,
        use_lecam_reg_zero=False,
        freeze_encoder=False,
        rm_downsample=False,
        random_flip_1lvl=False,
        flip_lvl_idx=0,
        drop_when_test=False,
        drop_lvl_idx=0,
        drop_lvl_num=1,
        disc_version='v1',
        magvit_disc=False,
        sigmoid_in_disc=False,
        activation_in_disc='leaky_relu',
        apply_blur=False,
        apply_noise=False,
        dis_warmup_steps=0,
        dis_lr_multiplier=1.0,
        dis_minlr_multiplier=False,
        disc_channels=64,
        disc_layers=3,
        discriminator_iter_start=0,
        disc_pretrain_iter=0,
        disc_optim_steps=1,
        disc_warmup=0,
        disc_pool='no',
        disc_pool_size=1000,
        advanced_disc=False,
        recon_loss_type='l1',
        video_perceptual_weight=0.0,
        image_gan_weight=1.0,
        video_gan_weight=1.0,
        image_disc_weight=0.0,
        video_disc_weight=0.0,
        l1_weight=4.0,
        gan_feat_weight=0.0,
        perceptual_weight=0.0,
        kl_weight=0.0,
        lfq_weight=0.0,
        entropy_loss_weight=0.1,
        commitment_loss_weight=0.25,
        diversity_gamma=1,
        norm_type='group',
        disc_loss_type='hinge',
        use_checkpoint=False,
        precision='fp32',
        encoder_dtype='fp32',
        upcast_attention='',
        upcast_tf32=False,
        tokenizer='flux',
        pretrained=None,
        pretrained_mode='full',
        inflation_pe=False,
        init_vgen='no',
        no_init_idis=False,
        init_idis='keep',
        init_vdis='no',
        enable_nan_detector=False,
        turn_on_profiler=False,
        profiler_scheduler_wait_steps=10,
        debug=True,
        video_logger=False,
        bytenas='',
        username='',
        seed=1234,
        vq_to_vae=False,
        load_not_strict=False,
        zero=0,
        bucket_cap_mb=40,
        manual_gc_interval=1000,
        data_path=[''],
        data_type=[''],
        dataset_list=['imagenet'],
        fps=-1,
        dataaug='resizecrop',
        multi_resolution=False,
        random_bucket_ratio=0.0,
        sequence_length=16,
        resolution=[256, 256],
        batch_size=[1],
        num_workers=0,
        image_channels=3,
        codebook_size=codebook_size,
        codebook_l2_norm=True,
        codebook_show_usage=True,
        commit_loss_beta=0.25,
        entropy_loss_ratio=0.0,
        base_ch=128,
        num_res_blocks=2,
        encoder_ch_mult=encoder_ch_mult,
        decoder_ch_mult=decoder_ch_mult,
        dropout_p=0.0,
        cnn_type='2d',
        cnn_version='v1',
        conv_in_out_2d='no',
        conv_inner_2d='no',
        res_conv_2d='no',
        cnn_attention='no',
        cnn_norm_axis='spatial',
        flux_weight=0,
        cycle_weight=0,
        cycle_feat_weight=0,
        cycle_gan_weight=0,
        cycle_loop=0,
        z_drop=0.0)
    
    vae = AutoEncoder(args)
    use_vae = vae.use_vae
    if not use_vae:
        num_codes = args.codebook_size
    if isinstance(vqgan_ckpt, str):
        state_dict = torch.load(args.vqgan_ckpt, map_location=torch.device("cpu"), weights_only=True)
    else:
        state_dict = args.vqgan_ckpt
    if state_dict:
        if args.ema == "yes":
            vae, new_state_dict, loaded_keys = load_cnn(vae, state_dict["ema"], prefix="", expand=False)
        else:
            vae, new_state_dict, loaded_keys = load_cnn(vae, state_dict["vae"], prefix="", expand=False)
    if test_mode:
        vae.eval()
        [p.requires_grad_(False) for p in vae.parameters()]
    return vae