import random
import sys
import os
sys.path.append(os.path.dirname(os.getcwd()))  # Add parent directory to path
import torch
import numpy as np
from run_infinity import *

# parse arguments
import argparse


from tqdm import tqdm
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def tokenize_and_reconstruct(vae, original_image, scale_schedule, display_img=False, resize=1024, img_rec_quant_1=None, img_rec_no_quant_1=None):

    if img_rec_quant_1 is not None:
        _, hidden_states, _, _ = vae.encode_with_internals(img_rec_quant_1.to(device), scale_schedule)
        encoded_tokens, hidden_states, quantized_states, codebook_loss = vae.encode_with_internals(img_rec_quant_1.to(device), scale_schedule)
        img_rec_quant = vae.decode(quantized_states)
        sorcery = vae.decode(hidden_states)
        encoded_tokens, hidden_states, quantized_states, _ = vae.encode_with_internals(img_rec_no_quant_1.to(device), scale_schedule)
        img_rec_no_quant = vae.decode(hidden_states)
        img_rec_no_quant_loss_mse = torch.mean((img_rec_no_quant.cpu() - img_rec_no_quant_1.cpu()) ** 2, dim=[1, 2, 3])
        img_rec_quant_loss_mse = torch.mean((img_rec_quant.cpu() - img_rec_quant_1.cpu()) ** 2, dim=[1, 2, 3])
        sorcery = torch.mean((sorcery.cpu() - img_rec_no_quant_1.cpu()) ** 2, dim=[1, 2, 3])
    else:
        image = transform(original_image, resize, resize)  # If resize differfs from 1024, it will do a resize attack as it upsizes afterwards
        encoded_tokens, hidden_states, quantized_states, codebook_loss = vae.encode_with_internals(image.unsqueeze(0).to(device), scale_schedule)
        img_rec_quant = vae.decode(quantized_states)
        img_rec_no_quant = vae.decode(hidden_states)
        sorcery = None
        img_rec_no_quant_loss_mse = torch.mean((img_rec_no_quant.cpu() - image.cpu()) ** 2, dim=[1, 2, 3])
        img_rec_quant_loss_mse = torch.mean((img_rec_quant.cpu() - image.cpu()) ** 2, dim=[1, 2, 3])
    return img_rec_quant, img_rec_no_quant, img_rec_quant_loss_mse, img_rec_no_quant_loss_mse, codebook_loss, sorcery





# calculate losses for any image dataset
def calculate_losses_dataset(vae, dataset, dataset_name, scale_schedule, get_overlap=False,resize=1024):
    codebook_loss_mses, rec_quant_loss_mses, rec_no_quant_loss_mses = [], [], []
    codebook_loss_mses_double, rec_quant_loss_mses_double, rec_no_quant_loss_mses_double = [], [], []
    codebook_loss_mses_double_ratio, rec_quant_loss_mses_double_ratio, rec_no_quant_loss_mses_double_ratio = [], [], []
    sorcery_loss, sorcery_loss_ratio_list = [], []
    overlapping_ratios = []
    main_combined, sorcery_combined = [], []

    for i, (image) in tqdm(enumerate(dataset)):
        # first reconstruction
        try:
            image = Image.open(image)
            if image.mode != 'RGB':
                image = image.convert('RGB')
            img_rec_quant1, img_rec_no_quant1, img_rec_quant_mse, img_rec_no_quant_mse, codebook_loss_mse,_ = tokenize_and_reconstruct(vae, image, scale_schedule, resize=resize, display_img=False)
            codebook_loss_mses.append(codebook_loss_mse.item())
            rec_quant_loss_mses.append(img_rec_quant_mse.item())
            rec_no_quant_loss_mses.append(img_rec_no_quant_mse.item())
            
            img_rec_quant2, img_rec_no_quant2,  img_rec_quant_mse_double, img_rec_no_quant_mse_double, codebook_loss_mse_double, sorcery = tokenize_and_reconstruct(vae, None, scale_schedule, resize=resize,display_img=False, img_rec_quant_1=img_rec_quant1, img_rec_no_quant_1=img_rec_no_quant1)
            codebook_loss_mses_double.append(codebook_loss_mse_double.item())
            rec_no_quant_loss_mses_double.append(img_rec_no_quant_mse_double.item())
            rec_quant_loss_mses_double.append(img_rec_quant_mse_double.item())
            sorcery_loss.append(sorcery.item()) 
            sorcery_loss_ratio = img_rec_no_quant_mse/sorcery
            sorcery_loss_ratio_list.append(sorcery_loss_ratio.item())
            codebook_loss_mse_double_ratio = codebook_loss_mse / codebook_loss_mse_double if codebook_loss_mse != 0 else 0
            img_rec_quant_loss_mse_double_ratio = img_rec_quant_mse / img_rec_quant_mse_double if img_rec_quant_mse != 0 else 0
            img_rec_no_quant_loss_mse_double_ratio = img_rec_no_quant_mse / img_rec_no_quant_mse_double if img_rec_no_quant_mse != 0 else 0
            codebook_loss_mses_double_ratio.append(codebook_loss_mse_double_ratio.item())
            rec_quant_loss_mses_double_ratio.append(img_rec_quant_loss_mse_double_ratio.item())
            rec_no_quant_loss_mses_double_ratio.append(img_rec_no_quant_loss_mse_double_ratio.item())
            main_combined.append(img_rec_no_quant_loss_mse_double_ratio.item() * codebook_loss_mse.item())
            sorcery_combined.append(sorcery_loss_ratio.item() * codebook_loss_mse.item())
        except Exception as e:
            print(f"Error processing image {i} in {dataset_name}: {e}")
            continue
    if get_overlap:
        print(f"Average overlapping ratio for {dataset_name} images: {np.mean(overlapping_ratios)}")
    return codebook_loss_mses, codebook_loss_mses_double, codebook_loss_mses_double_ratio, rec_quant_loss_mses, rec_quant_loss_mses_double, rec_quant_loss_mses_double_ratio, rec_no_quant_loss_mses, rec_no_quant_loss_mses_double, rec_no_quant_loss_mses_double_ratio, sorcery_loss, sorcery_loss_ratio_list, sorcery_combined, main_combined, img_rec_quant1, img_rec_no_quant1, img_rec_quant2, img_rec_no_quant2


# List dir but with absolute path
import numpy as np
from sklearn import metrics

def calc_metrics(data_dict):

    # Prepare data
    datasets = list(data_dict.keys())
    data_arrays = [np.array(data_dict[key]) for key in datasets]
    # Calculate TPR@1%FPR for each dataset vs the last dataset
    
    
    if len(datasets) > 1:
        reference_data = data_arrays[-1]  # Last dataset as reference
        reference_name = datasets[-1]
        
        for i, (name, data) in enumerate(zip(datasets[:-1], data_arrays[:-1])):

            all_labels = np.concatenate([np.zeros(len(data)), np.ones(len(reference_data))])
            all_scores = np.concatenate([data, reference_data])
            all_scores_inverted = -all_scores

            # fpr, tpr, threshold = metrics.roc_curve(all_labels, all_scores)
            fpr, tpr, threshold_inverted = metrics.roc_curve(all_labels, all_scores_inverted)
            idx = np.where(fpr < 0.01)[0][-1]
            threshold_at_1fpr = -threshold_inverted[idx]
            tpr_at_1fpr = tpr[idx]


    stats_dict = {}

    
    for name, data in zip(datasets, data_arrays):
        stats = {
            'tpr_at_1fpr': tpr_at_1fpr,
            'mean': np.mean(data),
            'std': np.std(data),
            'min': np.min(data),
            'max': np.max(data),
            'median': np.median(data),
            'q25': np.percentile(data, 25),
            'q75': np.percentile(data, 75)
        }
        stats_dict[name] = stats
        
        
    
    return stats_dict


def main():
    
    parser = argparse.ArgumentParser()
    add_common_arguments(parser)
    parser.add_argument("--ft_path", type=str, default="", help="Path to the inversed decoder")    
    parser.add_argument('--resize', type=int, default=1024, help='Resize dimension for images')
    parser.add_argument('--dataset_config', type=str, default='dataset_config.json', help='Path to dataset configuration JSON file')
    args = parser.parse_args()
    
    seed = args.seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    

    
    import json
    try:
        with open(args.dataset_config, 'r') as f:
            dataset_name_image_path = json.load(f)
        print(f"Loaded dataset configuration from: {args.dataset_config}")
    except FileNotFoundError:
        print(f"Dataset config file '{args.dataset_config}' not found. Please create it or specify a valid path.")
        return
    except json.JSONDecodeError as e:
        print(f"Error parsing JSON config file '{args.dataset_config}': {e}")
        return
    vae = load_visual_tokenizer(args)
    if args.ft_path:
        checkpoint = torch.load(args.ft_path)
        vae.encoder.load_state_dict(checkpoint['encoder_state_dict'])
        print(f"Loaded finetuned encoder weights at {args.ft_path}.")
    else:
        print("Using original encoder weights.")
    if args.resize != 1024:
        print(f"Using resize attack to {args.resize} instead of 1024.")

    h_div_w = 1/1 # aspect ratio, height:width

    h_div_w_template_ = h_div_w_templates[np.argmin(np.abs(h_div_w_templates-h_div_w))]
    scale_schedule = dynamic_resolution_h_w[h_div_w_template_][args.pn]['scales']
    scale_schedule = [(1, h, w) for (_, h, w) in scale_schedule]




    extended_map = {
        0: "Codebook Loss MSE",
        1: "Codebook Loss MSE Double",
        2: "Codebook Loss MSE Double Ratio",
        3: "Reconstruction Quant Loss MSE",
        4: "Reconstruction Quant Loss MSE Double",
        5: "Reconstruction Quant Loss MSE Double Ratio",
        6: "Reconstruction No Quant Loss MSE",
        7: "Reconstruction No Quant Loss MSE Double",
        8: "Reconstruction No Quant Loss MSE Double Ratio",
        9: "Sorcery Loss",
        10: 'Sorcery Loss Ratio',
        11: "Sorcery Combined",
        12: "Main Combined"
    }


    num_samples = 1000

    gen_path = dataset_name_image_path.get("Generated")
    gen_dataset = os.listdir(gen_path)
    gen_dataset = [os.path.join(gen_path, f) for f in gen_dataset if f.endswith('.png')]
    gen_results = calculate_losses_dataset(vae, gen_dataset[:num_samples], "gen", scale_schedule, resize=args.resize)
    torch.save(gen_results, f'gen_results.pt')
    for mode in range(len(gen_results)):
        mean = np.mean(gen_results[mode])
        std = np.std(gen_results[mode])
        print(f"Gen - {extended_map[mode]}: Mean = {mean:.5f}, Std = {std:.5f}")

    all_results = {}
    from PIL import Image
    for dataset, path in dataset_name_image_path.items():
        print(f"Dataset: {dataset}, Path: {path}")
        test_dataset = os.listdir(path)
        valid_images = []
        for f in test_dataset:
            if f.endswith('.png') or f.endswith('.jpeg') or f.endswith('.JPEG') or f.endswith('.JPG') or f.endswith('.jpg'):
                img_path = os.path.join(path, f)
                try:
                    with Image.open(img_path) as im:
                        im.verify()
                    valid_images.append(img_path)
                except Exception as e:
                    print(f"Skipping file {img_path}: {e}")
        test_dataset = valid_images[:num_samples]

        test_results = calculate_losses_dataset(vae, test_dataset, dataset, scale_schedule, resize=args.resize)
        torch.save(test_results, f'test_results_{dataset}.pt')
        print(f"Test results for {dataset}: {test_results}")

        all_results[dataset] = {}
        modes = range(len(test_results))
        for mode in modes:
            all_results[dataset][extended_map[mode]] = {}
            data_dict = {
                dataset: test_results[mode],
                'Generated': gen_results[mode],
            }

            stats = calc_metrics(
                data_dict=data_dict,
            )

            all_results[dataset][extended_map[mode]] = stats[dataset]["tpr_at_1fpr"].item()
            # Also print the mean and std
            
            mean = np.mean(test_results[mode])
            std = np.std(test_results[mode])
            print(f"{dataset} - {extended_map[mode]}: Mean = {mean:.5f}, Std = {std:.5f}")
        print(all_results[dataset])

    import pandas as pd

    df = pd.DataFrame(all_results)  # Datasets as columns, metrics as rows

    latex_df = df.applymap(lambda x: f"{x*100:.1f}" if pd.notnull(x) else "--")
    print(latex_df.to_latex(escape=False, na_rep="--"))

if __name__ == "__main__":
    main()