"""
Metric-Agnostic Input Perturbation (Somepalli et al.): GNI RT CWR RNA
Metric-Aware Input Perturbation (Wen et al.)
"""
import os
import sys
import argparse
import re
import json
import time
from pprint import pprint
from functools import partial
import torch
import pandas as pd
from tqdm import tqdm
import wandb

sys.path.append(os.path.dirname(os.path.dirname(__file__)))

from mvdream.model_zoo import build_model
from mvdream.ldm.models.diffusion.ddim import DDIMSampler
from mvdream.attention_control import AttentionController, register_attention_control
from mvdream.editing.perturbation import (
    perturb_prompt_random_tokens, 
    perturb_prompt_word_repetition, 
    perturb_prompt_random_numbers,
    perturb_prompt_tokenwise,
    add_gaussian_noise,
    optimize_embedding_wen,
    optimize_embedding_bright_ending
)

from evaluation.evaluator import ModelEvaluator, save_run_outputs
from evaluation.metrics import NoiseDiffNormMetric, HessianMetric, HessianMetricSlow, BrightEndingMetric, XAttnEntropyMetric, DiversityMetric
from evaluation.utils import load_uids_from_clusters, uids_to_prompts

def main(args):
    # --- Configuration ---
    config = {
        "method_name": f"perturb_{args.method}_4",
        "device": "cuda",
        "model_name": "sd-v1.5-4view",
        "num_seeds_or_tokens": -1 if "tokenwise" in args.method else 4,
        "num_frames": 4,
    }
    config["output_dir"] = f"output/{config['method_name']}"
    
    datasets = [
        {"name": "laion_memorized", "type": "csv", "path": "data/nemo-prompts/memorized_laion_prompts.csv", "is_memorized": True},
        {"name": "laion_unmemorized", "path": "data/nemo-prompts/unmemorized_laion_prompts.csv", "is_memorized": False},
    ]

    # --- W&B Initialization ---
    if args.use_wandb:
        wandb.init(
            project="mvdream-evaluation",
            name=f"eval-{config['method_name']}-{time.strftime('%Y%m%d-%H%M%S')}",
            config={
                "method": config["method_name"],
                "model_name": config["model_name"],
                "num_seeds": config["num_seeds_or_tokens"],
                "num_frames": config["num_frames"],
                "datasets": [d["name"] for d in datasets],
                "perturbation_method": args.method,
            }
        )

    script_start_time = time.time()

    # --- Setup ---
    torch.manual_seed(42)
    model = build_model(config["model_name"]).to(config["device"])
    model.device = config["device"]
    model.dtype = torch.float32 
    model.image_size = 256
    model.num_frames = config["num_frames"]

    sampler = DDIMSampler(model)
    metrics = [NoiseDiffNormMetric(), HessianMetric(), HessianMetricSlow(), BrightEndingMetric(), XAttnEntropyMetric(), DiversityMetric(device=config["device"])]
    diversity_metric = metrics[-1]
    evaluator = ModelEvaluator(sampler, metrics, device=config["device"])
    
    # --- Setup Perturbation Functions ---
    token_perturb_func = None
    embed_perturb_func = None
    
    if args.method in ['ip-rt', 'ip-rna', 'ip-cwr', 'ip-rt_tokenwise']:
        if args.method == 'ip-rt':
            token_perturb_func = partial(perturb_prompt_random_tokens, tokenizer=model.cond_stage_model.tokenizer, num_tokens=4)
        elif args.method == 'ip-rna':
            token_perturb_func = partial(perturb_prompt_random_numbers, num_numbers=2)
        elif args.method == 'ip-cwr':
            token_perturb_func = partial(perturb_prompt_word_repetition, tokenizer=model.cond_stage_model.tokenizer, num_repeat=4)
        elif args.method == 'ip-rt_tokenwise':
            token_perturb_func = partial(perturb_prompt_tokenwise, tokenizer=model.cond_stage_model.tokenizer)   
        
    elif args.method in ['ip-gni', 'ip-wen', 'ip-be']:
        if args.method == 'ip-gni':
            embed_perturb_func = partial(add_gaussian_noise, std=1.5)
        elif args.method == 'ip-wen':
            # Note: output path will be set per-prompt
            embed_perturb_func = partial(optimize_embedding_wen, lr=0.1, steps=100, target_loss=0.0, loss_history=True) # Objaverse
            # embed_perturb_func = partial(optimize_embedding_wen, lr=0.1, steps=150, target_loss=0.0, loss_history=True) # LAION
        elif args.method == 'ip-be':
            # Note: output path will be set per-prompt
            embed_perturb_func = partial(optimize_embedding_bright_ending, lr=0.01, steps=9)
    else:
        raise ValueError(f"Method '{args.method}' is not defined.")

    # --- Loop and Evaluate Each Dataset ---
    for dataset in datasets:
        print(f"\n--- Evaluating with method '{args.method}' on: {dataset['name']} ---")
        
        # Load prompts data
        if dataset['path'].endswith('.json'):
            uids_for_eval = load_uids_from_clusters(dataset["path"], dataset["concept_key"])
            prompts_data = uids_to_prompts(uids_for_eval)
        else:
            prompts_data = pd.read_csv(dataset['path'], sep=';').to_dict('records')
        
        # Create a wandb.Table for this dataset's prompts
        if args.use_wandb:
            dataset_table = wandb.Table(columns=["prompt_idx", "prompt", "is_memorized"])

        for idx, prompt_data in enumerate(tqdm(prompts_data, desc=f"Processing {dataset['name']}")):
            prompt = prompt_data["Caption"]
            safe_prompt = re.sub(r'\W+', '_', prompt)[:32]
            prompt_start_time = time.time()
            
            if args.use_wandb:
                dataset_table.add_data(idx, prompt, dataset["is_memorized"])

            baseline_images = []
            mitigated_images = []
            baseline_diversity = None
            
            # --- STAGE 1: Run BASELINE for all seeds ---
            if args.run_baseline:
                print(f"\n{'>'*50}\nRunning BASELINE for prompt {idx}...\n{'<'*50}")
                for seed_idx in range(config['num_seeds_or_tokens']):
                    base_filename = f"prompt_{idx:04d}_{seed_idx:02d}_{safe_prompt}"
                    sample_dir = os.path.join("output/baseline", dataset['name'], base_filename)
                    
                    controller = AttentionController(output_dir=sample_dir)
                    register_attention_control(model, controller)
                    
                    artifacts = {"controller": controller}
                    
                    baseline_result = evaluator._process_single_prompt_single_seed(
                        prompt=prompt, prompt_idx=idx, seed=seed_idx, num_frames=config['num_frames'],
                        base_output_dir=sample_dir, unlearning_artifacts=artifacts
                    )
                    
                    baseline_images.append(baseline_result["image"])
                    baseline_result["metrics"].update({"prompt": prompt, "memorized": dataset["is_memorized"]})
                    save_run_outputs(baseline_result, os.path.dirname(sample_dir), base_filename)

                    # --- W&B Logging (Baseline) ---
                    if args.use_wandb:
                        log_data = {
                            f"baseline/{dataset['name']}/metrics/{k}": v for k, v in baseline_result["metrics"].items() 
                            if isinstance(v, (int, float))
                        }
                        log_data[f"baseline/{dataset['name']}/images"] = wandb.Image(
                            baseline_result["image"],
                            caption=f"Prompt: {prompt}\nSeed: {seed_idx}"
                        )
                        wandb.log(log_data, step=idx * config['num_seeds_or_tokens'] + seed_idx)

                    del baseline_result, controller
                    torch.cuda.empty_cache()

                # --- Cross-Seed Metrics Calculation (Baseline) ---
                if baseline_images:
                    baseline_diversity = diversity_metric.measure(images=baseline_images)
                    baseline_cross_seed_json = os.path.join(f"output/baseline/{dataset['name']}", f"prompt_{idx:04d}_{safe_prompt}_cross_seed.json")
                    os.makedirs(os.path.dirname(baseline_cross_seed_json), exist_ok=True)
                    with open(baseline_cross_seed_json, 'w') as f:
                        json.dump({"prompt": prompt, "memorized": dataset["is_memorized"], diversity_metric.name: baseline_diversity}, f, indent=2)
                    print(f"Baseline Cross-Seed Artefacts:", baseline_cross_seed_json)
                    
                    # --- W&B Logging (Baseline Diversity) ---
                    if args.use_wandb:
                        wandb.log({
                            f"baseline/{dataset['name']}/diversity": baseline_diversity,
                            "prompt_idx": idx
                        }, step=idx)
                    print(f"Baseline diversity for prompt {idx}: {baseline_diversity}")
                
                del baseline_images

            # --- STAGE 2: Run MITIGATION with perturbations ---
            print(f"\n{'>'*50}\nRunning perturbation mitigation ('{args.method}') for prompt {idx}...\n{'<'*50}")
            for seed_idx in range(config['num_seeds_or_tokens']):
                base_filename = f"prompt_{idx:04d}_{seed_idx:02d}_{safe_prompt}"
                mitigated_sample_dir = os.path.join(f"output/{config['method_name']}", dataset['name'], base_filename)
                os.makedirs(mitigated_sample_dir, exist_ok=True)
                
                controller = AttentionController(output_dir=mitigated_sample_dir)
                register_attention_control(model, controller)
                
                artifacts = {"controller": controller}
                perturbed_prompt = prompt
                # Add perturbation functions to artifacts
                if token_perturb_func is not None:
                    perturbed_prompt = token_perturb_func(prompt)
                    artifacts["token_perturb_func"] = token_perturb_func
                    print(f"  Original prompt: {prompt}")
                    print(f"  Perturbed prompt: {perturbed_prompt}")

                if embed_perturb_func is not None:
                    if args.method in ['ip-wen', 'ip-be']:
                        # Set output path for methods that need it
                        artifacts["embed_perturb_func"] = partial(embed_perturb_func, output_path=mitigated_sample_dir)
                    else:
                        artifacts["embed_perturb_func"] = embed_perturb_func
                print(artifacts)
                mitigated_result = evaluator._process_single_prompt_single_seed(
                    prompt=prompt, prompt_idx=idx, seed=seed_idx, num_frames=config['num_frames'],
                    base_output_dir=mitigated_sample_dir, unlearning_artifacts=artifacts
                )
                
                mitigated_images.append(mitigated_result["image"])
                mitigated_result["metrics"].update({"prompt": prompt, "memorized": False})
                save_run_outputs(mitigated_result, os.path.dirname(mitigated_sample_dir), base_filename)

                # --- W&B Logging (Mitigated) ---
                if args.use_wandb:
                    log_data = {
                        f"mitigated/{dataset['name']}/metrics/{k}": v for k, v in mitigated_result["metrics"].items() 
                        if isinstance(v, (int, float))
                    }
                    log_data[f"mitigated/{dataset['name']}/images"] = wandb.Image(
                        mitigated_result["image"],
                        caption=f"Prompt: {prompt}\nSeed: {seed_idx}\nMethod: {args.method}"
                    )
                    wandb.log(log_data, step=idx * config['num_seeds_or_tokens'] + seed_idx)

                del mitigated_result, controller
                torch.cuda.empty_cache()

            # --- Cross-Seed Metrics Calculation (Mitigated) ---
            if mitigated_images:
                mitigated_diversity = diversity_metric.measure(images=mitigated_images)
                mitigated_cross_seed_json = os.path.join(f"output/{config['method_name']}/{dataset['name']}", f"prompt_{idx:04d}_{safe_prompt}_cross_seed.json")
                os.makedirs(os.path.dirname(mitigated_cross_seed_json), exist_ok=True)
                with open(mitigated_cross_seed_json, 'w') as f:
                    json.dump({"prompt": prompt, "perturbed_prompt": perturbed_prompt,
                               "memorized": False, diversity_metric.name: mitigated_diversity}, f, indent=2)
                print(f"Mitigated Cross-Seed Artefacts:", mitigated_cross_seed_json)
                
                # --- W&B Logging (Mitigated Diversity & Comparison) ---
                if args.use_wandb:
                    wandb.log({
                        f"mitigated/{dataset['name']}/diversity": mitigated_diversity,
                        "prompt_idx": idx
                    }, step=idx)
                    
                    # Log a bar chart comparing diversities (only if baseline was run)
                    if baseline_diversity is not None:
                        diversity_plot = wandb.plot.bar(
                            wandb.Table(
                                columns=["stage", "diversity"],
                                data=[["baseline", baseline_diversity], ["mitigated", mitigated_diversity]]
                            ),
                            "stage",
                            "diversity",
                            title=f"Diversity Comparison for Prompt {idx}"
                        )
                        wandb.log({f"plots/{dataset['name']}/diversity_comparison": diversity_plot}, step=idx)
                
                print(f"Mitigated diversity for prompt {idx}: {mitigated_diversity}")
            
            del mitigated_images

            # --- W&B Logging (Timing) ---
            prompt_duration = time.time() - prompt_start_time
            print(f"Prompt {idx} processed in {prompt_duration:.2f} seconds.")
            if args.use_wandb:
                wandb.log({
                    f"timing/{dataset['name']}/prompt_duration": prompt_duration,
                    "prompt_idx": idx
                }, step=idx)

        # Log the dataset-specific table
        if args.use_wandb:
            wandb.log({f"prompts/{dataset['name']}_prompts": dataset_table})

    # --- Final Timing Log ---
    total_duration_minutes = (time.time() - script_start_time) / 60
    print(f"Total evaluation time: {total_duration_minutes:.2f} minutes.")
    
    if args.use_wandb:
        wandb.finish()
        
    print(f"\nEvaluation for perturbation method '{args.method}' complete.")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run input perturbation evaluation with optional W&B logging.")
    parser.add_argument(
        "--method", type=str, required=True,
        choices=['ip-rt', 'ip-rna', 'ip-cwr', 'ip-rt_tokenwise',
                 'ip-gni', 'ip-be', 'ip-wen'],
        help="The input perturbation method to use."
    )
    parser.add_argument("--use_wandb", action="store_true", help="Enable Weights & Biases logging.")
    parser.add_argument("--run_baseline", action="store_true", help="Run baseline for sanity check.")
    args = parser.parse_args()
    main(args)