#%%
import os
import math
from tqdm import trange, tqdm
from diffusers import AutoencoderDC
import matplotlib.pyplot as plt
from omegaconf import OmegaConf
import argparse
import time, datetime

import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch.nn as nn

from xvq.dataset import get_val_webdataset, get_train_webdataset, Preprocessor, get_imagenet_loader
from xvq.utils import seed_everything, check_rank_zero, to_ddp_model, set_eval
from xvq.models import setup_models
from efficientvit.ae_model_zoo import DCAE_HF
import matplotlib.pyplot as plt

def reconstruct_sample(device, dataset, vae_encode, vae_decode, vq_infer, index):
    img, _ = dataset[index]
    img = img.unsqueeze(0).to(device)

    ori = torch.clone(img)
    x_hat = img * 2 - 1
    lat = vae_encode(x_hat)
    rec = vq_infer(lat)
    rec = vae_decode(rec)
    rec = (rec + 1) / 2

    return ori.squeeze(0).cpu(), rec.squeeze(0).cpu()

def save_single_image(tensor_img, log_path, title, epoch, idx, rec=False):
    img = tensor_img.detach().cpu().numpy().transpose(1, 2, 0)  # (H, W, C)
    img = img.clip(0, 1)  

    plt.figure(figsize=(2, 2), dpi=200)
    plt.imshow(img)
    plt.axis("off")
    name = "revq" if rec else "gt"
    save_path = os.path.join(log_path, f"{name}_{idx:02d}.png")
    plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
    plt.close()

def get_config():
    # setup config
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, default="/path/to/pretrained_model/config.yaml")
    parser.add_argument("--name", type=str, default="/path/to/pretrained_model")
    opt = parser.parse_args()
    config = OmegaConf.load(opt.config)

    ###############################
    config.name = opt.name
    config.world_size = 1
    ###############################
    os.makedirs(config.log_path, exist_ok=True)
    print(f"Log path: {config.log_path}")
    return config

def load_preprocessor(device, config, is_eval: bool = True,
    ckpt_path: str = "/path/to//preprocessor.pth"):
    preprocessor = Preprocessor(
        input_data_size=config.input_data_size
    ).to(device)
    preprocessor.load_state_dict(
        torch.load(ckpt_path, map_location=device, weights_only=True)
    )
    if is_eval:
        preprocessor.eval()
    return preprocessor

def load_frozen_vae(device, config, is_eval: bool = True):
    # vae = AutoencoderDC.from_pretrained(config.vae_path, torch_dtype=torch.float32).to(device)
    vae = DCAE_HF.from_pretrained(f"mit-han-lab/dc-ae-f32c32-in-1.0").to(device)
    if is_eval:
        vae.eval()
    def vae_encode(x):
        with torch.no_grad():
            return vae.encode(x)
    def vae_decode(x):
        with torch.no_grad():
            return vae.decode(x)
    return vae_encode, vae_decode

def main_worker(rank, config):
    # setup devices
    torch.cuda.set_device(rank)
    device = torch.device(f"cuda:{rank}")
    config.rank = rank

    # setup distribution
    if config.world_size > 1:
        dist.init_process_group(
            backend="nccl", init_method="tcp://localhost:23456",
            rank=config.rank, world_size=config.world_size
        )
    
    val_dataset, val_loader = get_imagenet_loader(
        root_dir="/path/to/imagenet",
        batch_size=64,
        num_workers=8,
        image_size=256,
        split="val",
        shuffle=False
    )
    
    # load the preprocessor, index-projector, and vae from ckpt files
    preprocessor = load_preprocessor(device=device, config=config.model)
    config.model.data_dim = preprocessor.data_dim
    
    # setup the model
    quantizer, decoder, viewer = setup_models(config.model, device)
    code_bank = torch.load("/path/to/subset.pth", map_location=device, weights_only=True)
    code_bank = viewer.shuffle(code_bank)
    quantizer.prepare_codebook(code_bank, method="random")
    del code_bank
    torch.cuda.empty_cache()

    if config.world_size > 1:
        decoder = torch.nn.SyncBatchNorm.convert_sync_batchnorm(decoder)

    # print information
    if check_rank_zero():
        get_param_num = lambda x: sum(p.numel() for p in x.parameters() if p.requires_grad)
        print(f"Quantizer: {get_param_num(quantizer) / 1e6:.2f}M")
        print(f"Decoder: {get_param_num(decoder) / 1e6:.2f}M")
        total_params = get_param_num(quantizer) + get_param_num(decoder)
        print(f"Total params: {total_params / 1e6:.2f}M")
    
    # auto resume
    if os.path.exists(os.path.join(config.log_path, "ckpt.pth")):
        checkpoint = torch.load(os.path.join(config.log_path, "ckpt.pth"), map_location=device, weights_only=True)
        quantizer.load_state_dict(checkpoint["quantizer"])
        decoder.load_state_dict(checkpoint["decoder"])
        if check_rank_zero():
            print(f"loading from {config.log_path}/ckpt.pth")
    else:
        start_epoch = 0
        
    # start training
    quantizer, decoder = to_ddp_model(rank, quantizer, decoder)
    vae_encode, vae_decode = load_frozen_vae(device=device, config=config.model)

    def vq_infer(x):
        data = x.contiguous()
        data = preprocessor(data)
        data_shuffle = viewer.shuffle(data)
        quant_shuffle = quantizer(data_shuffle)["x_quant"]
        quant = viewer.unshuffle(quant_shuffle)
        data_rec = quant
        data_rec = decoder(data_rec)
        data_rec = data_rec.contiguous()
        data_rec = preprocessor.inverse(data_rec)

        return data_rec
        
    set_eval(quantizer, decoder)

    # visualize
    idx = torch.randint(0, len(val_dataset), (20,)).tolist()
    rec_log_path = os.path.join(config.workspace, "rec_log_path")
    ori_log_path = os.path.join(config.workspace, "ori_log_path")

    os.makedirs(rec_log_path, exist_ok=True)
    os.makedirs(ori_log_path, exist_ok=True)
    for i in idx:
        ori, rec = reconstruct_sample(device, val_dataset, vae_encode, vae_decode, vq_infer, i)
        save_single_image(ori, ori_log_path, "original", 0, i, rec=False)
        save_single_image(rec, rec_log_path, "reconstructed", 0, i, rec=True)

    if dist.is_available() and dist.is_initialized():
        # destroy the process group
        dist.destroy_process_group()
    

def main():
    # setup config
    config = get_config()
    seed_everything(config.seed)
    
    # launch
    if config.world_size > 1:
        torch.multiprocessing.spawn(main_worker, args=(config,), nprocs=config.world_size)
    else:
        main_worker(0, config)
    
if __name__ == "__main__":
    main()

# %%
