import argparse
import os
import random
import sys

import json
import pandas as pd

import numpy as np
from sklearn import metrics
from PIL import Image
import torch
import torch.distributed as dist
from deps.taming.util import get_ckpt_path
from huggingface_hub import hf_hub_download
from loguru import logger
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torchvision import transforms
from tqdm import tqdm
from wmar.augmentations.geometric import Identity, Rotate, UpperLeftCropWithPadBack
from wmar.augmentations.valuemetric import JPEG, Brightness, GaussianBlur, GaussianNoise
from wmar.models.armm_wrapper import load_model
from wmar.utils.distributed import (
    average_metrics,
    get_rank,
    get_world_size,
    init_distributed_mode,
    is_distributed,
    is_main_process,
)
from wmar.utils.tensorboard import CustomTensorboardWriter
from wmar.utils.utils import (
    CodesOnDiskDataset,
    calculate_gradient_norm,
    compute_and_save_delta,
    get_decoder_dist,
    get_encoder_dist,
    get_model_property,
)
from attacks import apply_attack
from torch.nn import functional as F

from utils.robustness_utils import *


def extended_analysis(watermarked_folder_path, clean_folder_path, vqgan, args):
    
    file_path = f"{args.outdir}/robustness_results_extended_{args.model}.csv"
    tex_path = f"{args.outdir}/robustness_results_extended_{args.model}.tex"

    if os.path.exists(file_path):
        os.remove(file_path)
    if os.path.exists(tex_path):
        os.remove(tex_path)

    
    all_attacks = ["gauss", "noise" , "crop", "jpeg", "resize", "brightness", "saturation", "contrast"]  #, 
    range_map = {
        "gauss": [1, 3, 5, 7, 9, 11, 13, 15, 17, 19],
        "noise": [0.0, 0.05, 0.1, 0.15, 0.2],
        "resize" : [0.95, 0.9, 0.85, 0.8, 0.75, 0.7, 0.65, 0.6, 0.55, 0.5],
        "crop": [1, 0.95, 0.9, 0.85, 0.8, 0.75, 0.7, 0.65, 0.6, 0.55, 0.5],
        "jpeg": [100, 90, 80, 70, 60, 50, 40, 30, 20, 10],
        "brightness" : [1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0],
        "saturation" : [1.0, 1.2, 1.4, 1.6, 1.8, 2.0, 2.2, 2.4, 2.6, 2.8, 3.0],
        "contrast" : [1.0, 1.2, 1.4, 1.6, 1.8, 2.0, 2.2, 2.4, 2.6, 2.8, 3.0],
    }


    args_map = {
        "gauss" : "kernel_size",
        "noise" : "variance",
        "crop" : "crop_ratio",
        "jpeg" : "final_quality",
        "brightness" : "brightness",
        "saturation" : "saturation",
        "contrast" : "contrast",
        "resize" : "resize_ratio"
    }
    all_rows = []

    for attack in tqdm(all_attacks, desc="Evaluated Attacks"):
    
        for attack_strength in tqdm(range_map[attack], desc="Evaluated Attack Strength"):
            
            real_dataset = apply_attack(clean_folder_path, attack, False, args)    
    
            real_dataset_values = calculate_losses_dataset(real_dataset, f"{clean_folder_path}", vqgan, args.model, get_overlap=False)


            args.__dict__[args_map[attack]] = attack_strength

            attacked_dataset = apply_attack(watermarked_folder_path, attack, False, args)
        
            attacked_dataset[0][0].save(f"{args.outdir}/attacks/{attack}_{attack_strength}.png")
            gen_dataset_values = calculate_losses_dataset(attacked_dataset, f"{args.model} Generated", vqgan, args.model)
                
            row_dict = evaluate_real_and_gen_datasets(real_dataset_values, gen_dataset_values)        
            all_rows.append((attack, attack_strength, row_dict))

            flat_row = {"attack": attack, "strength": attack_strength}
            flat_row.update({f"{loss}_{metric}": val for (loss, metric), val in row_dict.items()})

            # Append to CSV
            df_row = pd.DataFrame([flat_row])
            if not os.path.exists(file_path):
                df_row.to_csv(file_path, index=False)
            else:
                df_row.to_csv(file_path, mode="a", header=False, index=False)


            df_flat = pd.read_csv(file_path)
            # Extract TPR@1%FPR columns
            tpr_cols = [c for c in df_flat.columns if "TPR@1%FPR" in c]
            df_tpr = df_flat[["attack", "strength"] + tpr_cols]
            df_tpr.columns = [c.replace("%", "\\%") for c in df_tpr.columns]


            latex_str = df_tpr.to_latex(
                index=False, float_format="%.3f",
                caption="TPR@1\\%FPR results by attack and strength",
                label="tab:tpr_results"
            )
            with open(tex_path, "w") as f:
                f.write(latex_str)

    print(f"Saved results to {file_path}")
           
def tpr_fpr_robustness(watermarked_folder_path, clean_folder_path, vqgan, args, attack=None):

    
    if attack: 
        if type(attack) == str: 
            all_attacks = [attack]
        else:
            all_attacks = attack 
    else:
        all_attacks = ["none", "noise" , "gauss", "jpeg", "brightness", "saturation", "contrast", "resize", "crop"] #, CtrlRegen, DiffPure
    
    os.makedirs(f"{args.outdir}/attacks", exist_ok=True)

    

        
    all_rows = []
    
    for attack in tqdm(all_attacks):

        
        real_dataset = apply_attack(clean_folder_path, attack, False, args)
    
        real_dataset[0][0].save(f"{args.outdir}/attacks/real_{attack}.png")
        real_dataset_values = calculate_losses_dataset(real_dataset, f"{clean_folder_path}", vqgan, args.model)

        
        attacked_dataset = apply_attack(watermarked_folder_path, attack, False, args)
        
        attacked_dataset[0][0].save(f"{args.outdir}/attacks/{attack}.png")
        gen_dataset_values = calculate_losses_dataset(attacked_dataset, f"{args.model} Generated", vqgan, args.model)
            
        row_dict = evaluate_real_and_gen_datasets(real_dataset_values, gen_dataset_values)        
        all_rows.append((attack, row_dict))

    df = pd.DataFrame.from_dict({atk: r for atk, r in all_rows}, orient="index")
    df.columns = pd.MultiIndex.from_tuples(df.columns, names=["loss_type", "metric"])
    df.index.name = "attack"

    print(df)

    df_flat = df.copy()
    df_flat.columns = ["{}_{}".format(loss, metric) for loss, metric in df_flat.columns]

    # Save to CSV
    file_path = f"{args.outdir}/robustness_results.csv"
    df_flat.to_csv(file_path)

    df_tpr = df.xs("TPR@1%FPR", axis=1, level="metric")

    # Convert to LaTeX table
    latex_str = df_tpr.to_latex(float_format="%.3f", caption="TPR@1\\%FPR results", label="tab:tpr_results")

    # Save to .tex file
    with open(f"{args.outdir}/robustness_results.tex", "w") as f:
        f.write(latex_str)

    print(f"Saved results to {file_path}")


def tpr_fpr_robustness_latent_tracer(watermarked_folder_path, clean_folder_path, vqgan, args, attack=None):

    
    if attack: 
        if type(attack) == str: 
            all_attacks = [attack]
        else:
            all_attacks = attack 
    else:
        all_attacks = ["none", "noise" , "gauss", "jpeg", "brightness", "saturation", "contrast", "resize", "crop"] #, CtrlRegen, DiffPure
    
    os.makedirs(f"{args.outdir}/attacks", exist_ok=True)


    
    all_rows = []
    
    for attack in tqdm(all_attacks):
        
        
        real_dataset = apply_attack(clean_folder_path, attack, False, args)
    
        real_dataset_values = []

        for image, _ in tqdm(real_dataset, desc="Real Images"):
            if args.model == "taming":
                prepocessed_image = preprocess(image)
            else: 
                prepocessed_image = preprocess_rar(image)


            optimized_image, _ = tokenize_and_reconstruct_batch_latent_optim(prepocessed_image, vqgan, args, display_img=False, use_quant=True, lr=0.01, iters=100) 

            for val in F.mse_loss(prepocessed_image, optimized_image, reduction='none').mean(dim=(1,2,3)).detach().cpu():
                real_dataset_values.append(val)
        

    
        attacked_dataset = apply_attack(watermarked_folder_path, attack, False, args)
        attacked_dataset_values = []
        for image, _ in tqdm(attacked_dataset, desc="Belonging Images"):
            if args.model == "taming":
                prepocessed_image = preprocess(image)
            else: 
                prepocessed_image = preprocess_rar(image)

            optimized_image, _ = tokenize_and_reconstruct_batch_latent_optim(prepocessed_image, vqgan, args, display_img=False, use_quant=True, lr=0.01, iters=100) 

            for val in F.mse_loss(prepocessed_image, optimized_image, reduction='none').mean(dim=(1,2,3)).detach().cpu():
                attacked_dataset_values.append(val)
        
        threshold_at_1fpr, auc, acc, tpr_at_1fpr = evaluate(attacked_dataset_values, real_dataset_values)        
        row_dict = {
            "Threshold": threshold_at_1fpr,
            "AUC": auc,
            "ACC": acc,
            "TPR@1\\%FPR": tpr_at_1fpr
        }
        all_rows.append((attack, row_dict))

    df = pd.DataFrame.from_dict({atk: r for atk, r in all_rows}, orient="index")
    df.index.name = "attack"

    print(df)


    # Save to CSV
    file_path = f"{args.outdir}/robustness_results_latent_tracer.csv"
    df.to_csv(file_path)

    df_tpr = df[["TPR@1\\%FPR"]]

    # Convert to LaTeX table
    latex_str = df_tpr.to_latex(
        float_format="%.3f",
        caption="LatentTracer TPR@1\\%FPR results",
        label="tab:tpr_results"
    )

    # Save to .tex file
    with open(f"{args.outdir}/robustness_results_latent_tracer.tex", "w") as f:
        f.write(latex_str)

    print(f"Saved results to {file_path}")




def tpr_fpr_robustness_latent_tracer_batched(watermarked_folder_path, clean_folder_path, vqgan, args, attack=None):

    
    match args.model:
        case "taming":
            transform = transforms.Compose([
                transforms.Resize((size, size)),
                transforms.ToTensor(),
                lambda x: 2.0 * x - 1.0,  # normalize to [-1, 1]
            ])
        case "rar":
            transform = transforms.Compose([
                transforms.Resize((size, size)),
                transforms.ToTensor(),
            ])
        case _:
            raise ValueError()
    
    if attack: 
        if type(attack) == str: 
            all_attacks = [attack]
        else:
            all_attacks = attack 
    else:
        all_attacks = ["none", "noise" , "gauss", "jpeg", "brightness", "saturation", "contrast", "resize", "crop"] #, CtrlRegen, DiffPure
    
    
    
    all_rows = []
    
    for attack in tqdm(all_attacks):
        
        real_dataset = apply_attack(clean_folder_path, attack, False, args)


        if real_dataset.transform:
            real_dataset.transform = transforms.Compose([
                real_dataset.transform, 
                transform
            ])
        else:
            real_dataset.transform = transform

        real_dataloader = DataLoader(real_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, collate_fn=collate_images_only)

        os.makedirs(f"{args.outdir}/attacks", exist_ok=True)

        real_dataset_values = []
        
        for batch_idx, batch_data in enumerate(tqdm(real_dataloader, desc=f"Processing Real LatentTracer")):

            
            image_batch = batch_data.to(device)

            if image_batch.ndim == 3:
                image_batch = image_batch.unsqueeze(0) 
            
            optimized_image, _ = tokenize_and_reconstruct_batch_latent_optim(image_batch, vqgan, args, display_img=False, use_quant=True, lr=0.01, iters=100) 

            for val in F.mse_loss(image_batch, optimized_image, reduction='none').mean(dim=(1,2,3)).detach().cpu():
                real_dataset_values.append(val)
        

    
        attacked_dataset = apply_attack(watermarked_folder_path, attack, False, args)
        if attacked_dataset.transform:
            attacked_dataset.transform = transforms.Compose([
                attacked_dataset.transform, 
                transform
            ])
        else:
            attacked_dataset.transform = transform
        
        dataloader = DataLoader(attacked_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, collate_fn=collate_images_only)

        attacked_dataset_values = []
        for batch_idx, batch_data in enumerate(tqdm(dataloader, desc=f"Processing Belonging with {attack} LatentTracer")):
            
            image_batch = batch_data.to(device)

            if image_batch.ndim == 3:
                image_batch = image_batch.unsqueeze(0) 
            optimized_image, _ = tokenize_and_reconstruct_batch_latent_optim(image_batch, vqgan, args, display_img=False, use_quant=True, lr=0.01, iters=100) 

            for val in F.mse_loss(image_batch, optimized_image, reduction='none').mean(dim=(1,2,3)).detach().cpu():
                attacked_dataset_values.append(val)
        
        threshold_at_1fpr, auc, acc, tpr_at_1fpr = evaluate(attacked_dataset_values, real_dataset_values)        
        row_dict = {
            "Threshold": threshold_at_1fpr,
            "AUC": auc,
            "ACC": acc,
            "TPR@1\\%FPR": tpr_at_1fpr
        }
        all_rows.append((attack, row_dict))

    df = pd.DataFrame.from_dict({atk: r for atk, r in all_rows}, orient="index")
    df.index.name = "attack"

    print(df)


    # Save to CSV
    file_path = f"{args.outdir}/robustness_results_latent_tracer.csv"
    df.to_csv(file_path)

    df_tpr = df[["TPR@1\\%FPR"]]

    # Convert to LaTeX table
    latex_str = df_tpr.to_latex(
        float_format="%.3f",
        caption="LatentTracer TPR@1\\%FPR results",
        label="tab:tpr_results"
    )

    # Save to .tex file
    with open(f"{args.outdir}/robustness_results_latent_tracer.tex", "w") as f:
        f.write(latex_str)

    print(f"Saved results to {file_path}")
    
    
def extended_analysis_latent_tracer_batched(watermarked_folder_path, clean_folder_path, vqgan, args, attack=None):
    file_path = f"{args.outdir}/robustness_results_extended_{args.model}_latent_tracer.csv"
    tex_path = f"{args.outdir}/robustness_results_extended_{args.model}_latent_tracer.tex"

    
   
    
    match args.model:
        case "taming":
            transform = transforms.Compose([
                transforms.Resize((size, size)),
                transforms.ToTensor(),
                lambda x: 2.0 * x - 1.0,  # normalize to [-1, 1]
            ])
        case "rar":
            transform = transforms.Compose([
                transforms.Resize((size, size)),
                transforms.ToTensor(),
            ])
        case _:
            raise ValueError()
        
        
    
    if attack:
        if type(attack) == str:
            all_attacks = [attack]
        elif type(attack) == list:
            all_attacks = attack
            
    else:
        if os.path.exists(file_path):
            os.remove(file_path)
        if os.path.exists(tex_path):
            os.remove(tex_path)

        all_attacks = ["gauss", "noise", "crop", "jpeg", "resize", "brightness", "saturation", "contrast"]
    
    
    
    range_map = {
        "gauss": [1, 3, 5, 7, 9, 11, 13, 15, 17, 19],
        "noise": [0.0, 0.05, 0.1, 0.15, 0.2],
        "resize": [0.95, 0.9, 0.85, 0.8, 0.75, 0.7, 0.65, 0.6, 0.55, 0.5],
        "crop": [1, 0.95, 0.9, 0.85, 0.8, 0.75, 0.7, 0.65, 0.6, 0.55, 0.5],
        "jpeg": [100, 90, 80, 70, 60, 50, 40, 30, 20, 10],
        "brightness": [1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0],
        "saturation": [1.0, 1.2, 1.4, 1.6, 1.8, 2.0, 2.2, 2.4, 2.6, 2.8, 3.0],
        "contrast": [1.0, 1.2, 1.4, 1.6, 1.8, 2.0, 2.2, 2.4, 2.6, 2.8, 3.0],
        "none" : [1.0]
    }

    args_map = {
        "gauss": "kernel_size",
        "noise": "variance",
        "crop": "crop_ratio",
        "jpeg": "final_quality",
        "brightness": "brightness",
        "saturation": "saturation",
        "contrast": "contrast",
        "resize": "resize_ratio",
        "none" : "resize_ratio"
    }
    
    all_rows = []

    for attack in tqdm(all_attacks, desc="Evaluated Attacks"):
        for attack_strength in tqdm(range_map[attack], desc="Evaluated Attack Strength"):
            args.__dict__[args_map[attack]] = attack_strength


            real_dataset = apply_attack(clean_folder_path, attack, False, args)    
            if real_dataset.transform:
                real_dataset.transform = transforms.Compose([
                    real_dataset.transform, 
                    transform
                ])
            else:
                real_dataset.transform = transform

            real_dataloader = DataLoader(real_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, collate_fn=collate_images_only)

            real_dataset_values = []


            for batch_idx, batch_data in enumerate(tqdm(real_dataloader, desc=f"Processing Clean LatentTracer")):
                
                
                image_batch = batch_data.to(device)

                if image_batch.ndim == 3:
                    image_batch = image_batch.unsqueeze(0) 
                
                optimized_image, _ = tokenize_and_reconstruct_batch_latent_optim(image_batch, vqgan, args, display_img=False, use_quant=True, lr=0.01, iters=100) 

                for val in F.mse_loss(image_batch, optimized_image, reduction='none').mean(dim=(1,2,3)).detach().cpu():
                    real_dataset_values.append(val)


            attacked_dataset = apply_attack(watermarked_folder_path, attack, False, args)
        
            if attacked_dataset.transform:
                attacked_dataset.transform = transforms.Compose([
                    attacked_dataset.transform, 
                    transform
                ])
            else:
                attacked_dataset.transform = transform
            
            # Use batched version
            dataloader = DataLoader(attacked_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, collate_fn=collate_images_only)


            attacked_dataset_values = []
            for batch_idx, batch_data in enumerate(tqdm(dataloader, desc=f"Processing Attack {attack} LatentTracer")):
               
               
                image_batch = batch_data.to(device)

                if image_batch.ndim == 3:
                    image_batch = image_batch.unsqueeze(0) 
                optimized_image, _ = tokenize_and_reconstruct_batch_latent_optim(image_batch, vqgan, args, display_img=False, use_quant=True, lr=0.01, iters=100) 

                for val in F.mse_loss(image_batch, optimized_image, reduction='none').mean(dim=(1,2,3)).detach().cpu():
                    attacked_dataset_values.append(val)
                
            print(f"Using {len(attacked_dataset_values)} number of values")
    
            threshold_at_1fpr, auc, acc, tpr_at_1fpr = evaluate(attacked_dataset_values, real_dataset_values)        
            row_dict = {
                "Threshold": threshold_at_1fpr,
                "AUC": auc,
                "ACC": acc,
                "TPR@1\\%FPR": tpr_at_1fpr
            }
            flat_row = {"attack": attack, "strength": attack_strength}
            flat_row.update({
                (metric if isinstance(metric, str) else metric[1]): val
                for metric, val in row_dict.items()
            })

            df_row = pd.DataFrame([flat_row])
            df_row = pd.DataFrame([flat_row])

            # Save row incrementally
            if not os.path.exists(file_path):
                df_row.to_csv(file_path, index=False)
            else:
                df_row.to_csv(file_path, mode="a", header=False, index=False)

            # Reload to update LaTeX each step
            df_flat = pd.read_csv(file_path)

            # Extract TPR@1%FPR only
            tpr_cols = [c for c in df_flat.columns if "TPR@1\\%FPR" in c]
            df_tpr = df_flat[["attack", "strength"] + tpr_cols]

            # Escape % for LaTeX
            latex_str = df_tpr.to_latex(
                index=False,
                float_format="%.3f",
                caption="LatentTracer TPR@1\\%FPR results by attack and strength",
                label="tab:tpr_results"
            )

            with open(tex_path, "w") as f:
                f.write(latex_str)

        print(f"Saved results for {attack} (strength={attack_strength})")

    print(f"Saved results to {file_path}")



def tpr_fpr_robustness_batched(watermarked_folder_path, clean_folder_path, vqgan, args, attack=None):
   
    if attack: 
        if type(attack) == str: 
            all_attacks = [attack]
        else:
            all_attacks = attack 
    else:
        all_attacks = ["none", "noise", "gauss", "jpeg", "brightness", "saturation", "contrast", "resize", "crop"]
    
    
    all_rows = []
    
    for attack in tqdm(all_attacks):
        
        real_dataset = apply_attack(clean_folder_path, attack, False, args)
    
        real_dataset[0][0].save(f"{args.outdir}/attacks/real_{attack}.png")
        
        # Use batched version here
        real_dataset_values = calculate_losses_dataset_batched(
            real_dataset, f"{clean_folder_path}", vqgan, args.model, args=args, 
            batch_size=args.batch_size, get_overlap=False
        )

        
        attacked_dataset = apply_attack(watermarked_folder_path, attack, False, args)
        
        attacked_dataset[0][0].save(f"{args.outdir}/attacks/{attack}.png")
        
        # Use batched version here
        gen_dataset_values = calculate_losses_dataset_batched(
            attacked_dataset, f"{args.model} Generated", vqgan, args.model, args=args,
            batch_size=args.batch_size
        )
            
        row_dict = evaluate_real_and_gen_datasets(real_dataset_values, gen_dataset_values)        
        all_rows.append((attack, row_dict))

    

    # Rest of the function remains the same...
    df = pd.DataFrame.from_dict({atk: r for atk, r in all_rows}, orient="index")
    df.columns = pd.MultiIndex.from_tuples(df.columns, names=["loss_type", "metric"])
    df.index.name = "attack"

    print(df)

    df_flat = df.copy()
    df_flat.columns = ["{}_{}".format(loss, metric) for loss, metric in df_flat.columns]

    # Save to CSV
    file_path = f"{args.outdir}/robustness_results.csv"
    df_flat.to_csv(file_path)

    df_tpr = df.xs("TPR@1%FPR", axis=1, level="metric")

    # Convert to LaTeX table
    latex_str = df_tpr.to_latex(float_format="%.3f", caption="TPR@1\\%FPR results", label="tab:tpr_results")

    # Save to .tex file
    with open(f"{args.outdir}/robustness_results.tex", "w") as f:
        f.write(latex_str)

    print(f"Saved results to {file_path}")

def extended_analysis_batched(watermarked_folder_path, clean_folder_path, vqgan, args):
    file_path = f"{args.outdir}/robustness_results_extended_{args.model}.csv"
    tex_path = f"{args.outdir}/robustness_results_extended_{args.model}.tex"

    if os.path.exists(file_path):
        os.remove(file_path)
    if os.path.exists(tex_path):
        os.remove(tex_path)

    all_attacks = ["gauss", "noise", "crop", "jpeg", "resize", "brightness", "saturation", "contrast"]
    range_map = {
        "gauss": [1, 3, 5, 7, 9, 11, 13, 15, 17, 19],
        "noise": [0.0, 0.05, 0.1, 0.15, 0.2],
        "resize": [0.95, 0.9, 0.85, 0.8, 0.75, 0.7, 0.65, 0.6, 0.55, 0.5],
        "crop": [1, 0.95, 0.9, 0.85, 0.8, 0.75, 0.7, 0.65, 0.6, 0.55, 0.5],
        "jpeg": [100, 90, 80, 70, 60, 50, 40, 30, 20, 10],
        "brightness": [1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0],
        "saturation": [1.0, 1.2, 1.4, 1.6, 1.8, 2.0, 2.2, 2.4, 2.6, 2.8, 3.0],
        "contrast": [1.0, 1.2, 1.4, 1.6, 1.8, 2.0, 2.2, 2.4, 2.6, 2.8, 3.0],
        "none" : [1.0]
    }

    args_map = {
        "gauss": "kernel_size",
        "noise": "variance",
        "crop": "crop_ratio",
        "jpeg": "final_quality",
        "brightness": "brightness",
        "saturation": "saturation",
        "contrast": "contrast",
        "resize": "resize_ratio",
        "none" : "resize_ratio"
    }
    
    all_rows = []

    for attack in tqdm(all_attacks, desc="Evaluated Attacks"):
        for attack_strength in tqdm(range_map[attack], desc="Evaluated Attack Strength"):
            args.__dict__[args_map[attack]] = attack_strength

            real_dataset = apply_attack(clean_folder_path, attack, False, args)    
    
            # Use batched version
            real_dataset_values = calculate_losses_dataset_batched(
                real_dataset, f"{clean_folder_path}", vqgan, args.model, args=args, 
                batch_size=args.batch_size, get_overlap=False
            )


            attacked_dataset = apply_attack(watermarked_folder_path, attack, False, args)
        
            attacked_dataset[0][0].save(f"{args.outdir}/attacks/{attack}_{attack_strength}.png")
            
            # Use batched version
            gen_dataset_values = calculate_losses_dataset_batched(
                attacked_dataset, f"{args.model} Generated", vqgan, args.model, args=args,
                batch_size=args.batch_size,
            )
                
            row_dict = evaluate_real_and_gen_datasets(real_dataset_values, gen_dataset_values)        
            all_rows.append((attack, attack_strength, row_dict))

            flat_row = {"attack": attack, "strength": attack_strength}
            flat_row.update({f"{loss}_{metric}": val for (loss, metric), val in row_dict.items()})

            # Append to CSV
            df_row = pd.DataFrame([flat_row])
            if not os.path.exists(file_path):
                df_row.to_csv(file_path, index=False)
            else:
                df_row.to_csv(file_path, mode="a", header=False, index=False)

            df_flat = pd.read_csv(file_path)
            # Extract TPR@1%FPR columns
            tpr_cols = [c for c in df_flat.columns if "TPR@1%FPR" in c]
            df_tpr = df_flat[["attack", "strength"] + tpr_cols]
            df_tpr.columns = [c.replace("%", "\\%") for c in df_tpr.columns]

            latex_str = df_tpr.to_latex(
                index=False, float_format="%.3f",
                caption="TPR@1\\%FPR results by attack and strength",
                label="tab:tpr_results"
            )
            with open(tex_path, "w") as f:
                f.write(latex_str)

    print(f"Saved results to {file_path}")


        
if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, default="taming", choices=["taming", "rar"], help="model to use")
    parser.add_argument("--modelpath", type=str, default="checkpoints/2021-04-03T19-39-50_cin_transformer", help="path to the model (see README.md)")
    parser.add_argument("--dataset", default="codes-imagenet", type=str, help="dataset to use")
    parser.add_argument("--datapath", type=str, default="rar/codes", help="path to the dataset (precomputed imagenet codes)")
    parser.add_argument("--dataset_size", type=int, default=1000, help="size of the dataset to subselect")
    parser.add_argument("--mode", type=str, default="newenc-dec")
    parser.add_argument("--nb_epochs", type=int, help="number of epochs")
    parser.add_argument("--augs_schedule", type=str, default='1,1,4,4', help="augmentations schedule (e.g., 1,1,4,4)")
    parser.add_argument("--optimizer", type=str, default="adam", help="optimizer")
    parser.add_argument("--lr", type=float, help="learning rate")
    parser.add_argument("--batch_size", type=int, default=None)
    parser.add_argument("--disable_gan", action="store_true")
    parser.add_argument("--idempotence_loss_weight", type=float, default=1.0, help="idempotence loss weight compared to reg")
    parser.add_argument("--idempotence_loss_weight_factor", type=float, default=1.0, help="factor to multiply idem. loss weight by")
    parser.add_argument("--loss", type=str, default="hard-to-soft-with-ae")
    parser.add_argument("--augs", type=str, choices=["none", "all+geom"], help="augmentations to use in training")
    parser.add_argument("--outdir", type=str, help="output directory")
    parser.add_argument("--encoder_path", type=str, default=None)


    # DDP params
    parser.add_argument("--local_rank", "--local-rank", type=int, default=-1, help="Local rank for distributed training")
    parser.add_argument("--master_port", type=int, default=-1, help="Master port for DDP")
    parser.add_argument("--debug_slurm", type=bool, default=False, help="Debug SLURM setup")

    parser.add_argument("--extended", action="store_true")

    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--watermarked_dir", type=str)
    parser.add_argument("--clean_dir", type=str)
    parser.add_argument("--stable_diff_vae", type=str, default='')
    parser.add_argument("--num_samples", type=int, default=-1) # on how many samples the attacks are supposed to happen
    parser.add_argument("--kernel_size", type=int, default=7)
    parser.add_argument("--variance", type=float, default=0.1)
    parser.add_argument("--crop_ratio",  type=float, default=0.7)
    parser.add_argument("--final_quality", type=int, default=75)
    parser.add_argument("--ctrl_regen_steps", type=float, default=0.5)
    parser.add_argument("--brightness", type=int, default=1.2)
    parser.add_argument("--saturation", type=int, default=1.2)
    parser.add_argument("--contrast", type=int, default=1.2)
    parser.add_argument("--resize_ratio", type=int, default=0.8)
    
    parser.add_argument("--attack", type=str, default=None) #["gauss", "noise", "crop", "jpeg", "resize", "brightness", "saturation", "contrast"]
    args, unknown_args = parser.parse_known_args()
    print(args)


    set_seeds(args.seed)
    
    
    if torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    
    if args.model == "taming":
        vqgan_config_path = os.path.join(args.modelpath, "configs", "vqgan.yaml")
        vqgan_ckpt_path = os.path.join(args.modelpath, "checkpoints", "vqgan.ckpt")
    elif args.model == "rar":
        vqgan_config_path = "deps/rar/configs/training/generator/rar.yaml"
        vqgan_ckpt_path = None  # downloaded
    else:
        raise ValueError(f"Model {args.model} not supported")

    vqgan_codebase = "rar" if args.model == "rar" else "taming"
    vqgan = load_model(
        vqgan_config_path,
        vqgan_ckpt_path,
        device=device,
        vqgan_codebase=vqgan_codebase,
    )

    # update weights for the tokenizer
    
    if args.encoder_path:
        update_weights(vqgan.encoder,  args.encoder_path)
        print(f"Using finetuned encoder from {args.encoder_path}")


    size = 256
    
    print(f"Using size: {size}")
    
    transform = transforms.Compose(
        [
            transforms.Resize((size, size)),
            transforms.ToTensor(),
            lambda x: 2.0 * x - 1.0,  # normalize to [-1, 1]
        ]
    )

    os.makedirs(f"{args.outdir}/attacks", exist_ok=True)
    if args.attack:
        if "[" in args.attack:
            attack = eval(args.attack)
        else:
            attack = args.attack
        
    
    if not args.extended:
        if args.batch_size:
            tpr_fpr_robustness_batched(watermarked_folder_path=args.watermarked_dir, clean_folder_path=args.clean_dir, vqgan=vqgan, args=args, attack=None) # runs all attacks
            if not args.encoder_path:
                tpr_fpr_robustness_latent_tracer_batched(watermarked_folder_path=args.watermarked_dir, clean_folder_path=args.clean_dir, vqgan=vqgan, args=args, attack=None) # runs all attacks

        else:    
            tpr_fpr_robustness(watermarked_folder_path=args.watermarked_dir, clean_folder_path=args.clean_dir, vqgan=vqgan, args=args, attack=None) # runs all attacks
            if not args.encoder_path:
                tpr_fpr_robustness_latent_tracer(watermarked_folder_path=args.watermarked_dir, clean_folder_path=args.clean_dir, vqgan=vqgan, args=args, attack=None)
    else:
        if args.batch_size:
            extended_analysis_batched(watermarked_folder_path=args.watermarked_dir, clean_folder_path=args.clean_dir, vqgan=vqgan, args=args) # runs all attacks
            if not args.encoder_path:
                extended_analysis_latent_tracer_batched(watermarked_folder_path=args.watermarked_dir, clean_folder_path=args.clean_dir, vqgan=vqgan, args=args, attack=attack) #if attack=None it will run all attacks
        else:
            extended_analysis(watermarked_folder_path=args.watermarked_dir, clean_folder_path=args.clean_dir, vqgan=vqgan, args=args) # runs all attacks