import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"
import os.path as osp
from typing import List
import math
import time
import hashlib
import yaml
import argparse
import shutil
import re

import cv2
import numpy as np
import torch
torch._dynamo.config.cache_size_limit=64
import pandas as pd
from transformers import AutoTokenizer, T5EncoderModel, T5TokenizerFast
from PIL import Image, ImageEnhance
import torch.nn.functional as F
from torch.cuda.amp import autocast

import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
print(sys.path)

from infinity.models.infinity import Infinity,BInfinity
from infinity.models.basic import *
import PIL.Image as PImage
from torchvision.transforms.functional import to_tensor
from infinity.utils.dynamic_resolution import dynamic_resolution_h_w, h_div_w_templates
import pdb
from torchvision import transforms
from infinity.models.bitwise_self_correction import BitwiseSelfCorrection
from infinity.models.swinir import SwinIR
import importlib
from lora_diffusion import inject_trainable_lora
from transformers import BlipForConditionalGeneration,BlipProcessor

def get_obj_from_str(string: str, reload: bool = False):
    module, cls = string.rsplit(".", 1)
    if reload:
        module_imp = importlib.import_module(module)
        importlib.reload(module_imp)
    return getattr(importlib.import_module(module, package=None), cls)


def instantiate_from_config(config):
    if not "target" in config:
        raise KeyError("Expected key `target` to instantiate.")
    return get_obj_from_str(config["target"])(**config.get("params", dict()))

def extract_key_val(text):
    pattern = r'<(.+?):(.+?)>'
    matches = re.findall(pattern, text)
    key_val = {}
    for match in matches:
        key_val[match[0]] = match[1].lstrip()
    return key_val

def encode_prompt(text_tokenizer, text_encoder, prompt, enable_positive_prompt=False):
    if enable_positive_prompt:
        print(f'before positive_prompt aug: {prompt}')
        prompt = aug_with_positive_prompt(prompt)
        print(f'after positive_prompt aug: {prompt}')
    # print(f'prompt={prompt}')
    captions = [prompt]
    tokens = text_tokenizer(text=captions, max_length=512, padding='max_length', truncation=True, return_tensors='pt')  # todo: put this into dataset
    input_ids = tokens.input_ids.cuda(non_blocking=True)
    mask = tokens.attention_mask.cuda(non_blocking=True)
    text_features = text_encoder(input_ids=input_ids, attention_mask=mask)['last_hidden_state'].float()
    lens: List[int] = mask.sum(dim=-1).tolist()
    cu_seqlens_k = F.pad(mask.sum(dim=-1).to(dtype=torch.int32).cumsum_(0), (1, 0))
    Ltext = max(lens)    
    kv_compact = []
    for len_i, feat_i in zip(lens, text_features.unbind(0)):
        kv_compact.append(feat_i[:len_i])
    kv_compact = torch.cat(kv_compact, dim=0)
    text_cond_tuple = (kv_compact, lens, cu_seqlens_k, Ltext)
    return text_cond_tuple

def aug_with_positive_prompt(prompt):
    for key in ['man', 'woman', 'men', 'women', 'boy', 'girl', 'child', 'person', 'human', 'adult', 'teenager', 'employee', 
                'employer', 'worker', 'mother', 'father', 'sister', 'brother', 'grandmother', 'grandfather', 'son', 'daughter']:
        if key in prompt:
            prompt = prompt + '. very smooth faces, good looking faces, face to the camera, perfect facial features'
            break
    return prompt

def enhance_image(image):
    for t in range(1):
        contrast_image = image.copy()
        contrast_enhancer = ImageEnhance.Contrast(contrast_image)
        contrast_image = contrast_enhancer.enhance(1.05)  # 增强对比度
        color_image = contrast_image.copy()
        color_enhancer = ImageEnhance.Color(color_image)
        color_image = color_enhancer.enhance(1.05)  # 增强饱和度
    return color_image

def gen_one_img(
    infinity_test, 
    vae, 
    text_tokenizer,
    text_encoder,
    prompt, 
    cfg_list=[],
    tau_list=[],
    negative_prompt='',
    scale_schedule=None,
    top_k=900,
    top_p=0.97,
    cfg_sc=3,
    cfg_exp_k=0.0,
    cfg_insertion_layer=-5,
    vae_type=0,
    gumbel=0,
    softmax_merge_topk=-1,
    gt_leak=-1,
    gt_ls_Bl=None,
    g_seed=None,
    sampling_per_bits=1,
    enable_positive_prompt=0,
    lq_img_path='',
    gt_img_path='',
):
    sstt = time.time()
    
    lq_img = Image.open(lq_img_path)
    if lq_img.mode != "RGB":
        lq_img = lq_img.convert("RGB")
        
    # if scale_schedule[-1][-1]==16:
    #     lq_img = lq_img.resize((256,256))
    
    lq_img = lq_img.resize((512,512))

    transform = transforms.ToTensor()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    lq_img = transform(lq_img)
    lq_img = lq_img*2-1
    lq_img = lq_img.unsqueeze(0).to(device, non_blocking=True)
    
    ##### swinir
    swinir_config = {
        "target": "infinity.models.swinir.SwinIR",  
        "params": {                               
            "img_size": 64,
            "patch_size": 1,
            "in_chans": 3,
            "embed_dim": 180,
            "depths": [6, 6, 6, 6, 6, 6, 6, 6],
            "num_heads": [6, 6, 6, 6, 6, 6, 6, 6],
            "window_size": 8,
            "mlp_ratio": 2,
            "sf": 8,
            "img_range": 1.0,
            "upsampler": "nearest+conv",
            "resi_connection": "1conv",
            "unshuffle": True,
            "unshuffle_scale": 8
        }
    }    
    swinir: SwinIR = instantiate_from_config(swinir_config)
    sd = torch.load('weights/general_swinir_v1.ckpt', map_location="cpu")
    if "state_dict" in sd:
        sd = sd["state_dict"]
    sd = {
        (k[len("module.") :] if k.startswith("module.") else k): v
        for k, v in sd.items()
    }
    swinir.load_state_dict(sd, strict=True)
    for p in swinir.parameters():
        p.requires_grad = False
    swinir.eval().to(device)
    
    lq_img = (lq_img+1)/2
    lq_img = swinir(lq_img) 
    lq_img = lq_img + lq_img -1
    ##### swinir
    
    
    with torch.amp.autocast('cuda', enabled=False):
        with torch.no_grad():
            if infinity_test.apply_spatial_patchify:
                vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule]
            else:
                vae_scale_schedule = scale_schedule
            # raw_features, _, _ = vae.encode_for_raw_features(gt_img, scale_schedule=vae_scale_schedule)
            raw_features_lq, _, _ = vae.encode_for_raw_features(lq_img, scale_schedule=vae_scale_schedule)
            
    #####need to change
    bitwise_self_correction= BitwiseSelfCorrection(vae, args)
    x_BLC_wo_prefix_lq,_ = bitwise_self_correction.flip_requant(vae_scale_schedule, lq_img, raw_features_lq, device)
    
    if not isinstance(cfg_list, list):
        cfg_list = [cfg_list] * len(scale_schedule)
    if not isinstance(tau_list, list):
        tau_list = [tau_list] * len(scale_schedule)
    text_cond_tuple = encode_prompt(text_tokenizer, text_encoder, prompt, enable_positive_prompt)
    if negative_prompt:
        negative_label_B_or_BLT = encode_prompt(text_tokenizer, text_encoder, negative_prompt)
    else:
        negative_label_B_or_BLT = None
    print(f'cfg: {cfg_list}, tau: {tau_list}')
    with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True):
        stt = time.time()
        ### single step
        _, _, img_list = infinity_test.autoregressive_infer_cfg(
            vae=vae,
            scale_schedule=scale_schedule,
            label_B_or_BLT=text_cond_tuple, g_seed=g_seed,
            B=1, negative_label_B_or_BLT=negative_label_B_or_BLT, force_gt_Bhw=None,
            cfg_sc=cfg_sc, cfg_list=cfg_list, tau_list=tau_list, top_k=top_k, top_p=top_p,
            returns_vemb=1, ratio_Bl1=None, gumbel=gumbel, norm_cfg=False,
            cfg_exp_k=cfg_exp_k, cfg_insertion_layer=cfg_insertion_layer,
            vae_type=vae_type, softmax_merge_topk=softmax_merge_topk,
            ret_img=True, trunk_scale=1000,
            gt_leak=gt_leak, gt_ls_Bl=gt_ls_Bl, inference_mode=True,
            sampling_per_bits=sampling_per_bits,
            x_BLC_wo_prefix_lq=x_BLC_wo_prefix_lq,
        )
        # ###
        
    print(f"cost: {time.time() - sstt}, infinity cost={time.time() - stt}")
    img = img_list[0]
    return img

@torch.no_grad()
def gen_one_img_eval(
    infinity_test, 
    vae, 
    text_tokenizer,
    text_encoder,
    prompt,
    cfg_list=[],
    tau_list=[],
    negative_prompt='',
    scale_schedule=None,
    top_k=900,
    top_p=0.97,
    cfg_sc=3,
    cfg_exp_k=0.0,
    cfg_insertion_layer=-5,
    vae_type=0,
    gumbel=0,
    softmax_merge_topk=-1,
    gt_leak=-1,
    gt_ls_Bl=None,
    g_seed=None,
    sampling_per_bits=1,
    enable_positive_prompt=0,
    lq_img_path='',
    args=None,
    blip_model=None,
    blip_processor=None,
):
    sstt = time.time()
    
    lq_img = Image.open(lq_img_path)
    if lq_img.mode != "RGB":
        lq_img = lq_img.convert("RGB")
        
    # if scale_schedule[-1][-1]==16:
    #     lq_img = lq_img.resize((256,256))
    
    lq_img = lq_img.resize((512,512))

    transform = transforms.ToTensor()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    lq_img = transform(lq_img)
    lq_img = lq_img*2-1
    lq_img = lq_img.unsqueeze(0).to(device, non_blocking=True)
    
    ##### swinir    
    lq_img = (lq_img+1)/2
    lq_img = swinir(lq_img) 
    lq_img = lq_img + lq_img -1
    ##### swinir
    
    #####blip
    if not prompt:
        raw_image = Image.open(lq_img_path).convert('RGB')
        inputs = blip_processor(raw_image, return_tensors="pt").to("cuda", torch.float16)
        out = blip_model.generate(**inputs)
        prompt = blip_processor.decode(out[0], skip_special_tokens=True)
    #####
    
    with torch.amp.autocast('cuda', enabled=False):
        with torch.no_grad():
            if infinity_test.apply_spatial_patchify:
                vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule]
            else:
                vae_scale_schedule = scale_schedule
            # raw_features, _, _ = vae.encode_for_raw_features(gt_img, scale_schedule=vae_scale_schedule)
            raw_features_lq, _, _ = vae.encode_for_raw_features(lq_img, scale_schedule=vae_scale_schedule)
            
    #####need to change
    bitwise_self_correction= BitwiseSelfCorrection(vae, args)
    x_BLC_wo_prefix_lq,_ = bitwise_self_correction.flip_requant(vae_scale_schedule, lq_img, raw_features_lq, device)
    
    if not isinstance(cfg_list, list):
        cfg_list = [cfg_list] * len(scale_schedule)
    if not isinstance(tau_list, list):
        tau_list = [tau_list] * len(scale_schedule)
    text_cond_tuple = encode_prompt(text_tokenizer, text_encoder, prompt, enable_positive_prompt)
    if negative_prompt:
        negative_label_B_or_BLT = encode_prompt(text_tokenizer, text_encoder, negative_prompt)
    else:
        negative_label_B_or_BLT = None

    with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True):
        stt = time.time()
        ### single step
        _, _, img_list = infinity_test.autoregressive_infer_cfg(
            vae=vae,
            scale_schedule=scale_schedule,
            label_B_or_BLT=text_cond_tuple, g_seed=g_seed,
            B=1, negative_label_B_or_BLT=negative_label_B_or_BLT, force_gt_Bhw=None,
            cfg_sc=cfg_sc, cfg_list=cfg_list, tau_list=tau_list, top_k=top_k, top_p=top_p,
            returns_vemb=1, ratio_Bl1=None, gumbel=gumbel, norm_cfg=False,
            cfg_exp_k=cfg_exp_k, cfg_insertion_layer=cfg_insertion_layer,
            vae_type=vae_type, softmax_merge_topk=softmax_merge_topk,
            ret_img=True, trunk_scale=1000,
            gt_leak=gt_leak, gt_ls_Bl=gt_ls_Bl, inference_mode=True,
            sampling_per_bits=sampling_per_bits,
            x_BLC_wo_prefix_lq=x_BLC_wo_prefix_lq,
        )
        # ###
        
    # print(f"cost: {time.time() - sstt}, infinity cost={time.time() - stt}")
    img = img_list[0]
    return img,prompt

@torch.no_grad()
def gen_one_img_eval_long(
    infinity_test, 
    vae, 
    text_tokenizer,
    text_encoder,
    prompt,
    cfg_list=[],
    tau_list=[],
    negative_prompt='',
    scale_schedule=None,
    top_k=900,
    top_p=0.97,
    cfg_sc=3,
    cfg_exp_k=0.0,
    cfg_insertion_layer=-5,
    vae_type=0,
    gumbel=0,
    softmax_merge_topk=-1,
    gt_leak=-1,
    gt_ls_Bl=None,
    g_seed=None,
    sampling_per_bits=1,
    enable_positive_prompt=0,
    lq_img_path='',
    args=None,
    blip_model=None,
    blip_processor=None,
    swinir=None,
    bitwise_self_correction=None
):
    sstt = time.time()
    
    lq_img = Image.open(lq_img_path)
    if lq_img.mode != "RGB":
        lq_img = lq_img.convert("RGB")
        
    # if scale_schedule[-1][-1]==16:
    #     lq_img = lq_img.resize((256,256))
    
    lq_img = lq_img.resize((512,512))

    transform = transforms.ToTensor()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    lq_img = transform(lq_img)
    lq_img = lq_img*2-1
    lq_img = lq_img.unsqueeze(0).to(device, non_blocking=True)
    
    ##### swinir
    # swinir_config = {
    #     "target": "infinity.models.swinir.SwinIR",  
    #     "params": {                               
    #         "img_size": 64,
    #         "patch_size": 1,
    #         "in_chans": 3,
    #         "embed_dim": 180,
    #         "depths": [6, 6, 6, 6, 6, 6, 6, 6],
    #         "num_heads": [6, 6, 6, 6, 6, 6, 6, 6],
    #         "window_size": 8,
    #         "mlp_ratio": 2,
    #         "sf": 8,
    #         "img_range": 1.0,
    #         "upsampler": "nearest+conv",
    #         "resi_connection": "1conv",
    #         "unshuffle": True,
    #         "unshuffle_scale": 8
    #     }
    # }    
    # swinir: SwinIR = instantiate_from_config(swinir_config)
    # sd = torch.load('weights/general_swinir_v1.ckpt', map_location="cpu")
    # if "state_dict" in sd:
    #     sd = sd["state_dict"]
    # sd = {
    #     (k[len("module.") :] if k.startswith("module.") else k): v
    #     for k, v in sd.items()
    # }
    # swinir.load_state_dict(sd, strict=True)
    # for p in swinir.parameters():
    #     p.requires_grad = False
    # swinir.eval().to(device)
    
    lq_img = (lq_img+1)/2
    lq_img = swinir(lq_img) 
    lq_img = lq_img + lq_img -1
    ##### swinir
    
    #####blip
    if not prompt:
        raw_image = Image.open(lq_img_path).convert('RGB')
        inputs = blip_processor(raw_image, return_tensors="pt").to("cuda", torch.float16)
        out = blip_model.generate(**inputs)
        prompt = blip_processor.decode(out[0], skip_special_tokens=True)
    #####
    
    with torch.amp.autocast('cuda', enabled=False):
        with torch.no_grad():
            if infinity_test.apply_spatial_patchify:
                vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule]
            else:
                vae_scale_schedule = scale_schedule
            # raw_features, _, _ = vae.encode_for_raw_features(gt_img, scale_schedule=vae_scale_schedule)
            raw_features_lq, _, _ = vae.encode_for_raw_features(lq_img, scale_schedule=vae_scale_schedule)
    stt_vae = time.time()
    #####need to change
    # x_BLC_wo_prefix_lq,_ = bitwise_self_correction.flip_requant(vae_scale_schedule, lq_img, raw_features_lq, device)
    # x_BLC_w_prefix_lq,_ = bitwise_self_correction.my_flip_requant(vae_scale_schedule, lq_img, raw_features_lq, device)
    # last_scale_length = scale_schedule[-1][0] * scale_schedule[-1][1] * scale_schedule[-1][2]
    # x_BLC_wo_prefix_lq_long = torch.cat([x_BLC_wo_prefix_lq,x_BLC_w_prefix_lq[:,-last_scale_length:,:]],dim = 1)
    
    x_BLC_wo_prefix_lq_long,_ = bitwise_self_correction.long_flip_requant(vae_scale_schedule, lq_img, raw_features_lq, device)
    #####
    
    stt_encode = time.time()
    if not isinstance(cfg_list, list):
        cfg_list = [cfg_list] * len(scale_schedule)
    if not isinstance(tau_list, list):
        tau_list = [tau_list] * len(scale_schedule)
    text_cond_tuple = encode_prompt(text_tokenizer, text_encoder, prompt, enable_positive_prompt)
    if negative_prompt:
        negative_label_B_or_BLT = encode_prompt(text_tokenizer, text_encoder, negative_prompt)
    else:
        negative_label_B_or_BLT = None

    with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True):
        stt = time.time()
        ### single step
        _, _, img_list = infinity_test.autoregressive_infer_cfg(
            vae=vae,
            scale_schedule=scale_schedule,
            label_B_or_BLT=text_cond_tuple, g_seed=g_seed,
            B=1, negative_label_B_or_BLT=negative_label_B_or_BLT, force_gt_Bhw=None,
            cfg_sc=cfg_sc, cfg_list=cfg_list, tau_list=tau_list, top_k=top_k, top_p=top_p,
            returns_vemb=1, ratio_Bl1=None, gumbel=gumbel, norm_cfg=False,
            cfg_exp_k=cfg_exp_k, cfg_insertion_layer=cfg_insertion_layer,
            vae_type=vae_type, softmax_merge_topk=softmax_merge_topk,
            ret_img=True, trunk_scale=1000,
            gt_leak=gt_leak, gt_ls_Bl=gt_ls_Bl, inference_mode=True,
            sampling_per_bits=sampling_per_bits,
            x_BLC_wo_prefix_lq=x_BLC_wo_prefix_lq_long,
        )
        # ###
        
    # print(f"cost: {time.time() - sstt}, infinity cost={time.time() - stt}, \
    #        encode cost={stt- stt_encode},vae cost={stt_vae- sstt},bit cost{stt_encode-stt_vae}")
    img = img_list[0]
    return img,prompt

def get_prompt_id(prompt):
    md5 = hashlib.md5()
    md5.update(prompt.encode('utf-8'))
    prompt_id = md5.hexdigest()
    return prompt_id

def save_slim_model(infinity_model_path, save_file=None, device='cpu', key='gpt_fsdp'):
    print('[Save slim model]')
    full_ckpt = torch.load(infinity_model_path, map_location=device)
    infinity_slim = full_ckpt['trainer'][key]
    # ema_state_dict = cpu_d['trainer'].get('gpt_ema_fsdp', state_dict)
    if not save_file:
        save_file = osp.splitext(infinity_model_path)[0] + '-slim.pth'
    print(f'Save to {save_file}')
    torch.save(infinity_slim, save_file)
    print('[Save slim model] done')
    return save_file

def load_tokenizer(t5_path =''):
    print(f'[Loading tokenizer and text encoder]')
    text_tokenizer: T5TokenizerFast = AutoTokenizer.from_pretrained(t5_path, revision=None, legacy=True)
    text_tokenizer.model_max_length = 512
    text_encoder: T5EncoderModel = T5EncoderModel.from_pretrained(t5_path, torch_dtype=torch.float16)
    text_encoder.to('cuda')
    text_encoder.eval()
    text_encoder.requires_grad_(False)
    return text_tokenizer, text_encoder

def load_infinity(
    rope2d_each_sa_layer, 
    rope2d_normalized_by_hw, 
    use_scale_schedule_embedding, 
    pn, 
    use_bit_label, 
    add_lvl_embeding_only_first_block, 
    model_path='', 
    scale_schedule=None, 
    vae=None, 
    device='cuda', 
    model_kwargs=None,
    text_channels=2048,
    apply_spatial_patchify=0,
    use_flex_attn=False,
    bf16=False,
):
    print(f'[Loading Infinity]')
    text_maxlen = 512
    with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True), torch.no_grad():
        infinity_test: Infinity = BInfinity(
            vae_local=vae, text_channels=text_channels, text_maxlen=text_maxlen,
            shared_aln=True, raw_scale_schedule=scale_schedule,
            checkpointing='full-block',
            customized_flash_attn=False,
            fused_norm=True,
            pad_to_multiplier=1,
            use_flex_attn=use_flex_attn,
            add_lvl_embeding_only_first_block=add_lvl_embeding_only_first_block,
            use_bit_label=use_bit_label,
            rope2d_each_sa_layer=rope2d_each_sa_layer,
            rope2d_normalized_by_hw=rope2d_normalized_by_hw,
            pn=pn,
            apply_spatial_patchify=apply_spatial_patchify,
            inference_mode=True,
            train_h_div_w_list=[1.0],
            **model_kwargs,
        ).to(device=device)
        print(f'[you selected Infinity with {model_kwargs=}] model size: {sum(p.numel() for p in infinity_test.parameters())/1e9:.2f}B, bf16={bf16}')

        if bf16:
            for block in infinity_test.unregistered_blocks:
                block.bfloat16()

        ###
        unet_lora_params, train_names = inject_trainable_lora(infinity_test.block_chunks)
        ###
        
        infinity_test.eval()
        infinity_test.requires_grad_(False)

        infinity_test.cuda()
        torch.cuda.empty_cache()

        print(f'[Load Infinity weights]')
        # state_dict = torch.load(model_path, map_location=device)
        # print(infinity_test.load_state_dict(state_dict))
        ###
        checkpoint = torch.load(model_path,map_location=device)
        if 'infinity' in checkpoint.keys():
            print(infinity_test.load_state_dict(checkpoint['infinity'],strict=True))
        elif 'trainer' in checkpoint.keys():
            print(infinity_test.load_state_dict(checkpoint['trainer']['gpt_fsdp'],strict=True))
        else:
            print(infinity_test.load_state_dict(checkpoint,strict=True))
        
        # state_dict =  checkpoint['infinity']  
        # lora_params = {
        # k: v for k, v in state_dict.items() 
        # if 'lora' in k.lower()  
        # }
        # torch.save(lora_params, 'infinity_lora.pth')
        # pdb.set_trace()
        
        infinity_test.rng = torch.Generator(device=device)
        return infinity_test

def transform(pil_img, tgt_h, tgt_w):
    width, height = pil_img.size
    if width / height <= tgt_w / tgt_h:
        resized_width = tgt_w
        resized_height = int(tgt_w / (width / height))
    else:
        resized_height = tgt_h
        resized_width = int((width / height) * tgt_h)
    pil_img = pil_img.resize((resized_width, resized_height), resample=PImage.LANCZOS)
    # crop the center out
    arr = np.array(pil_img)
    crop_y = (arr.shape[0] - tgt_h) // 2
    crop_x = (arr.shape[1] - tgt_w) // 2
    im = to_tensor(arr[crop_y: crop_y + tgt_h, crop_x: crop_x + tgt_w])
    return im.add(im).add_(-1)

def joint_vi_vae_encode_decode(vae, image_path, scale_schedule, device, tgt_h, tgt_w):
    pil_image = Image.open(image_path).convert('RGB')
    inp = transform(pil_image, tgt_h, tgt_w)
    inp = inp.unsqueeze(0).to(device)
    scale_schedule = [(item[0], item[1], item[2]) for item in scale_schedule]
    t1 = time.time()
    h, z, _, all_bit_indices, _, infinity_input = vae.encode(inp, scale_schedule=scale_schedule)
    t2 = time.time()
    recons_img = vae.decode(z)[0]
    if len(recons_img.shape) == 4:
        recons_img = recons_img.squeeze(1)
    print(f'recons: z.shape: {z.shape}, recons_img shape: {recons_img.shape}')
    t3 = time.time()
    print(f'vae encode takes {t2-t1:.2f}s, decode takes {t3-t2:.2f}s')
    recons_img = (recons_img + 1) / 2
    recons_img = recons_img.permute(1, 2, 0).mul_(255).cpu().numpy().astype(np.uint8)
    gt_img = (inp[0] + 1) / 2
    gt_img = gt_img.permute(1, 2, 0).mul_(255).cpu().numpy().astype(np.uint8)
    print(recons_img.shape, gt_img.shape)
    return gt_img, recons_img, all_bit_indices

def load_visual_tokenizer(args):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # load vae
    if args.vae_type in [16,18,20,24,32,64]:
        from infinity.models.bsq_vae.vae import vae_model
        schedule_mode = "dynamic"
        codebook_dim = args.vae_type
        codebook_size = 2**codebook_dim
        if args.apply_spatial_patchify:
            patch_size = 8
            encoder_ch_mult=[1, 2, 4, 4]
            decoder_ch_mult=[1, 2, 4, 4]
        else:
            patch_size = 16
            encoder_ch_mult=[1, 2, 4, 4, 4]
            decoder_ch_mult=[1, 2, 4, 4, 4]
        vae = vae_model(args.vae_path, schedule_mode, codebook_dim, codebook_size, patch_size=patch_size, 
                        encoder_ch_mult=encoder_ch_mult, decoder_ch_mult=decoder_ch_mult, test_mode=True).to(device)
    else:
        raise ValueError(f'vae_type={args.vae_type} not supported')
    return vae

def load_visual_tokenizer_lora(args):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # load vae
    if args.vae_type in [16,18,20,24,32,64]:
        from infinity.models.bsq_vae.vae import vae_model_lora
        schedule_mode = "dynamic"
        codebook_dim = args.vae_type
        codebook_size = 2**codebook_dim
        if args.apply_spatial_patchify:
            patch_size = 8
            encoder_ch_mult=[1, 2, 4, 4]
            decoder_ch_mult=[1, 2, 4, 4]
        else:
            patch_size = 16
            encoder_ch_mult=[1, 2, 4, 4, 4]
            decoder_ch_mult=[1, 2, 4, 4, 4]
        vae = vae_model_lora(args.vae_path, schedule_mode, codebook_dim, codebook_size, patch_size=patch_size, 
                        encoder_ch_mult=encoder_ch_mult, decoder_ch_mult=decoder_ch_mult, test_mode=True).to(device)
    else:
        raise ValueError(f'vae_type={args.vae_type} not supported')
    return vae

def load_transformer(vae, args):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model_path = args.model_path
    # if args.checkpoint_type == 'torch': 
    #     assert ('ar-' in model_path) or ('slim-' in model_path)
    #     # copy large model to local, save slim to local, and copy slim to nas, and load local slim model
    #     if osp.exists(args.cache_dir):
    #         local_model_path = osp.join(args.cache_dir, 'tmp', model_path.replace('/', '_'))
    #     else:
    #         local_model_path = model_path
    #     slim_model_path = model_path.replace('ar-', 'slim-')
    #     local_slim_model_path = local_model_path.replace('ar-', 'slim-')
    #     os.makedirs(osp.dirname(local_slim_model_path), exist_ok=True)
    #     print(f'model_path: {model_path}, slim_model_path: {slim_model_path}')
    #     print(f'local_model_path: {local_model_path}, local_slim_model_path: {local_slim_model_path}')
    #     if not osp.exists(local_slim_model_path):
    #         if osp.exists(slim_model_path):
    #             print(f'copy {slim_model_path} to {local_slim_model_path}')
    #             shutil.copyfile(slim_model_path, local_slim_model_path)
    #         else:
    #             if not osp.exists(local_model_path):
    #                 print(f'copy {model_path} to {local_model_path}')
    #                 shutil.copyfile(model_path, local_model_path)
    #             save_slim_model(local_model_path, save_file=local_slim_model_path, device=device)
    #             print(f'copy {local_slim_model_path} to {slim_model_path}')
    #             if not osp.exists(slim_model_path):
    #                 shutil.copyfile(local_slim_model_path, slim_model_path)
    #                 os.remove(local_model_path)
    #                 os.remove(model_path)
    #     slim_model_path = local_slim_model_path
    #     print(f'load checkpoint from {slim_model_path}')

    if args.model_type == 'infinity_2b':
        kwargs_model = dict(depth=32, embed_dim=2048, num_heads=2048//128, drop_path_rate=0.1, mlp_ratio=4, block_chunks=8) # 2b model
    elif args.model_type == 'infinity_layer12':
        kwargs_model = dict(depth=12, embed_dim=768, num_heads=8, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
    elif args.model_type == 'infinity_layer16':
        kwargs_model = dict(depth=16, embed_dim=1152, num_heads=12, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
    elif args.model_type == 'infinity_layer24':
        kwargs_model = dict(depth=24, embed_dim=1536, num_heads=16, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
    elif args.model_type == 'infinity_layer32':
        kwargs_model = dict(depth=32, embed_dim=2080, num_heads=20, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
    elif args.model_type == 'infinity_layer40':
        kwargs_model = dict(depth=40, embed_dim=2688, num_heads=24, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
    elif args.model_type == 'infinity_layer48':
        kwargs_model = dict(depth=48, embed_dim=3360, num_heads=28, drop_path_rate=0.1, mlp_ratio=4, block_chunks=4)
    infinity = load_infinity(
        rope2d_each_sa_layer=args.rope2d_each_sa_layer, 
        rope2d_normalized_by_hw=args.rope2d_normalized_by_hw,
        use_scale_schedule_embedding=args.use_scale_schedule_embedding,
        pn=args.pn,
        use_bit_label=args.use_bit_label, 
        add_lvl_embeding_only_first_block=args.add_lvl_embeding_only_first_block, 
        model_path=model_path, 
        scale_schedule=None, 
        vae=vae, 
        device=device, 
        model_kwargs=kwargs_model,
        text_channels=args.text_channels,
        apply_spatial_patchify=args.apply_spatial_patchify,
        use_flex_attn=args.use_flex_attn,
        bf16=args.bf16,
    )
    return infinity

def add_common_arguments(parser):
    parser.add_argument('--cfg', type=str, default='3')
    parser.add_argument('--tau', type=float, default=1)
    parser.add_argument('--pn', type=str, required=True, choices=['0.06M', '0.25M', '1M'])
    parser.add_argument('--model_path', type=str, required=True)
    parser.add_argument('--cfg_insertion_layer', type=int, default=0)
    parser.add_argument('--vae_type', type=int, default=1)
    parser.add_argument('--vae_path', type=str, default='')
    parser.add_argument('--add_lvl_embeding_only_first_block', type=int, default=0, choices=[0,1])
    parser.add_argument('--use_bit_label', type=int, default=1, choices=[0,1])
    parser.add_argument('--model_type', type=str, default='infinity_2b')
    parser.add_argument('--rope2d_each_sa_layer', type=int, default=1, choices=[0,1])
    parser.add_argument('--rope2d_normalized_by_hw', type=int, default=2, choices=[0,1,2])
    parser.add_argument('--use_scale_schedule_embedding', type=int, default=0, choices=[0,1])
    parser.add_argument('--sampling_per_bits', type=int, default=1, choices=[1,2,4,8,16])
    parser.add_argument('--text_encoder_ckpt', type=str, default='')
    parser.add_argument('--text_channels', type=int, default=2048)
    parser.add_argument('--apply_spatial_patchify', type=int, default=0, choices=[0,1])
    parser.add_argument('--h_div_w_template', type=float, default=1.000)
    parser.add_argument('--use_flex_attn', type=int, default=0, choices=[0,1])
    parser.add_argument('--enable_positive_prompt', type=int, default=0, choices=[0,1])
    parser.add_argument('--cache_dir', type=str, default='/dev/shm')
    parser.add_argument('--checkpoint_type', type=str, default='torch')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--bf16', type=int, default=1, choices=[0,1])

def encode_and_decode(lq_img_path,vae,save_path):
    
    lq_img = Image.open(lq_img_path)
    if lq_img.mode != "RGB":
        lq_img = lq_img.convert("RGB")
    transform = transforms.ToTensor()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    lq_img = transform(lq_img)
    lq_img = lq_img*2-1
    x = lq_img.unsqueeze(0).to(device, non_blocking=True)

    is_image = x.ndim == 4
    if not is_image:
        B, C, T, H, W = x.shape
    else:
        B, C, H, W = x.shape
        T = 1
    ptdtype = {None: torch.float32, 'fp32': torch.float32, 'bf16': torch.bfloat16}
    enc_dtype = ptdtype[vae.args.encoder_dtype]

    with torch.amp.autocast("cuda", dtype=enc_dtype):
        h, hs, hs_mid = vae.encoder(x, return_hidden=True) # B C H W or B C T H W
    hs = [_h.detach() for _h in hs]
    hs_mid = [_h.detach() for _h in hs_mid]
    h = h.to(dtype=torch.float32)
    # print(z.shape)
    # Multiscale LFQ         
    # z, all_indices, all_loss = vae.quantizer(h)
    z,_,_,_,_,_ = vae.quantizer(h)
    x_recon = vae.decoder(z)
    
    x_recon = (x_recon+1)/2
    x_recon = x_recon.squeeze(0)
    to_pil = transforms.ToPILImage()
    x_recon = to_pil(x_recon)  
    x_recon.save(save_path)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    add_common_arguments(parser)
    parser.add_argument('--prompt', type=str, default='a dog')
    parser.add_argument('--save_file', type=str, default='./tmp.jpg')
    parser.add_argument('--lq_img_path', type=str, default='')
    parser.add_argument('--noise_apply_layers',type=int,default=-1)
    parser.add_argument('--noise_apply_requant',type=int,default=1)
    parser.add_argument('--noise_apply_strength',type=float,default=0.3)
    parser.add_argument('--debug_bsc',type=int,default=0)
    args = parser.parse_args()
    # noise_apply_layers: int = 13        # Bitwise Self-Correction: apply noise to layers, -1 means not apply noise
    # noise_apply_strength: float = 0.3    # Bitwise Self-Correction: apply noise strength, -1 means not apply noise
    # noise_apply_requant: int = 1        # Bitwise Self-Correction: requant after apply noise
    # debug_bsc: int = 0   

    # parse cfg
    args.cfg = list(map(float, args.cfg.split(',')))
    if len(args.cfg) == 1:
        args.cfg = args.cfg[0]
    
    # load text encoder
    text_tokenizer, text_encoder = load_tokenizer(t5_path =args.text_encoder_ckpt)
    # load vae
    vae = load_visual_tokenizer(args)
    # load infinity
    infinity = load_transformer(vae, args)
    
    scale_schedule = dynamic_resolution_h_w[args.h_div_w_template][args.pn]['scales']
    scale_schedule = [ (1, h, w) for (_, h, w) in scale_schedule]
    
    # #encode decode
    # save_path = "visualize/bed_hq_recon.png"
    # encode_and_decode('visualize/bed_hq.png',vae,save_path)
    # pdb.set_trace()

    with autocast(dtype=torch.bfloat16):
        with torch.no_grad():
            generated_image = gen_one_img(
                infinity,
                vae,
                text_tokenizer,
                text_encoder,
                args.prompt,
                g_seed=args.seed,
                gt_leak=0,
                gt_ls_Bl=None,
                cfg_list=args.cfg,
                tau_list=args.tau,
                scale_schedule=scale_schedule,
                cfg_insertion_layer=[args.cfg_insertion_layer],
                vae_type=args.vae_type,
                sampling_per_bits=args.sampling_per_bits,
                enable_positive_prompt=args.enable_positive_prompt,
                lq_img_path=args.lq_img_path,
            )
    os.makedirs(osp.dirname(osp.abspath(args.save_file)), exist_ok=True)
    cv2.imwrite(args.save_file, generated_image.cpu().numpy())
    print(f'Save to {osp.abspath(args.save_file)}')
