import os
import os.path as osp
import random

import numpy as np
import PIL.Image as PImage
import torch
import torchvision
from tqdm import tqdm

from models import build_vae_var

# disable default parameter init for faster speed
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)


def setup_environment(gpu_id, seed, tf32):
    """Configure environment for the script."""
    # seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # setup torch
    torch.cuda.set_device(gpu_id)
    print(f"Using device: {torch.cuda.current_device()}")
    device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")

    # run faster
    torch.backends.cudnn.allow_tf32 = tf32
    torch.backends.cuda.matmul.allow_tf32 = tf32
    torch.set_float32_matmul_precision("high" if tf32 else "highest")
    return device


def download_checkpoints(data_dir, model_depth):
    """Download VAE and VAR model checkpoints if they don't exist."""
    hf_home = "https://huggingface.co/FoundationVision/var/resolve/main"

    vae_ckpt = "vae_ch160v4096z32.pth"
    vae_ckpt_path = osp.join(data_dir, vae_ckpt)
    if not osp.exists(vae_ckpt_path):
        os.system(f"wget -P {data_dir} {hf_home}/{vae_ckpt}")

    var_ckpt = f"var_d{model_depth}.pth"
    var_ckpt_path = osp.join(data_dir, var_ckpt)
    if not osp.exists(var_ckpt_path):
        os.system(f"wget -P {data_dir} {hf_home}/{var_ckpt}")
    return vae_ckpt_path, var_ckpt_path


def load_models(model_depth, device, vae_ckpt_path, var_ckpt_path):
    """Build and load VAE and VAR models."""
    patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16)
    
    vae, var = build_vae_var(
        V=4096,
        Cvae=32,
        ch=160,
        share_quant_resi=4,  # hard-coded VQVAE hyperparameters
        device=device,
        patch_nums=patch_nums,
        num_classes=1000,
        depth=model_depth,
        shared_aln=False,
    )

    # load checkpoints
    vae.load_state_dict(torch.load(vae_ckpt_path, map_location="cpu"), strict=True)
    var.load_state_dict(torch.load(var_ckpt_path, map_location="cpu"), strict=True)
    
    vae.to(device)
    var.to(device)

    vae.eval()
    var.eval()

    for p in vae.parameters():
        p.requires_grad_(False)
    for p in var.parameters():
        p.requires_grad_(False)
    print("VAE and VAR models prepared successfully.")
    return vae, var


def generate_and_save_images(var, n_img, batch_size, cfg, seed, more_smooth, device, output_dir):
    """Generate and save images in batches."""
    os.makedirs(output_dir, exist_ok=True)
    recon_B3HW = None
    original_idxBl = None
    token_maps = []
    for i_batch in tqdm(range(0, n_img, batch_size)):
        class_labels = tuple([random.randint(0,1000) for _ in range(batch_size)])
        B = len(class_labels)
        label_B: torch.LongTensor = torch.tensor(class_labels, device=device)
        with torch.no_grad():
            # with torch.autocast('cuda', enabled=True, dtype=torch.float16, cache_enabled=True):    # using bfloat16 can be faster
            with torch.autocast('cuda', enabled=True, cache_enabled=True):    # using bfloat16 can be faster
                # recon_B3HW = var.autoregressive_infer_cfg(B=B, label_B=label_B, cfg=cfg, top_k=900, top_p=0.95, g_seed=seed, more_smooth=more_smooth)
                recon_B3HW_batch, original_idxBl_batch, token_map = var.autoregressive_infer_cfg_with_token_map(B=B, label_B=label_B, cfg=cfg, top_k=900, top_p=0.95, g_seed=seed, more_smooth=more_smooth)
            
        # concatenate batch results
        token_maps.append(token_map)
        if recon_B3HW is None:
            recon_B3HW = recon_B3HW_batch
        else:
            recon_B3HW = torch.cat((recon_B3HW, recon_B3HW_batch), dim=0)
        if original_idxBl is None:
            original_idxBl = original_idxBl_batch
        else:
            for i_scale in range(len(original_idxBl_batch)):
                original_idxBl[i_scale] = torch.cat([original_idxBl[i_scale],original_idxBl_batch[i_scale]], dim=0)
                
        
        chw = torchvision.utils.make_grid(recon_B3HW_batch, nrow=8, padding=0, pad_value=1.0)
        chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()
        chw = PImage.fromarray(chw.astype(np.uint8))
        chw.show()

    if recon_B3HW is not None:
        recon_B3HW_save = recon_B3HW.clone()
        # save all the generated images to png format in /generated
        for i in range(len(recon_B3HW_save)):
            img = recon_B3HW_save[i].permute(1, 2, 0).mul_(255).cpu().numpy().astype(np.uint8)
            img = PImage.fromarray(img)
            img.save(osp.join(output_dir, f'{i:03d}.png'))
        print(f'Generated {len(recon_B3HW_save)} images with classifier-free guidance, saved to {output_dir}.')

        # Save the corresponding token indices
        if original_idxBl is not None:
            torch.save(original_idxBl, osp.join(output_dir, 'generated_token_indices.pt'))
            print(f'Saved token indices to {osp.join(output_dir, "generated_token_indices.pt")}.')

        if token_maps:
            torch.save(token_maps, osp.join(output_dir, 'generated_token_maps.pt'))
            print(f'Saved token maps to {osp.join(output_dir, "generated_token_maps.pt")}.')


def main():
    # Configuration
    MODEL_DEPTH = 16
    assert MODEL_DEPTH in {16, 20, 24, 30}
    DATA_DIR = "[VAR_MODEL_PATH]"
    OUTPUT_DIR = "./VAR/var_generated_validation/"
    GPU_ID = 0
    SEED = 123
    N_IMG = 1024
    BATCH_SIZE = 128
    CFG = 4.0
    MORE_SMOOTH = False
    TF32 = True

    device = setup_environment(GPU_ID, SEED, TF32)

    vae_ckpt_path, var_ckpt_path = download_checkpoints(DATA_DIR, MODEL_DEPTH)

    _, var = load_models(MODEL_DEPTH, device, vae_ckpt_path, var_ckpt_path)

    generate_and_save_images(
        var, N_IMG, BATCH_SIZE, CFG, SEED, MORE_SMOOTH, device, OUTPUT_DIR
    )


if __name__ == "__main__":
    main()
