import os
import sys
import argparse
import re
import json
import time

import torch
import pandas as pd
from tqdm import tqdm
import numpy as np
import wandb
from PIL import Image

sys.path.append(os.path.join(os.path.dirname(__file__), '..'))

from mvdream.model_zoo import build_model
from mvdream.ldm.models.diffusion.ddim import DDIMSampler
from mvdream.attention_control import AttentionController, register_attention_control
from mvdream.editing.perturbation import optimize_embedding_bright_ending

from evaluation.evaluator import ModelEvaluator, save_run_outputs
from evaluation.metrics import NoiseDiffNormMetric, HessianMetric, HessianMetricSlow, DiversityMetric, BrightEndingMetric, XAttnEntropyMetric
from evaluation.utils import load_uids_from_clusters, uids_to_prompts

def main(args):
    # --- Configuration & Setup ---
    config = {"method": args.method, "device": "cuda", "model_name": "sd-v1.5-4view", "num_seeds": 4, "num_frames": 4}
    datasets = [
        {"name": "laion_memorized", "path": "data/nemo-prompts/memorized_laion_prompts.csv", "is_memorized": True},
        {"name": "laion_unmemorized", "path": "data/nemo-prompts/unmemorized_laion_prompts.csv", "is_memorized": False}
    ]
    torch.manual_seed(42)

    # --- W&B Initialization ---
    if args.use_wandb:
        wandb.init(
            project="mvdream-evaluation",
            name=f"eval-{args.method}-{time.strftime('%Y%m%d-%H%M%S')}",
            config={
                "method": args.method,
                "model_name": config["model_name"],
                "num_seeds": config["num_seeds"],
                "num_frames": config["num_frames"],
                "datasets": [d["name"] for d in datasets],
            }
        )

    script_start_time = time.time()

    model = build_model(config["model_name"]).to(config["device"])
    model.device = config["device"]
    model.dtype = torch.float32 
    model.image_size = 256
    model.num_frames = config["num_frames"]
    
    sampler = DDIMSampler(model)
    metrics = [NoiseDiffNormMetric(), HessianMetric(), HessianMetricSlow(), BrightEndingMetric(), XAttnEntropyMetric(), DiversityMetric()]
    diversity_metric = metrics[-1]
    evaluator = ModelEvaluator(sampler, metrics, device=config["device"])

    # --- Main Evaluation Loop ---
    for dataset in datasets:
        if dataset['path'].endswith('.json'):
            uids_for_eval = load_uids_from_clusters(dataset["path"], dataset["concept_key"])
            prompts_data = uids_to_prompts(uids_for_eval)
        else:
            prompts_data = pd.read_csv(dataset['path'], sep=';').to_dict('records')
        
        # Create a wandb.Table for this dataset's prompts
        if args.use_wandb:
            dataset_table = wandb.Table(columns=["prompt_idx", "prompt", "is_memorized"])

        for idx, prompt_data in enumerate(tqdm(prompts_data, desc=f"Processing {dataset['name']}")):
            prompt_start_time = time.time()
            prompt = prompt_data["Caption"]
            safe_prompt = re.sub(r'\W+', '_', prompt)[:32]
            
            if args.use_wandb:
                dataset_table.add_data(idx, prompt, dataset["is_memorized"])

            baseline_images = []
            mitigated_images = []
            baseline_diversity = None
            
            # --- STAGE 1: Run BASELINE for all seeds ---
            if args.run_baseline:
                print(f"\n{'>'*50}\nRunning BASELINE for prompt {idx}...\n{'<'*50}")
                for seed_idx in range(config['num_seeds']):
                    base_filename = f"prompt_{idx:04d}_{seed_idx:02d}_{safe_prompt}"
                    sample_dir = os.path.join("output/baseline", dataset['name'], base_filename)
                    
                    attn_config = {'steps': [0, 49], 'blocks': ['down'], 'xattn_indices': [0, 1]}
                    controller = AttentionController(output_dir=sample_dir, config=attn_config)
                    register_attention_control(model, controller)
                    
                    artifacts = {"controller": controller, "attention_map_dir": controller.output_dir}
                    
                    baseline_result = evaluator._process_single_prompt_single_seed(
                        prompt=prompt, prompt_idx=idx, seed=seed_idx, num_frames=config['num_frames'],
                        base_output_dir=sample_dir, unlearning_artifacts=artifacts
                    )
                    
                    baseline_images.append(baseline_result["image"])
                    baseline_result["metrics"].update({"prompt": prompt, "memorized": dataset["is_memorized"]})
                    save_run_outputs(baseline_result, os.path.dirname(sample_dir), base_filename)

                    # --- W&B Logging (Baseline) ---
                    if args.use_wandb:
                        log_data = {
                            f"baseline/{dataset['name']}/metrics/{k}": v for k, v in baseline_result["metrics"].items() 
                            if isinstance(v, (int, float))
                        }
                        log_data[f"baseline/{dataset['name']}/images"] = wandb.Image(
                            baseline_result["image"],
                            caption=f"Prompt: {prompt}\nSeed: {seed_idx}"
                        )
                        wandb.log(log_data, step=idx * config['num_seeds'] + seed_idx)

                    del baseline_result, controller
                    torch.cuda.empty_cache()

                # --- Cross-Seed Metrics Calculation (Baseline) ---
                if baseline_images:
                    baseline_diversity = diversity_metric.measure(images=baseline_images)
                    baseline_cross_seed_json = os.path.join(f"output/baseline/{dataset['name']}", f"prompt_{idx:04d}_{safe_prompt}_cross_seed.json")
                    os.makedirs(os.path.dirname(baseline_cross_seed_json), exist_ok=True)
                    with open(baseline_cross_seed_json, 'w') as f:
                        json.dump({"prompt": prompt, "memorized": dataset["is_memorized"], diversity_metric.name: baseline_diversity}, f, indent=2)
                    print(f"Baseline Cross-Seed Artefacts:", baseline_cross_seed_json)
                    
                    # --- W&B Logging (Baseline Diversity) ---
                    if args.use_wandb:
                        wandb.log({
                            f"baseline/{dataset['name']}/diversity": baseline_diversity,
                            "prompt_idx": idx
                        }, step=idx)
                    print(f"Baseline diversity for prompt {idx}: {baseline_diversity}")
                
                del baseline_images

            # --- STAGE 2: Run MITIGATION for all seeds ---
            print(f"\n{'>'*50}\nRunning MITIGATION ('{args.method}') for prompt {idx}...\n{'<'*50}")
            for seed_idx in range(config['num_seeds']):
                base_filename = f"prompt_{idx:04d}_{seed_idx:02d}_{safe_prompt}"
                mitigated_sample_dir = os.path.join(f"output/{config['method']}", dataset['name'], base_filename)
                os.makedirs(mitigated_sample_dir, exist_ok=True)
                
                artifacts = {}
                if args.method == 'be':
                    attn_config = {'steps': [0, 49], 'blocks': ['down'], 'xattn_indices': [0, 1]}
                    controller = AttentionController(output_dir=mitigated_sample_dir, config=attn_config)
                    register_attention_control(model, controller)

                    initial_embedding = model.get_learned_conditioning([prompt]).to(config['device'])
                    perturbed_embedding = optimize_embedding_bright_ending(initial_embedding, model, controller, output_path=mitigated_sample_dir, lr=0.01, steps=20)
                    artifacts["embedding_perturb_func"] = lambda e, **k: perturbed_embedding
                    artifacts["controller"] = controller
                
                elif args.method == 'ca_entropy':
                    mitigation_params = {'type': 'ca_entropy', 'c_scale': 1.8}
                    controller = AttentionController(output_dir=mitigated_sample_dir, mitigation_params=mitigation_params)
                    register_attention_control(model, controller)
                    artifacts["controller"] = controller
                
                mitigated_result = evaluator._process_single_prompt_single_seed(
                    prompt=prompt, prompt_idx=idx, seed=seed_idx, num_frames=config['num_frames'],
                    base_output_dir=mitigated_sample_dir, unlearning_artifacts=artifacts
                )
                mitigated_result["metrics"].update({"prompt": prompt, "memorized": False})
                mitigated_images.append(mitigated_result["image"])
                save_run_outputs(mitigated_result, os.path.dirname(mitigated_sample_dir), base_filename)

                # --- W&B Logging (Mitigated) ---
                if args.use_wandb:
                    log_data = {
                        f"mitigated/{dataset['name']}/metrics/{k}": v for k, v in mitigated_result["metrics"].items() 
                        if isinstance(v, (int, float))
                    }
                    log_data[f"mitigated/{dataset['name']}/images"] = wandb.Image(
                        mitigated_result["image"],
                        caption=f"Prompt: {prompt}\nSeed: {seed_idx}\nMethod: {args.method}"
                    )
                    wandb.log(log_data, step=idx * config['num_seeds'] + seed_idx)

                del mitigated_result, controller
                if 'perturbed_embedding' in locals(): del perturbed_embedding
                torch.cuda.empty_cache()

            # --- Cross-Seed Metrics Calculation (Mitigated) ---
            if mitigated_images:
                mitigated_diversity = diversity_metric.measure(images=mitigated_images)
                mitigated_cross_seed_json = os.path.join(f"output/{config['method']}/{dataset['name']}", f"prompt_{idx:04d}_{safe_prompt}_cross_seed.json")
                os.makedirs(os.path.dirname(mitigated_cross_seed_json), exist_ok=True)
                with open(mitigated_cross_seed_json, 'w') as f:
                    json.dump({"prompt": prompt, "memorized": False, diversity_metric.name: mitigated_diversity}, f, indent=2)
                print(f"Mitigated Cross-Seed Artefacts:", mitigated_cross_seed_json)
                
                # --- W&B Logging (Mitigated Diversity & Comparison Plot) ---
                if args.use_wandb:
                    wandb.log({
                        f"mitigated/{dataset['name']}/diversity": mitigated_diversity,
                        "prompt_idx": idx
                    }, step=idx)
                    
                    # Log a bar chart comparing diversities (only if baseline was run)
                    if baseline_diversity is not None:
                        diversity_plot = wandb.plot.bar(
                            wandb.Table(
                                columns=["stage", "diversity"],
                                data=[["baseline", baseline_diversity], ["mitigated", mitigated_diversity]]
                            ),
                            "stage",
                            "diversity",
                            title=f"Diversity Comparison for Prompt {idx}"
                        )
                        wandb.log({f"plots/{dataset['name']}/diversity_comparison": diversity_plot}, step=idx)

                print(f"Mitigated diversity for prompt {idx}: {mitigated_diversity}")

            del mitigated_images
            
            # --- W&B Logging (Timing) ---
            prompt_duration = time.time() - prompt_start_time
            print(f"Prompt {idx} processed in {prompt_duration:.2f} seconds.")
            if args.use_wandb:
                wandb.log({
                    f"timing/{dataset['name']}/prompt_duration": prompt_duration,
                    "prompt_idx": idx
                }, step=idx)

        # Log the dataset-specific table
        if args.use_wandb:
            wandb.log({f"prompts/{dataset['name']}_prompts": dataset_table})

    # --- Final Timing Log ---
    total_duration_minutes = (time.time() - script_start_time) / 60
    print(f"Total evaluation time: {total_duration_minutes:.2f} minutes.")
    
    if args.use_wandb:
        wandb.finish()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run evaluation for attention-based unlearning methods with W&B logging.")
    parser.add_argument("--method", type=str, required=True, choices=['be', 'ca_entropy'], help="Attention-based method to evaluate.")
    parser.add_argument("--use_wandb", action="store_true", help="Enable Weights & Biases logging.")
    parser.add_argument("--run_baseline", action="store_true", help="Run baseline for sanity check.")
    args = parser.parse_args()
    main(args)