import os
import os.path as osp
import torch, torchvision
import random
import numpy as np
import PIL.Image as PImage, PIL.ImageDraw as PImageDraw
setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)     # disable default parameter init for faster speed
setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)  # disable default parameter init for faster speed
from models import VQVAE, build_vae_var
from attacks import apply_attack
from torch.utils.data import DataLoader


import torch

n_img = 1000
batch_size = 8
iters_optim = 100

gpu_id = 0
torch.cuda.set_device(gpu_id)
print(torch.cuda.current_device())
device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")

data_dir_ae = "[VAR_MODEL_PATH]"

data_dir_ar = "[VAR_MODEL_PATH]"

MODEL_DEPTH = 16    # TODO: =====> please specify MODEL_DEPTH <=====
assert MODEL_DEPTH in {16, 20, 24, 30}


# download checkpoint
hf_home = 'https://huggingface.co/FoundationVision/var/resolve/main'
vae_ckpt, var_ckpt = 'vae_ch160v4096z32.pth', f'var_d{MODEL_DEPTH}.pth'
# vae_ckpt, var_ckpt = 'vqvae_finetuned_fmap_lpips0.0_mse_img0.0_mse_feat1.0_steps20000_encoder_lr5e-05_bs16_with_dataset.pth', f'var_d{MODEL_DEPTH}.pth'
# vae_ckpt, var_ckpt = 'vqvae_finetuned_fmap_lpips0.0_mse_img0.0_mse_feat1.0_keep_real10.0_steps20000_encoder_lr5e-05_bs8_with_dataset.pth', f'var_d{MODEL_DEPTH}.pth'

vae_ckpt_path, var_ckpt_path = osp.join(data_dir_ae, vae_ckpt), osp.join(data_dir_ar, var_ckpt)
if not osp.exists(vae_ckpt_path): 
    os.system(f'wget -P {data_dir_ae} {hf_home}/{vae_ckpt}')
if not osp.exists(var_ckpt_path): 
    os.system(f'wget -P {data_dir_ar} {hf_home}/{var_ckpt}')

# build vae, var
patch_nums = [1, 2, 3, 4, 5, 6, 8, 10, 13, 16]
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
if 'vae' not in globals() or 'var' not in globals():
    vae, var = build_vae_var(
        V=4096, Cvae=32, ch=160, share_quant_resi=4,    # hard-coded VQVAE hyperparameters
        device=device, patch_nums=patch_nums,
        num_classes=1000, depth=MODEL_DEPTH, shared_aln=False,
    )

# load checkpoints
vae.load_state_dict(torch.load(vae_ckpt_path, map_location='cpu'), strict=True)
var.load_state_dict(torch.load(var_ckpt_path, map_location='cpu'), strict=True)
vae.eval(), var.eval()
for p in vae.parameters(): p.requires_grad_(False)
for p in var.parameters(): p.requires_grad_(False)
print(f'prepare finished.')



'''
configs for attack
'''
attack_type = "none"
range_map = {
    "none": [0.0],  # no attack
    "noise": [0.1], #[0.0, 0.05, 0.1, 0.15, 0.2],
    "gauss": [7], #[1, 3, 5, 7, 9, 11, 13, 15, 17, 19],
    "crop": [0.5], #[1, 0.95, 0.9, 0.85, 0.8, 0.75, 0.7, 0.65, 0.6, 0.55, 0.5],
    "jpeg": [50], #[100, 90, 80, 70, 60, 50, 40, 30, 20, 10],
    "rotate": [(-5,5)],
    "CtrlRegen": [0.1], #[0, 0.1, 0.2, 0.3, 0.4, 0.5]
}

args_map = {
    "none" : "none",
    "noise" : "variance",
    "gauss" : "kernel_size",
    "crop" : "crop_ratio",
    "jpeg" : "final_quality",
    "rotate" : "degrees",
    "CtrlRegen" : "ctrl_regen_steps"
}

import argparse
parser = argparse.ArgumentParser()
args, unknown = parser.parse_known_args()
args.num_samples = n_img
args.variance = 0.1  # default noise variance
attack_strength = range_map[attack_type][0]  # default attack strength
args.__dict__[args_map[attack_type]] = attack_strength


# compare how many tokens match for the two token maps on each scale
import torch.nn.functional as F
def compare_token_maps(map1, map2, return_per_scale=False):
    total_equal = []
    total_elements = []
    overlapping_ratio = []
    
    for tensor_list1, tensor_list2 in zip(map1, map2):
        total_equal_scale = []
        total_elements_scale = []
        if tensor_list1.shape != tensor_list2.shape:
            raise ValueError(f"Tensors have different shapes")
        
        # Create a boolean mask of equal elements
        equal_mask = tensor_list1.cpu() == tensor_list2.cpu() # Bl (binary)

        # Count equal elements
        equal_count = equal_mask.sum(1) # B
        element_count = torch.Tensor([tensor_list1[i].numel() for i in range(len(tensor_list1))]) # B

        total_equal_scale.append(equal_count)
        total_elements_scale.append(element_count)
        overlapping_ratio.append(equal_count / element_count)

    if return_per_scale:
        return total_equal, total_elements, overlapping_ratio
    else:
        return [equal.sum() for equal in total_equal], [element.sum() for element in total_elements], [ratio.mean() for ratio in overlapping_ratio]

def compare_embeddings(embed_scale1, embed_scale2): #List[BhwC]
    error_scales = []
    for embed1, embed2 in zip(embed_scale1, embed_scale2):
        error = F.mse_loss(embed1, embed2, reduction='none')
        error = error.mean(dim=(1,2,3)).cpu()
        error_scales.append(error)
    return error_scales

def img_to_reconstructed_img_with_latent_optim(original_B3HW, lr=1e-2, iters=iters_optim):
    f = vae.quant_conv(vae.encoder(original_B3HW.clone()))
    fhat = vae.quantize.f_to_idxBl_or_fhat(f, to_fhat=True)[-1]
    fhat_optim = torch.nn.Parameter(fhat.clone().detach()).cuda()
    optimizer = torch.optim.Adam([fhat_optim], lr=lr)
    for i in range(iters):
        optimizer.zero_grad()
        rec_gen_img = vae.decoder(vae.post_quant_conv(fhat_optim)).clamp_(-1, 1)
        loss = F.mse_loss(rec_gen_img, original_B3HW.clone())
        loss.backward()
        optimizer.step()

        if i%50==0:
            for g in optimizer.param_groups:
                g['lr'] = g['lr']*0.5
    return rec_gen_img, fhat_optim

'''
Define the losses with the optimized token search
'''
def calculate_loss_batch_latent_tracer(all_dataset_results, dataset_name, original_B3HW, display_img=False, original_idxBl=None):

    # first reconstruction
    recon_img, fhat = img_to_reconstructed_img_with_latent_optim(original_B3HW.clone().mul_(2).add_(-1).float())

    recon_img_show = recon_img.clone().add_(1).mul_(0.5)

    recon_img_mse = F.mse_loss(recon_img.clone().detach(), original_B3HW.clone().mul_(2).add_(-1).float(), reduction='none').mean(dim=(1,2,3)).detach().cpu()

    # second reconstruction
    # recon_img_2nd, fhat_2nd = img_to_reconstructed_img_with_latent_optim(recon_img)
    # recon_img_2nd_show = recon_img_2nd.clone().add_(1).mul_(0.5)

    all_dataset_results["rec_optim"]["1st"][dataset_name].extend(recon_img_mse)

    # display the images
    if display_img:
        chw = torchvision.utils.make_grid(original_B3HW, nrow=8, padding=0, pad_value=1.0)
        chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()
        chw = PImage.fromarray(chw.astype(np.uint8))
        chw.show()

        chw = torchvision.utils.make_grid(recon_img_show, nrow=8, padding=0, pad_value=1.0)
        chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()
        chw = PImage.fromarray(chw.astype(np.uint8))
        chw.show()

    return all_dataset_results

# initialize the loss results
all_dataset_results = {
    "rec_optim": {
        "1st": {},
        # "2nd": {},
        # "ratio": {},
        # "ratio_calibrated": {},
    }
}

# calculate all the losses for all given datasets
dataset_name_image_path = {
    "LAION": "PATH_TO_LAION_SUBSET",   # TODO: =====> please specify the path to your LAION subset <=====
    "MS-COCO": "PATH_TO_MSCOCO_SUBSET", # TODO: =====> please specify the path to your MS-COCO subset <=====
    "ImageNet (val)": "PATH_TO_IMAGENET_VAL_SUBSET", # TODO: =====> please specify the path to your ImageNet validation subset <=====
    "ImageNet (train)": "PATH_TO_IMAGENET_TRAIN_SUBSET", # TODO: =====> please specify the path to your ImageNet training subset <=====
    "RAR Generated": "PATH_TO_RAR_GENERATED_SUBSET", # TODO: =====> please specify the path to your RAR generated images subset <=====
    "VAR Generated": "PATH_TO_VAR_GENERATED_SUBSET", # TODO: =====> please specify the path to your VAR generated images subset <=====
    "LlamaGen Generated": "PATH_TO_LLAMAGEN_GENERATED_SUBSET", # TODO: =====> please specify the path to your LlamaGen generated images subset <=====
    "Taming Generated": "PATH_TO_TAMING_GENERATED_SUBSET", # TODO: =====> please specify the path to your Taming generated images subset <=====
    "Infinity Generated": "PATH_TO_INFINITY_GENERATED_SUBSET", # TODO: =====> please specify the path to your Infinity generated images subset <=====
}

var_dataset_name = "VAR Generated"

for dataset_name in dataset_name_image_path.keys():
    # dataset_names.append(dataset_name)
    for loss_type in all_dataset_results.keys():
        for loss_round in all_dataset_results[loss_type].keys():
            all_dataset_results[loss_type][loss_round][dataset_name] = []

    image_path = dataset_name_image_path[dataset_name]
    print(f'reading from {image_path}', flush=True)
    load_token_map = (dataset_name == var_dataset_name)
    dataset = apply_attack(img_path=image_path, attack=attack_type, load_token_map=load_token_map, args=args)
    print(f'dataset length: {len(dataset)}', flush=True)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    for i, batch in enumerate(dataloader):
        image_batch, token_batch = batch
        original_B3HW, original_idxBl = image_batch, token_batch
        original_B3HW = original_B3HW.cuda()
        # the main losses
        # display_img = True if i==0 else False
        display_img = False
        # the losses with the optim
        # print(f'with optim search')
        if dataset_name == var_dataset_name:
            original_idxBl = [idxBl.cuda() for idxBl in original_idxBl]
            all_dataset_results = calculate_loss_batch_latent_tracer(all_dataset_results, dataset_name, original_B3HW, display_img=display_img, original_idxBl=original_idxBl)
        else:
            all_dataset_results = calculate_loss_batch_latent_tracer(all_dataset_results, dataset_name, original_B3HW, display_img=display_img, original_idxBl=None)
        print(i, flush=True)

'''
Evaluate
'''
# visualize the probability distribution function of codebook losses for both real and generated images with seaborn
from utils.plot import plot_multi_pdf
from sklearn import metrics
import pandas as pd

def evaluate(rar_scores, other_scores):
    all_labels = np.concatenate([np.zeros(len(other_scores)), np.ones(len(rar_scores))])
    all_scores = np.concatenate([other_scores, rar_scores])
    all_scores_inverted = -all_scores
    # all_labels = np.concatenate([np.zeros(len(rar_scores)), np.ones(len(other_scores))])
    # all_scores = np.concatenate([rar_scores, other_scores])

    # fpr, tpr, threshold = metrics.roc_curve(all_labels, all_scores)
    fpr, tpr, threshold_inverted = metrics.roc_curve(all_labels, all_scores_inverted)
    auc = metrics.auc(fpr, tpr)
    acc = np.max(1 - (fpr + (1 - tpr))/2)

    idx = np.where(fpr < 0.01)[0][-1]
    threshold_at_1fpr = -threshold_inverted[idx]
    tpr_at_1fpr = tpr[idx]

    return threshold_at_1fpr, auc, acc, tpr_at_1fpr


print(f'Results for setting: VAR-d[{MODEL_DEPTH}], {attack_type}({args_map[attack_type]}={args.__dict__[args_map[attack_type]]}), {n_img}imgs, {iters_optim}iters', flush=True)
result_dir = f'results/latenttracer/VAR-d{MODEL_DEPTH}/{attack_type}({args_map[attack_type]}={args.__dict__[args_map[attack_type]]})_{n_img}imgs_{iters_optim}iters'
os.makedirs(result_dir, exist_ok=True)

results_table = []

for loss_type, loss_type_results in all_dataset_results.items(): # codebook, rec
    if "overlapping" in loss_type or "full" in loss_type or "current" in loss_type:
        continue
    for loss_round, loss_round_results in loss_type_results.items(): # 1st, 2nd, ratio, ratio_calibrated
        print(f"[{loss_type}][{loss_round}]")
        data_list, label_list = [], [f"{dataset_name}" for dataset_name in dataset_name_image_path.keys()]
        for dataset_name, dataset_results in loss_round_results.items(): # Real, Real (train), VAR Generated, RAR Generated
            print(f"{dataset_name} dataset size {len(dataset_results)}")
            data_list.append(np.array(dataset_results))
        # plot
        title = f'{loss_type}_{loss_round}_mse'
        xlabel = 'Loss'
        ylabel = 'PDF'
        plot_multi_pdf(data_list, label_list, title, xlabel, ylabel, save_dir=result_dir)
        
        # quantitative
        results_table_single_losses = []
        for dataset_name, dataset_results in loss_round_results.items():  # Real, VAR Generated, etc.
            if dataset_name != var_dataset_name:
                threshold, auc, acc, tpr1 = evaluate(
                    np.array(loss_round_results[var_dataset_name]),
                    np.array(dataset_results)
                )
                results_table.append({
                    "Loss Type": loss_type,
                    "Loss Round": loss_round,
                    "Comparison": f"VAR Generated vs {dataset_name}",
                    "Threshold": round(threshold, 4),
                    "AUC": round(auc, 4),
                    "ACC": round(acc, 4),
                    "TPR@1%FPR": round(tpr1, 4)
                })
                results_table_single_losses.append({
                    "Comparison": f"VAR Generated vs {dataset_name}",
                    "Threshold": round(threshold, 4),
                    "AUC": round(auc, 4),
                    "ACC": round(acc, 4),
                    "TPR@1%FPR": round(tpr1, 4)
                })
        df_single = pd.DataFrame(results_table_single_losses)
        print(df_single.to_string(index=False), flush=True)
        print('\n')
import pandas as pd
df = pd.DataFrame(results_table)
print(df.to_string(index=False))
# save the results
df.to_csv(os.path.join(result_dir, 'results.csv'), index=False)

import re
import pandas as pd

# --- your input DataFrame ---
# df = ...  # must contain columns: ["Loss Type","Loss Round","Comparison","Threshold","AUC","ACC","TPR@1%FPR"]

# Choose which metric to display in the table
metric = "TPR@1%FPR"  # or "ACC", "TPR@1%FPR", "Threshold"

# Desired column order
col_order = [
    "ImageNet (train)",
    "ImageNet (val)",
    "LAION",
    "MS-COCO",
    "LlamaGen",
    "RAR",
    "Taming",
    "VAR",
    "Infinity",
]

aliases = {
    "imagenet (train)": "ImageNet (train)",
    "imagenet (val)": "ImageNet (val)",
    "imagenet": "ImageNet",  # fallback, usually won't be used
    "laion": "LAION",
    "ms-coco": "MS-COCO",
    "llamagen": "LlamaGen",
    "rar": "RAR",
    "taming": "Taming",
    "var": "VAR",
    "infinity": "Infinity",
}

def extract_dataset(x: str) -> str:
    """Extract the dataset name from the Comparison string."""
    m = re.search(r"\bvs\b\s+(.*)$", str(x), flags=re.IGNORECASE)
    if not m:
        return ""
    right = m.group(1)
    right = re.sub(r"\bGenerated\b", "", right, flags=re.IGNORECASE).strip()
    # keep words like "ImageNet (val)" intact
    key = right.lower().strip()
    return aliases.get(key, right)

tmp = df.copy()
tmp["Dataset"] = tmp["Comparison"].apply(extract_dataset)

# keep only rows for the datasets we care about
tmp = tmp[tmp["Dataset"].isin(col_order)]

# remove duplicates if any
tmp = tmp.drop_duplicates(subset=["Loss Type","Loss Round","Dataset"], keep="last")

# Pivot to wide format
wide = tmp.pivot(index=["Loss Type","Loss Round"], columns="Dataset", values=metric)

# Reorder columns
wide = wide.reindex(columns=col_order)

# Round numeric values
wide = wide.round(3)

# Reset index so Loss Type and Loss Round become columns again
wide = wide.reset_index()

wide[col_order] = wide[col_order].applymap(lambda x: round(x*100, 1) if pd.notnull(x) else x)

# Convert to LaTeX
latex = wide.to_latex(
    index=False,
    escape=True,
    na_rep="-",
    float_format=lambda x: f"{x:.1f}",  # ensure one decimal
    column_format="ll" + "c"*len(col_order),
    bold_rows=False,
    longtable=False,
    multicolumn=False,
    multicolumn_format="c",
)
print(latex)
# save the latex
with open(os.path.join(result_dir, 'results.tex'), 'w') as f:
    f.write(latex)