from typing import Tuple
import torch.nn as nn

from .quant import VectorQuantizer2
from .var import VAR
from .vqvae import VQVAE

# from dataclasses import dataclass, field
# from typing import List

def build_vae_var(
    # Shared args
    device, patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16),   # 10 steps by default
    # VQVAE args
    V=4096, Cvae=32, ch=160, share_quant_resi=4, using_znorm=False,
    # VAR args
    num_classes=1000, depth=16, shared_aln=False, attn_l2_norm=True,
    flash_if_available=True, fused_if_available=True,
    init_adaln=0.5, init_adaln_gamma=1e-5, init_head=0.02, init_std=-1,    # init_std < 0: automated
) -> Tuple[VQVAE, VAR]:
    heads = depth
    width = depth * 64
    dpr = 0.1 * depth/24
    
    # disable built-in initialization for speed
    for clz in (nn.Linear, nn.LayerNorm, nn.BatchNorm2d, nn.SyncBatchNorm, nn.Conv1d, nn.Conv2d, nn.ConvTranspose1d, nn.ConvTranspose2d):
        setattr(clz, 'reset_parameters', lambda self: None)
    
    # build models


################ Warning: using_znorm=True ############################ 
    vae_local = VQVAE(vocab_size=V, z_channels=Cvae, ch=ch, test_mode=True, share_quant_resi=share_quant_resi, v_patch_nums=patch_nums, using_znorm=using_znorm).to(device)
    
    # configs = ModelArgs()
    #vae_local = VQVAE(configs)
    
    var_wo_ddp = VAR(
        vae_local=vae_local,
        num_classes=num_classes, depth=depth, embed_dim=width, num_heads=heads, drop_rate=0., attn_drop_rate=0., drop_path_rate=dpr,
        norm_eps=1e-6, shared_aln=shared_aln, cond_drop_rate=0.1,
        attn_l2_norm=attn_l2_norm,
        patch_nums=patch_nums,
        flash_if_available=flash_if_available, fused_if_available=fused_if_available,
    ).to(device)
    var_wo_ddp.init_weights(init_adaln=init_adaln, init_adaln_gamma=init_adaln_gamma, init_head=init_head, init_std=init_std)
    
    # print('I am here')
    return vae_local, var_wo_ddp


# @dataclass
# class ModelArgs:
#     codebook_size: int = 16384 
#     codebook_embed_dim: int = 32
#     codebook_l2_norm: bool = True # l2 norm for codebook (original VAR False)
#     codebook_show_usage: bool = True 
#     commit_loss_beta: float = 0.25 
#     entropy_loss_ratio: float = 0.0 # LFQ entropy loss ratio (useless)
    
#     encoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4]) # encoder channel multiplier
#     decoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4]) # decoder channel multiplier
#     z_channels: int = 256 # channel number of the output of the encoder
#     dropout_p: float = 0.0 # dropout rate

#     v_patch_nums: List[int] = field(default_factory=lambda: [1, 2, 3, 4, 5, 6, 8, 10, 13, 16])
#     enc_type: str = 'cnn'
#     dec_type: str = 'cnn'
#     semantic_guide: str = 'none' # semantic guide type (ImageFolder)
#     detail_guide: str = 'none' # detail guide type (ImageFolder)
#     num_latent_tokens: int = 256 # latent token number (ImageFolder)
#     encoder_model: str = 'vit_small_patch14_dinov2.lvd142m' # encoder model (ImageFolder)
#     decoder_model: str = 'vit_small_patch14_dinov2.lvd142m' # decoder model (ImageFolder)
#     abs_pos_embed: bool = False # absolute position embedding (ImageFolder)

#     share_quant_resi: int = 4 # share quant residual (VAR)
#     product_quant: int = 1 # product quant (ImageFolder)
    
#     codebook_drop: float = 0.0 # codebook (ImageFolder)
#     half_sem: bool = False  # chunk operationImageFolder
#     start_drop: int = 1 # start drop (ImageFolder)
#     sem_loss_weight: float = 0.1 # semantic loss weight (ImageFolder)
#     detail_loss_weight: float = 0.1 # detail loss weight (ImageFolder)

#     clip_norm: bool = False # clip norm

#     sem_loss_scale: float = 1.0 # semantic loss scale (ImageFolder)
#     detail_loss_scale: float = 1.0 # detail loss scale (ImageFolder)
#     guide_type_1: str = "none" # Semantic guide type (ImageFolder) useless for our experiment
#     guide_type_2: str = "none" # Detail guide type(ImageFolder) useless for our experiment

#     lfq: bool = False # LFQ (LFQ)
#     scale: float = 1.0 # scale (LFQ)
#     soft_entropy: bool = True # soft entropy (LFQ)

#     dependency_loss_weight: float = 0.0 # prodecut loss weight (ImageFolder)

#     test_model: bool = False
    
#     ae_training: bool = False
