'''
Download checkpoints and build models
'''
import os
import os.path as osp
from tqdm import tqdm
import torch, torchvision
import random
import numpy as np
import PIL.Image as PImage, PIL.ImageDraw as PImageDraw
setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)     # disable default parameter init for faster speed
setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)  # disable default parameter init for faster speed
from models import VQVAE, build_vae_var
from attacks import apply_attack
from torch.utils.data import DataLoader


import torch
gpu_id = 0
torch.cuda.set_device(gpu_id)
print(torch.cuda.current_device())
device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")

ae_type = 'orig' # 'enc_ft' 'orig'

n_img = 1000
batch_size = 8
optim_iters = 200
optim_stop_scale = 4 # 8 4

MODEL_DEPTH = 16    # TODO: =====> please specify MODEL_DEPTH <=====
assert MODEL_DEPTH in {16, 20, 24, 30}

data_dir_ae_orig = "[VAR_MODEL_PATH]"
vae_ckpt_orig = 'vae_ch160v4096z32.pth'
vae_ckpt_path_orig = osp.join(data_dir_ae_orig, vae_ckpt_orig)

match ae_type:
    case 'orig':
        data_dir_ae = "[VAR_MODEL_PATH]"
        vae_ckpt, var_ckpt = 'vae_ch160v4096z32.pth', f'var_d{MODEL_DEPTH}.pth'
    case 'enc_ft':
        data_dir_ae = "[DATA_SAVE_DIR]/var/finetune"
        vae_ckpt, var_ckpt = 'vqvae_finetuned_fmap_lpips0.0_mse_img0.0_mse_feat1.0_steps20000_encoder_lr5e-05_bs16_with_dataset.pth', f'var_d{MODEL_DEPTH}.pth'

data_dir_ar = "[VAR_MODEL_PATH]"


# download checkpoint
hf_home = 'https://huggingface.co/FoundationVision/var/resolve/main'
# vae_ckpt, var_ckpt = 'vqvae_finetuned_fmap_lpips0.0_mse_img0.0_mse_feat1.0_keep_real10.0_steps20000_encoder_lr5e-05_bs8_with_dataset.pth', f'var_d{MODEL_DEPTH}.pth'

vae_ckpt_path, var_ckpt_path = osp.join(data_dir_ae, vae_ckpt), osp.join(data_dir_ar, var_ckpt)
if not osp.exists(vae_ckpt_path): 
    os.system(f'wget -P {data_dir_ae} {hf_home}/{vae_ckpt}')
if not osp.exists(var_ckpt_path): 
    os.system(f'wget -P {data_dir_ar} {hf_home}/{var_ckpt}')

# build vae, var
patch_nums = [1, 2, 3, 4, 5, 6, 8, 10, 13, 16]
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
if 'vae' not in globals() or 'var' not in globals():
    vae, var = build_vae_var(
        V=4096, Cvae=32, ch=160, share_quant_resi=4,    # hard-coded VQVAE hyperparameters
        device=device, patch_nums=patch_nums,
        num_classes=1000, depth=MODEL_DEPTH, shared_aln=False,
    )
    # load the original vae
    vae_orig, var_orig = build_vae_var(
        V=4096, Cvae=32, ch=160, share_quant_resi=4,    # hard-coded VQVAE hyperparameters
        device=device, patch_nums=patch_nums,
        num_classes=1000, depth=MODEL_DEPTH, shared_aln=False,
    )
    del var_orig

# load checkpoints
vae.load_state_dict(torch.load(vae_ckpt_path, map_location='cpu'), strict=True)
# var.load_state_dict(torch.load(var_ckpt_path, map_location='cpu'), strict=True)
vae.eval()
# var.eval()
for p in vae.parameters(): p.requires_grad_(False)
# for p in var.parameters(): p.requires_grad_(False)
del var

# load original vae
vae_orig.load_state_dict(torch.load(vae_ckpt_path_orig, map_location='cpu'), strict=True)
vae_orig.eval()
for p in vae_orig.parameters(): p.requires_grad_(False)

print(f'prepare finished.')


# configs for attack
attack_type = "none"
range_map = {
    "none": [0.0],  # no attack
    "noise": [0.1], #[0.0, 0.05, 0.1, 0.15, 0.2],
    "gauss": [7], #[1, 3, 5, 7, 9, 11, 13, 15, 17, 19],
    "crop": [0.5], #[1, 0.95, 0.9, 0.85, 0.8, 0.75, 0.7, 0.65, 0.6, 0.55, 0.5],
    "jpeg": [50], #[100, 90, 80, 70, 60, 50, 40, 30, 20, 10],
    "rotate": [(-5,5)],
    "CtrlRegen": [0.1], #[0, 0.1, 0.2, 0.3, 0.4, 0.5]
}

args_map = {
    "none" : "none",
    "noise" : "variance",
    "gauss" : "kernel_size",
    "crop" : "crop_ratio",
    "jpeg" : "final_quality",
    "rotate" : "degrees",
    "CtrlRegen" : "ctrl_regen_steps"
}

import argparse
parser = argparse.ArgumentParser()
args, unknown = parser.parse_known_args()
args.num_samples = n_img
args.variance = 0.1  # default noise variance
attack_strength = range_map[attack_type][0]  # default attack strength
args.__dict__[args_map[attack_type]] = attack_strength


'''
Calculating different losses
'''
# compare how many tokens match for the two token maps on each scale
import torch.nn.functional as F
def compare_token_maps(map1, map2, return_per_scale=False):
    total_equal = []
    total_elements = []
    overlapping_ratio = []
    
    for tensor_list1, tensor_list2 in zip(map1, map2):
        total_equal_scale = []
        total_elements_scale = []
        if tensor_list1.shape != tensor_list2.shape:
            raise ValueError(f"Tensors have different shapes")
        
        # Create a boolean mask of equal elements
        equal_mask = tensor_list1.cpu() == tensor_list2.cpu() # Bl (binary)

        # Count equal elements
        equal_count = equal_mask.sum(1) # B
        element_count = torch.Tensor([tensor_list1[i].numel() for i in range(len(tensor_list1))]) # B

        total_equal_scale.append(equal_count)
        total_elements_scale.append(element_count)
        overlapping_ratio.append(equal_count / element_count)

    if return_per_scale:
        return total_equal, total_elements, overlapping_ratio
    else:
        return [equal.sum() for equal in total_equal], [element.sum() for element in total_elements], [ratio.mean() for ratio in overlapping_ratio]

def compare_embeddings(embed_scale1, embed_scale2): #List[BhwC]
    error_scales = []
    for embed1, embed2 in zip(embed_scale1, embed_scale2):
        error = F.mse_loss(embed1, embed2, reduction='none')
        error = error.mean(dim=(1,2,3)).cpu()
        error_scales.append(error)
    return error_scales

'''
Define the losses without the optimized token search
'''
# the losses with the quantization
def calculate_loss_batch(all_dataset_results, dataset_name, original_B3HW, display_img=False, original_idxBl=None):

    # first reconstruction
    recon_img, recon_idxBl, f, fhat, embeddings = vae.img_to_reconstructed_img_with_intermediates(original_B3HW.clone().mul_(2).add_(-1).float(), last_one=True)
    # calculate overlapping ratio only when the original tokens are provided
    if original_idxBl is not None:
        total_equal, total_elements, overlapping_ratio = compare_token_maps(original_idxBl, recon_idxBl, return_per_scale=True)
        # print([ratio.mean() for ratio in overlapping_ratio])
    else:
        overlapping_ratio = []

    recon_img_show = recon_img.clone().add_(1).mul_(0.5)

    feature_map_mse = F.mse_loss(fhat, f, reduction='none').mean(dim=(1,2,3)).cpu()
    recon_img_mse = F.mse_loss(recon_img, original_B3HW.clone().mul_(2).add_(-1).float(), reduction='none').mean(dim=(1,2,3)).cpu()
    embed_mse_current = compare_embeddings(embeddings["current_resolution"]["original"], embeddings["current_resolution"]["quantized"])
    embed_mse_full = compare_embeddings(embeddings["full_resolution"]["original"], embeddings["full_resolution"]["quantized"])

    # second reconstruction
    recon_img_2nd, recon_idxBl_2nd, f_2nd, fhat_2nd, embeddings_2nd = vae.img_to_reconstructed_img_with_intermediates(recon_img, last_one=True)
    recon_img_2nd_show = recon_img_2nd.clone().add_(1).mul_(0.5)

    feature_map_mse_2nd = F.mse_loss(fhat_2nd, f_2nd, reduction='none').mean(dim=(1,2,3)).cpu()
    recon_img_mse_2nd = F.mse_loss(recon_img_2nd, recon_img.clone(), reduction='none').mean(dim=(1,2,3)).cpu()
    embed_mse_current_2nd = compare_embeddings(embeddings_2nd["current_resolution"]["original"], embeddings_2nd["current_resolution"]["quantized"])
    embed_mse_full_2nd = compare_embeddings(embeddings_2nd["full_resolution"]["original"], embeddings_2nd["full_resolution"]["quantized"])

    feature_map_mse_ratio = feature_map_mse / feature_map_mse_2nd
    recon_img_mse_ratio = recon_img_mse / recon_img_mse_2nd
    embed_mse_current_ratio = [embed_mse_current[i] / embed_mse_current_2nd[i] for i in range(len(embed_mse_current))]
    embed_mse_full_ratio = [embed_mse_full[i] / embed_mse_full_2nd[i] for i in range(len(embed_mse_full))]

    # update the results
    if len(all_dataset_results["overlapping"]["all"][dataset_name]) == 0:
        all_dataset_results["overlapping"]["all"][dataset_name] = overlapping_ratio
    else:
        for scale in range(len(overlapping_ratio)):
            all_dataset_results["overlapping"]["all"][dataset_name][scale] = torch.cat((all_dataset_results["overlapping"]["all"][dataset_name][scale], overlapping_ratio[scale]))

    all_dataset_results["feature_map"]["1st"][dataset_name].extend(feature_map_mse)
    all_dataset_results["rec"]["1st"][dataset_name].extend(recon_img_mse)
    for i in range(len(patch_nums)):
        all_dataset_results[f"embedding_current_{i}"]["1st"][dataset_name].extend(embed_mse_current[i])
        all_dataset_results[f"embedding_full_{i}"]["1st"][dataset_name].extend(embed_mse_full[i])

    all_dataset_results["feature_map"]["2nd"][dataset_name].extend(feature_map_mse_2nd)
    all_dataset_results["rec"]["2nd"][dataset_name].extend(recon_img_mse_2nd)
    for i in range(len(patch_nums)):
        all_dataset_results[f"embedding_current_{i}"]["2nd"][dataset_name].extend(embed_mse_current_2nd[i])
        all_dataset_results[f"embedding_full_{i}"]["2nd"][dataset_name].extend(embed_mse_full_2nd[i])

    all_dataset_results["feature_map"]["ratio"][dataset_name].extend(feature_map_mse_ratio)
    all_dataset_results["rec"]["ratio"][dataset_name].extend(recon_img_mse_ratio)
    for i in range(len(patch_nums)):
        all_dataset_results[f"embedding_current_{i}"]["ratio"][dataset_name].extend(embed_mse_current_ratio[i])
        all_dataset_results[f"embedding_full_{i}"]["ratio"][dataset_name].extend(embed_mse_full_ratio[i])

    # display the images
    if display_img:
        chw = torchvision.utils.make_grid(original_B3HW, nrow=8, padding=0, pad_value=1.0)
        chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()
        chw = PImage.fromarray(chw.astype(np.uint8))
        chw.show()

        chw = torchvision.utils.make_grid(recon_img_show, nrow=8, padding=0, pad_value=1.0)
        chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()
        chw = PImage.fromarray(chw.astype(np.uint8))
        chw.show()

    return all_dataset_results

# the losses without the quantization
def calculate_loss_batch_no_quant(all_dataset_results, dataset_name, original_B3HW, display_img=False, original_idxBl=None):

    # first reconstruction
    recon_img, recon_idxBl, f, fhat, embeddings = vae.img_to_reconstructed_img_without_quant(original_B3HW.clone().mul_(2).add_(-1).float(), last_one=True)
    # recon_img_, recon_idxBl_, f_, fhat_, embeddings_ = vae.img_to_reconstructed_img_with_intermediates(original_B3HW.clone().mul_(2).add_(-1).float(), last_one=True)
    # calculate overlapping ratio only when the original tokens are provided
    if original_idxBl is not None:
        total_equal, total_elements, overlapping_ratio = compare_token_maps(original_idxBl, recon_idxBl, return_per_scale=True)
        # print(overlapping_ratio)
    else:
        overlapping_ratio = []
    

    recon_img_show = recon_img.clone().add_(1).mul_(0.5)


    recon_img_mse = F.mse_loss(recon_img.clone(), original_B3HW.clone().mul_(2).add_(-1).float(), reduction='none').mean(dim=(1,2,3)).cpu()

    # second reconstruction
    # recon_img_2nd, recon_idxBl_2nd, f_2nd, fhat_2nd, embeddings_2nd = vae.img_to_reconstructed_img_without_quant(recon_img.clone(), last_one=True)
    recon_img_2nd, recon_idxBl_2nd, f_2nd, fhat_2nd, embeddings_2nd = vae.img_to_reconstructed_img_without_quant(recon_img.clone(), last_one=True)
    # recon_img_2nd, recon_idxBl_2nd, f_2nd, fhat_2nd, embeddings_2nd = vae.img_to_reconstructed_img_with_intermediates(recon_img.clone(), last_one=True)
    recon_img_2nd_show = recon_img_2nd.clone().add_(1).mul_(0.5)


    recon_img_mse_2nd = F.mse_loss(recon_img_2nd.clone(), recon_img.clone(), reduction='none').mean(dim=(1,2,3)).cpu()

    recon_img_mse_ratio = recon_img_mse / recon_img_mse_2nd
    # update the results
    if len(all_dataset_results["overlapping_no_quant"]["all"][dataset_name]) == 0:
        all_dataset_results["overlapping_no_quant"]["all"][dataset_name] = overlapping_ratio
    else:
        for scale in range(len(overlapping_ratio)):
            all_dataset_results["overlapping_no_quant"]["all"][dataset_name][scale] = torch.cat((all_dataset_results["overlapping_no_quant"]["all"][dataset_name][scale], overlapping_ratio[scale]))

    all_dataset_results["rec_no_quant"]["1st"][dataset_name].extend(recon_img_mse)

    all_dataset_results["rec_no_quant"]["2nd"][dataset_name].extend(recon_img_mse_2nd)

    all_dataset_results["rec_no_quant"]["ratio"][dataset_name].extend(recon_img_mse_ratio)

    # display the images
    if display_img:
        chw = torchvision.utils.make_grid(original_B3HW, nrow=8, padding=0, pad_value=1.0)
        chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()
        chw = PImage.fromarray(chw.astype(np.uint8))
        chw.show()

        chw = torchvision.utils.make_grid(recon_img_show, nrow=8, padding=0, pad_value=1.0)
        chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()
        chw = PImage.fromarray(chw.astype(np.uint8))
        chw.show()

    return all_dataset_results

# optimized token search
def f_to_idxBl_and_fhat_optimized(f_hat_target, iters=optim_iters):
    idxBl_rec_first, f_rest, _, embedhat_scales = vae.quantize.f_to_idxBl_or_fhat_with_f_rest(f_hat_target.clone(), to_fhat=False)
    idxBl_rec, f_refine,  mse_refine = vae.quantize.refine_soft_assign(f_hat_target, init_idx_Bl=idxBl_rec_first, iters=iters, lr=0.1, entropy_weight=0, tau_start=2, tau_end=0.5)
    idxBl_rec_topk_re = idxBl_rec
    f_refine_re = f_refine
    # freeze the tokens layer by layer
    for i in range(9, optim_stop_scale, -1):
        idxBl_rec_topk_re, f_refine_re, _ = vae.quantize.refine_soft_assign(f_hat_target, init_idx_Bl=idxBl_rec_topk_re, iters=iters, lr=0.1, entropy_weight=0, fix_scale=i, tau_start=2, tau_end=0.5)

    return idxBl_rec_topk_re, f_refine_re

def img_to_reconstructed_img_with_optim(original_B3HW):
    f_gen = vae.quant_conv(vae.encoder(original_B3HW.clone()))
    idxBl_rec_topk_re, f_refine_re = f_to_idxBl_and_fhat_optimized(f_gen)
    # feature_map_mse = F.mse_loss(f_gen, f_refine_re, reduction='none').mean(dim=(1,2,3)).cpu()
    # print(f'Feature map MSE: {feature_map_mse}')
    # rec_gen_img = vae.fhat_to_img(f_gen)
    rec_gen_img = vae.decoder(vae.post_quant_conv(f_refine_re)).clamp_(-1, 1)
    return rec_gen_img, idxBl_rec_topk_re, f_gen, f_refine_re


'''
Define the losses with the optimized token search
'''
def calculate_loss_batch_with_optim(all_dataset_results, dataset_name, original_B3HW, display_img=False, original_idxBl=None):

    # first reconstruction
    # recon_img, recon_idxBl, f, fhat = img_to_reconstructed_img_with_optim(original_B3HW.clone().mul_(2).add_(-1).float())
    recon_img, recon_idxBl, f, fhat = img_to_reconstructed_img_with_optim(original_B3HW.clone().mul_(2).add_(-1).float())
    # recon_img_normal, recon_idxBl_normal, f_normal, fhat_normal, _ = vae.img_to_reconstructed_img_with_intermediates(original_B3HW.clone().mul_(2).add_(-1).float(), last_one=True)
    # recon_img_normal, recon_idxBl_normal, f_normal, fhat_normal, _ = vae_orig.img_to_reconstructed_img_with_intermediates(original_B3HW.clone().mul_(2).add_(-1).float(), last_one=True)
    # recon_img_no_quant, recon_idxBl_no_quant, f_no_quant, fhat_no_quant, _ = vae.img_to_reconstructed_img_without_quant(original_B3HW.clone().mul_(2).add_(-1).float(), last_one=True)

    # calculate overlapping ratio only when the original tokens are provided
    if original_idxBl is not None:
        total_equal, total_elements, overlapping_ratio = compare_token_maps(original_idxBl, recon_idxBl, return_per_scale=True)
        # print([ratio.mean() for ratio in overlapping_ratio])
    else:
        overlapping_ratio = []

    recon_img_show = recon_img.clone().add_(1).mul_(0.5)

    feature_map_mse = F.mse_loss(fhat, f, reduction='none').mean(dim=(1,2,3)).cpu()
    recon_img_mse = F.mse_loss(recon_img, original_B3HW.clone().mul_(2).add_(-1).float(), reduction='none').mean(dim=(1,2,3)).cpu()

    # second reconstruction
    recon_img_2nd, recon_idxBl_2nd, f_2nd, fhat_2nd = img_to_reconstructed_img_with_optim(recon_img.clone())
    # f_2nd = fhat.clone()
    # _, fhat_2nd = f_to_idxBl_and_fhat_optimized(f_2nd.clone())
    # recon_img_2nd = vae.decoder(vae.post_quant_conv(fhat_2nd)).clamp_(-1, 1)
    # recon_img_2nd, recon_idxBl_2nd, f_2nd, fhat_2nd, _ = vae.img_to_reconstructed_img_without_quant(recon_img_no_quant.clone(), last_one=True)

    recon_img_2nd_show = recon_img_2nd.clone().add_(1).mul_(0.5)

    feature_map_mse_2nd = F.mse_loss(fhat_2nd, f_2nd, reduction='none').mean(dim=(1,2,3)).cpu()
    # feature_map_mse_2nd = F.mse_loss(recon_img_2nd.clone(), recon_img_no_quant.clone(), reduction='none').mean(dim=(1,2,3)).cpu()
    recon_img_mse_2nd = F.mse_loss(recon_img_2nd, recon_img.clone(), reduction='none').mean(dim=(1,2,3)).cpu()

    feature_map_mse_ratio = feature_map_mse / feature_map_mse_2nd
    recon_img_mse_ratio = recon_img_mse / recon_img_mse_2nd

    # update the results
    if len(all_dataset_results["overlapping_optim"]["all"][dataset_name]) == 0:
        all_dataset_results["overlapping_optim"]["all"][dataset_name] = overlapping_ratio
    else:
        for scale in range(len(overlapping_ratio)):
            all_dataset_results["overlapping_optim"]["all"][dataset_name][scale] = torch.cat((all_dataset_results["overlapping_optim"]["all"][dataset_name][scale], overlapping_ratio[scale]))

    all_dataset_results["feature_map_optim"]["1st"][dataset_name].extend(feature_map_mse)
    all_dataset_results["rec_optim"]["1st"][dataset_name].extend(recon_img_mse)

    all_dataset_results["feature_map_optim"]["2nd"][dataset_name].extend(feature_map_mse_2nd)
    all_dataset_results["rec_optim"]["2nd"][dataset_name].extend(recon_img_mse_2nd)

    all_dataset_results["feature_map_optim"]["ratio"][dataset_name].extend(feature_map_mse_ratio)
    all_dataset_results["rec_optim"]["ratio"][dataset_name].extend(recon_img_mse_ratio)

    # display the images
    if display_img:
        chw = torchvision.utils.make_grid(original_B3HW, nrow=8, padding=0, pad_value=1.0)
        chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()
        chw = PImage.fromarray(chw.astype(np.uint8))
        chw.show()

        chw = torchvision.utils.make_grid(recon_img_show, nrow=8, padding=0, pad_value=1.0)
        chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()
        chw = PImage.fromarray(chw.astype(np.uint8))
        chw.show()

    return all_dataset_results

# initialize the loss results
all_dataset_results = {
    # for losses with the optimized token search
    "feature_map": {
        "1st": {},
        "2nd": {},
        "ratio": {},
        # "ratio_calibrated": {},

    },
    "rec": {
        "1st": {},
        "2nd": {},
        "ratio": {},
        # "ratio_calibrated": {},
    },
    "rec_no_quant": {
        "1st": {},
        "2nd": {},
        "ratio": {},
        # "ratio_calibrated": {},
    },
    "overlapping": {
        "all": {}
    },
    "overlapping_no_quant": {
        "all": {}
    },
    # for losses with the optimized token search
    "feature_map_optim": {
        "1st": {},
        "2nd": {},
        "ratio": {},
        # "ratio_calibrated": {},

    },
    "rec_optim": {
        "1st": {},
        "2nd": {},
        "ratio": {},
        # "ratio_calibrated": {},
    },
    "overlapping_optim": {
        "all": {}
    },
    "fmap_recnoquant_prod": {
        "all": {}
    },
    "fmap_recnoquant_sum": {
        "all": {}
    }
}
for i in range(len(patch_nums)):
    all_dataset_results[f"embedding_current_{i}"] = {
        "1st": {},
        "2nd": {},
        "ratio": {},
        # "ratio_calibrated": {},
    }
    all_dataset_results[f"embedding_full_{i}"] = {
        "1st": {},
        "2nd": {},
        "ratio": {},
        # "ratio_calibrated": {},
    }


# calculate all the losses for all given datasets
dataset_name_image_path = {
    "LAION": "PATH_TO_LAION_SUBSET",   # TODO: =====> please specify the path to your LAION subset <=====
    "MS-COCO": "PATH_TO_MSCOCO_SUBSET", # TODO: =====> please specify the path to your MS-COCO subset <=====
    "ImageNet (val)": "PATH_TO_IMAGENET_VAL_SUBSET", # TODO: =====> please specify the path to your ImageNet validation subset <=====
    "ImageNet (train)": "PATH_TO_IMAGENET_TRAIN_SUBSET", # TODO: =====> please specify the path to your ImageNet training subset <=====
    "RAR Generated": "PATH_TO_RAR_GENERATED_SUBSET", # TODO: =====> please specify the path to your RAR generated images subset <=====
    "VAR Generated": "PATH_TO_VAR_GENERATED_SUBSET", # TODO: =====> please specify the path to your VAR generated images subset <=====
    "LlamaGen Generated": "PATH_TO_LLAMAGEN_GENERATED_SUBSET", # TODO: =====> please specify the path to your LlamaGen generated images subset <=====
    "Taming Generated": "PATH_TO_TAMING_GENERATED_SUBSET", # TODO: =====> please specify the path to your Taming generated images subset <=====
    "Infinity Generated": "PATH_TO_INFINITY_GENERATED_SUBSET", # TODO: =====> please specify the path to your Infinity generated images subset <=====
}

var_dataset_name = "VAR Generated"

for dataset_name in dataset_name_image_path.keys():
    # dataset_names.append(dataset_name)
    for loss_type in all_dataset_results.keys():
        for loss_round in all_dataset_results[loss_type].keys():
            all_dataset_results[loss_type][loss_round][dataset_name] = []

    image_path = dataset_name_image_path[dataset_name]
    print(f'[{dataset_name}] Reading from {image_path}', flush=True)
    load_token_map = (dataset_name == var_dataset_name)
    dataset = apply_attack(img_path=image_path, attack=attack_type, load_token_map=load_token_map, args=args)
    print(f'dataset length: {len(dataset)}', flush=True)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    for i, batch in enumerate(dataloader):
        image_batch, token_batch = batch
        original_B3HW, original_idxBl = image_batch, token_batch
        original_B3HW = original_B3HW.cuda()
        # the main losses
        display_img = False
        if dataset_name == var_dataset_name:
            original_idxBl = [idxBl.cuda() for idxBl in original_idxBl]
            all_dataset_results = calculate_loss_batch(all_dataset_results, dataset_name, original_B3HW, display_img=display_img, original_idxBl=original_idxBl)
        else:
            all_dataset_results = calculate_loss_batch(all_dataset_results, dataset_name, original_B3HW, display_img=display_img, original_idxBl=None)
        # the rec loss without quant
        if dataset_name == var_dataset_name:
            original_idxBl = [idxBl.cuda() for idxBl in original_idxBl]
            all_dataset_results = calculate_loss_batch_no_quant(all_dataset_results, dataset_name, original_B3HW, display_img=display_img, original_idxBl=original_idxBl)
        else:
            all_dataset_results = calculate_loss_batch_no_quant(all_dataset_results, dataset_name, original_B3HW, display_img=display_img, original_idxBl=None)
        # the losses with the optim
        # print(f'with optim search')
        if dataset_name == var_dataset_name:
            original_idxBl = [idxBl.cuda() for idxBl in original_idxBl]
            all_dataset_results = calculate_loss_batch_with_optim(all_dataset_results, dataset_name, original_B3HW, display_img=display_img, original_idxBl=original_idxBl)
        else:
            all_dataset_results = calculate_loss_batch_with_optim(all_dataset_results, dataset_name, original_B3HW, display_img=display_img, original_idxBl=None)
        print(f"batch {i}", flush=True)
        # print(len(all_dataset_results["rec"]["1st"][dataset_name]), flush=True)
        # print(len(all_dataset_results["rec_no_quant"]["1st"][dataset_name]), flush=True)

        # ms_h_BChw = vae.idxBl_to_embedhat(original_idxBl.clone())
        # f_hat_target = vae.quantize.embedhat_to_fhat(
        #     ms_h_BChw, all_to_max_scale=True, last_one=True
        # ).detach()
        # # Also decode the target f_hat to get the target image reconstruction
        # image_reconstructed_target = vae.decoder(vae.post_quant_conv(f_hat_target)).clamp(-1, 1)

    # add feature combination: feature map optim * rec no quant ratio
    optim_ratios = all_dataset_results['feature_map_optim']['1st'][dataset_name]
    no_quant_ratios = all_dataset_results['rec_no_quant']['ratio'][dataset_name]
    combined_ratios = [optim * no_quant for optim, no_quant in zip(optim_ratios, no_quant_ratios)]
    all_dataset_results['fmap_recnoquant_prod']['all'][dataset_name] = combined_ratios
    combined_ratios = [optim + no_quant for optim, no_quant in zip(optim_ratios, no_quant_ratios)]
    all_dataset_results['fmap_recnoquant_sum']['all'][dataset_name] = combined_ratios

    print(f"overlapping ratio: {[ratio.mean() for ratio in all_dataset_results['overlapping']['all'][dataset_name]]}", flush=True)
    print(f"overlapping ratio optim: {[ratio.mean() for ratio in all_dataset_results['overlapping_optim']['all'][dataset_name]]}", flush=True)


# visualize the probability distribution function of codebook losses for both real and generated images with seaborn
from utils.plot import plot_multi_pdf
from sklearn import metrics
import pandas as pd

def evaluate(rar_scores, other_scores):
    all_labels = np.concatenate([np.zeros(len(other_scores)), np.ones(len(rar_scores))])
    all_scores = np.concatenate([other_scores, rar_scores])
    all_scores_inverted = -all_scores
    # all_labels = np.concatenate([np.zeros(len(rar_scores)), np.ones(len(other_scores))])
    # all_scores = np.concatenate([rar_scores, other_scores])

    # fpr, tpr, threshold = metrics.roc_curve(all_labels, all_scores)
    fpr, tpr, threshold_inverted = metrics.roc_curve(all_labels, all_scores_inverted)
    auc = metrics.auc(fpr, tpr)
    acc = np.max(1 - (fpr + (1 - tpr))/2)

    idx = np.where(fpr < 0.01)[0][-1]
    threshold_at_1fpr = -threshold_inverted[idx]
    tpr_at_1fpr = tpr[idx]

    return threshold_at_1fpr, auc, acc, tpr_at_1fpr


print(f'Results for setting: VAR-d[{MODEL_DEPTH}], ae-{ae_type}, {attack_type}({args_map[attack_type]}={args.__dict__[args_map[attack_type]]}), {n_img}imgs, {batch_size}bs, {optim_iters}iters, {optim_stop_scale}optim_stop_scale')
result_dir = f'results/VAR-d{MODEL_DEPTH}/ae-{ae_type}/{attack_type}({args_map[attack_type]}={args.__dict__[args_map[attack_type]]})_{n_img}imgs_{batch_size}bs_{optim_iters}iters_{optim_stop_scale}optim_stop_scale'
os.makedirs(result_dir, exist_ok=True)

results_table = []

for loss_type, loss_type_results in all_dataset_results.items(): # codebook, rec
    if "overlapping" in loss_type or "full" in loss_type or "current" in loss_type:
        continue
    for loss_round, loss_round_results in loss_type_results.items(): # 1st, 2nd, ratio, ratio_calibrated
        print(f"[{loss_type}][{loss_round}]")
        data_list, label_list = [], [f"{dataset_name}" for dataset_name in dataset_name_image_path.keys()]
        for dataset_name, dataset_results in loss_round_results.items(): # Real, Real (train), VAR Generated, RAR Generated
            print(f"{dataset_name} dataset size {len(dataset_results)}")
            data_list.append(np.array(dataset_results))
        # plot
        title = f'{loss_type}_{loss_round}_mse'
        xlabel = 'Loss'
        ylabel = 'PDF'
        plot_multi_pdf(data_list, label_list, title, xlabel, ylabel, save_dir=result_dir)
        
        # quantitative
        results_table_single_losses = []
        for dataset_name, dataset_results in loss_round_results.items():  # Real, VAR Generated, etc.
            if dataset_name != var_dataset_name:
                threshold, auc, acc, tpr1 = evaluate(
                    np.array(loss_round_results[var_dataset_name]),
                    np.array(dataset_results)
                )
                results_table.append({
                    "Loss Type": loss_type,
                    "Loss Round": loss_round,
                    "Comparison": f"VAR Generated vs {dataset_name}",
                    "Threshold": round(threshold, 4),
                    "AUC": round(auc, 4),
                    "ACC": round(acc, 4),
                    "TPR@1%FPR": round(tpr1, 4)
                })
                results_table_single_losses.append({
                    "Comparison": f"VAR Generated vs {dataset_name}",
                    "Threshold": round(threshold, 4),
                    "AUC": round(auc, 4),
                    "ACC": round(acc, 4),
                    "TPR@1%FPR": round(tpr1, 4)
                })
        df_single = pd.DataFrame(results_table_single_losses)
        print(df_single.to_string(index=False), flush=True)
        print('\n')
import pandas as pd
df = pd.DataFrame(results_table)
print(df.to_string(index=False))
# save the results
df.to_csv(os.path.join(result_dir, 'results.csv'), index=False)

import re
import pandas as pd

# --- your input DataFrame ---
# df = ...  # must contain columns: ["Loss Type","Loss Round","Comparison","Threshold","AUC","ACC","TPR@1%FPR"]

# Choose which metric to display in the table
metric = "TPR@1%FPR"  # or "ACC", "TPR@1%FPR", "Threshold"

# Desired column order
col_order = [
    "ImageNet (train)",
    "ImageNet (val)",
    "LAION",
    "MS-COCO",
    "LlamaGen",
    "RAR",
    "Taming",
    "VAR",
    "Infinity",
]

aliases = {
    "imagenet (train)": "ImageNet (train)",
    "imagenet (val)": "ImageNet (val)",
    "imagenet": "ImageNet",  # fallback, usually won't be used
    "laion": "LAION",
    "ms-coco": "MS-COCO",
    "llamagen": "LlamaGen",
    "rar": "RAR",
    "taming": "Taming",
    "var": "VAR",
    "infinity": "Infinity",
}

def extract_dataset(x: str) -> str:
    """Extract the dataset name from the Comparison string."""
    m = re.search(r"\bvs\b\s+(.*)$", str(x), flags=re.IGNORECASE)
    if not m:
        return ""
    right = m.group(1)
    right = re.sub(r"\bGenerated\b", "", right, flags=re.IGNORECASE).strip()
    # keep words like "ImageNet (val)" intact
    key = right.lower().strip()
    return aliases.get(key, right)

tmp = df.copy()
tmp["Dataset"] = tmp["Comparison"].apply(extract_dataset)

# keep only rows for the datasets we care about
tmp = tmp[tmp["Dataset"].isin(col_order)]

# remove duplicates if any
tmp = tmp.drop_duplicates(subset=["Loss Type","Loss Round","Dataset"], keep="last")

# Pivot to wide format
wide = tmp.pivot(index=["Loss Type","Loss Round"], columns="Dataset", values=metric)

# Reorder columns
wide = wide.reindex(columns=col_order)

# Round numeric values
wide = wide.round(3)

# Reset index so Loss Type and Loss Round become columns again
wide = wide.reset_index()

wide[col_order] = wide[col_order].applymap(lambda x: round(x*100, 1) if pd.notnull(x) else x)

# Convert to LaTeX
latex = wide.to_latex(
    index=False,
    escape=True,
    na_rep="-",
    float_format=lambda x: f"{x:.1f}",  # ensure one decimal
    column_format="ll" + "c"*len(col_order),
    bold_rows=False,
    longtable=False,
    multicolumn=False,
    multicolumn_format="c",
)
print(latex)
# save the latex
with open(os.path.join(result_dir, 'results.tex'), 'w') as f:
    f.write(latex)
