import os
import re
import sys
import json
import argparse
import time
from tqdm import tqdm
import pandas as pd
import torch
import wandb

sys.path.append(os.path.dirname(os.path.dirname(__file__)))

from mvdream.model_zoo import build_model
from mvdream.ldm.models.diffusion.ddim import DDIMSampler
from mvdream.editing.sail import SAILOptimizer
from mvdream.attention_control import AttentionController, register_attention_control

from evaluation.evaluator import ModelEvaluator, save_run_outputs
from evaluation.metrics import (
    NoiseDiffNormMetric, HessianMetric, HessianMetricSlow, DiversityMetric,
    BrightEndingMetric, XAttnEntropyMetric
)
from evaluation.utils import load_uids_from_clusters, uids_to_prompts

def main(args):
    # --- Configuration ---
    config = {"method": "sail", "device": "cuda", "model_name": "sd-v1.5-4view", "num_seeds": 4, "num_frames": 4}
    datasets = [
        {"name": "laion_memorized", "path": "data/nemo-prompts/memorized_laion_prompts.csv", "is_memorized": True},
        {"name": "laion_unmemorized", "path": "data/nemo-prompts/unmemorized_laion_prompts.csv", "is_memorized": False},
    ]

    # --- W&B Initialization ---
    if args.use_wandb:
        wandb.init(
            project="mvdream-evaluation",
            name=f"eval-{config['method']}-{time.strftime('%Y%m%d-%H%M%S')}",
            config={
                "method": config["method"],
                "model_name": config["model_name"],
                "num_seeds": config["num_seeds"],
                "num_frames": config["num_frames"],
                "datasets": [d["name"] for d in datasets],
            }
        )

    script_start_time = time.time()

    # --- Setup ---
    torch.manual_seed(42)
    model = build_model(config["model_name"]).to(config["device"])
    model.device = config["device"]
    model.dtype = torch.float32 
    model.image_size = 256
    model.num_frames = config["num_frames"]
    sampler = DDIMSampler(model)
    
    # --- Instantiate SAIL Optimizer ---
    # sail_optimizer = SAILOptimizer(model, optim_steps=20, lr=0.05) # teddy bear
    sail_optimizer = SAILOptimizer(model, optim_steps=20, lr=0.05)

    # --- Instantiate Metrics and Evaluator ---
    metrics = [NoiseDiffNormMetric(), HessianMetric(), HessianMetricSlow(), BrightEndingMetric(), XAttnEntropyMetric(), DiversityMetric()]
    diversity_metric = metrics[-1]
    evaluator = ModelEvaluator(sampler, metrics, device=config["device"])

    # --- Main Evaluation Loop ---
    for dataset in datasets:
        print(f"\n--- Evaluating SAIL for: {dataset['name']} ---")
        
        # Load prompts data
        if dataset['path'].endswith('.json'):
            uids_for_eval = load_uids_from_clusters(dataset["path"], dataset["concept_key"])
            prompts_data = uids_to_prompts(uids_for_eval)
        else:
            prompts_data = pd.read_csv(dataset['path'], sep=';').to_dict('records')
        
        # Create a wandb.Table for this dataset's prompts
        if args.use_wandb:
            dataset_table = wandb.Table(columns=["prompt_idx", "prompt", "is_memorized"])
        
        for idx, prompt_data in enumerate(tqdm(prompts_data, desc=f"Processing {dataset['name']}")):
            prompt = prompt_data["Caption"]
            safe_prompt = re.sub(r'\W+', '_', prompt)[:32]
            prompt_start_time = time.time()
            
            if args.use_wandb:
                dataset_table.add_data(idx, prompt, dataset["is_memorized"])

            baseline_images = []
            mitigated_images = []
            baseline_diversity = None
            
            # --- STAGE 1: Run BASELINE for all seeds ---
            if args.run_baseline:
                print(f"\n{'>'*50}\nRunning BASELINE for prompt {idx}...\n{'<'*50}")
                for seed_idx in range(config['num_seeds']):
                    base_filename = f"prompt_{idx:04d}_{seed_idx:02d}_{safe_prompt}"
                    sample_dir = os.path.join("output/baseline", dataset['name'], base_filename)
                    
                    controller = AttentionController(output_dir=sample_dir)
                    register_attention_control(model, controller)
                    
                    artifacts = {"controller": controller}
                    
                    baseline_result = evaluator._process_single_prompt_single_seed(
                        prompt=prompt, prompt_idx=idx, seed=seed_idx, num_frames=config['num_frames'],
                        base_output_dir=sample_dir, unlearning_artifacts=artifacts
                    )
                    
                    baseline_images.append(baseline_result["image"])
                    baseline_result["metrics"].update({"prompt": prompt, "memorized": dataset["is_memorized"]})
                    save_run_outputs(baseline_result, os.path.dirname(sample_dir), base_filename)

                    # --- W&B Logging (Baseline) ---
                    if args.use_wandb:
                        log_data = {
                            f"baseline/{dataset['name']}/metrics/{k}": v for k, v in baseline_result["metrics"].items() 
                            if isinstance(v, (int, float))
                        }
                        log_data[f"baseline/{dataset['name']}/images"] = wandb.Image(
                            baseline_result["image"],
                            caption=f"Prompt: {prompt}\nSeed: {seed_idx}"
                        )
                        wandb.log(log_data, step=idx * config['num_seeds'] + seed_idx)

                    del baseline_result, controller
                    torch.cuda.empty_cache()

                # --- Cross-Seed Metrics Calculation (Baseline) ---
                if baseline_images:
                    baseline_diversity = diversity_metric.measure(images=baseline_images)
                    baseline_cross_seed_json = os.path.join(f"output/baseline/{dataset['name']}", f"prompt_{idx:04d}_{safe_prompt}_cross_seed.json")
                    os.makedirs(os.path.dirname(baseline_cross_seed_json), exist_ok=True)
                    with open(baseline_cross_seed_json, 'w') as f:
                        json.dump({"prompt": prompt, "memorized": dataset["is_memorized"], diversity_metric.name: baseline_diversity}, f, indent=2)
                    print(f"Baseline Cross-Seed Artefacts:", baseline_cross_seed_json)
                    
                    # --- W&B Logging (Baseline Diversity) ---
                    if args.use_wandb:
                        wandb.log({
                            f"baseline/{dataset['name']}/diversity": baseline_diversity,
                            "prompt_idx": idx
                        }, step=idx)
                    print(f"Baseline diversity for prompt {idx}: {baseline_diversity}")
                
                del baseline_images

            # --- STAGE 2: Run SAIL optimization and mitigation ---
            print(f"\n{'>'*50}\nRunning SAIL optimization and mitigation for prompt {idx}...\n{'<'*50}")
            for seed_idx in range(config['num_seeds']):
                print(f"\nRunning SAIL opt & inference for prompt {idx}, seed {seed_idx}...")
                base_filename = f"prompt_{idx:04d}_{seed_idx:02d}_{safe_prompt}"
                mitigated_sample_dir = os.path.join(f"output/{config['method']}", dataset['name'], base_filename)
                os.makedirs(mitigated_sample_dir, exist_ok=True)

                # --- SAIL Optimization for this specific sample ---
                base_seed = 42 * (idx + 1) + seed_idx
                optimized_x_T = sail_optimizer.optimize_noise(prompt, base_seed)
                
                # --- Run inference with the optimized noise ---
                controller = AttentionController(output_dir=mitigated_sample_dir)
                register_attention_control(model, controller)
                
                artifacts = {"x_T_override": optimized_x_T, "controller": controller}
                
                mitigated_result = evaluator._process_single_prompt_single_seed(
                    prompt=prompt, prompt_idx=idx, seed=seed_idx, num_frames=config['num_frames'],
                    base_output_dir=mitigated_sample_dir,
                    unlearning_artifacts=artifacts
                )
                mitigated_result["metrics"].update({"prompt": prompt, "memorized": False})
                mitigated_images.append(mitigated_result["image"])
                save_run_outputs(mitigated_result, os.path.dirname(mitigated_sample_dir), base_filename)

                # --- W&B Logging (Mitigated) ---
                if args.use_wandb:
                    log_data = {
                        f"mitigated/{dataset['name']}/metrics/{k}": v for k, v in mitigated_result["metrics"].items() 
                        if isinstance(v, (int, float))
                    }
                    log_data[f"mitigated/{dataset['name']}/images"] = wandb.Image(
                        mitigated_result["image"],
                        caption=f"Prompt: {prompt}\nSeed: {seed_idx}\nMethod: SAIL"
                    )
                    wandb.log(log_data, step=idx * config['num_seeds'] + seed_idx)

                del mitigated_result, controller, optimized_x_T
                torch.cuda.empty_cache()

            # --- Cross-Seed Metrics Calculation (Mitigated) ---
            if mitigated_images:
                mitigated_diversity = diversity_metric.measure(images=mitigated_images)
                mitigated_cross_seed_json = os.path.join(f"output/{config['method']}/{dataset['name']}", f"prompt_{idx:04d}_{safe_prompt}_cross_seed.json")
                os.makedirs(os.path.dirname(mitigated_cross_seed_json), exist_ok=True)
                with open(mitigated_cross_seed_json, 'w') as f:
                    json.dump({"prompt": prompt, "memorized": False, diversity_metric.name: mitigated_diversity}, f, indent=2)
                print(f"Mitigated Cross-Seed Artefacts:", mitigated_cross_seed_json)
                
                # --- W&B Logging (Mitigated Diversity & Comparison) ---
                if args.use_wandb:
                    wandb.log({
                        f"mitigated/{dataset['name']}/diversity": mitigated_diversity,
                        "prompt_idx": idx
                    }, step=idx)
                    
                    # Log a bar chart comparing diversities (only if baseline was run)
                    if baseline_diversity is not None:
                        diversity_plot = wandb.plot.bar(
                            wandb.Table(
                                columns=["stage", "diversity"],
                                data=[["baseline", baseline_diversity], ["mitigated", mitigated_diversity]]
                            ),
                            "stage",
                            "diversity",
                            title=f"Diversity Comparison for Prompt {idx}"
                        )
                        wandb.log({f"plots/{dataset['name']}/diversity_comparison": diversity_plot}, step=idx)
                
                print(f"Mitigated diversity for prompt {idx}: {mitigated_diversity}")
            
            del mitigated_images

            # --- W&B Logging (Timing) ---
            prompt_duration = time.time() - prompt_start_time
            print(f"Prompt {idx} processed in {prompt_duration:.2f} seconds.")
            if args.use_wandb:
                wandb.log({
                    f"timing/{dataset['name']}/prompt_duration": prompt_duration,
                    "prompt_idx": idx
                }, step=idx)

        # Log the dataset-specific table
        if args.use_wandb:
            wandb.log({f"prompts/{dataset['name']}_prompts": dataset_table})

    # --- Final Timing Log ---
    total_duration_minutes = (time.time() - script_start_time) / 60
    print(f"Total evaluation time: {total_duration_minutes:.2f} minutes.")
    
    if args.use_wandb:
        wandb.finish()

    print("\nSAIL evaluation complete.")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run evaluation for SAIL mitigation method.")
    parser.add_argument("--use_wandb", action="store_true", help="Enable Weights & Biases logging.")
    parser.add_argument("--run_baseline", action="store_true", help="Run baseline for sanity check.")
    args = parser.parse_args()
    main(args)