import os
import sys
import argparse
import re
import json
import time

import torch
import pandas as pd
from tqdm import tqdm
import numpy as np
import wandb
from PIL import Image

sys.path.append(os.path.join(os.path.dirname(__file__), '..'))

from mvdream.model_zoo import build_model
from mvdream.ldm.models.diffusion.ddim import DDIMSampler
from mvdream.attention_control import AttentionController, register_attention_control
from mvdream.editing.perturbation import optimize_embedding_bright_ending

from evaluation.evaluator import ModelEvaluator, save_run_outputs
from evaluation.metrics import NoiseDiffNormMetric, HessianMetric, HessianMetricSlow, DiversityMetric, BrightEndingMetric, XAttnEntropyMetric
from evaluation.utils import load_uids_from_clusters, uids_to_prompts

def main(args):
    args.method = "baseline"
    # --- Configuration & Setup ---
    config = {"method": args.method, "device": "cuda", "model_name": "sd-v1.5-4view", "num_seeds": 3, "num_frames": 4}
    datasets = [
        {"name": "laion_memorized", "type": "csv", "path": "data/nemo-prompts/memorized_laion_prompts.csv", "is_memorized": True},
        {"name": "laion_unmemorized", "path": "data/nemo-prompts/unmemorized_laion_prompts.csv", "is_memorized": False},
    ]
    torch.manual_seed(42)

    # --- W&B Initialization ---
    if args.use_wandb:
        wandb.init(
            project="mvdream-evaluation",
            name=f"eval-{args.method}-{time.strftime('%Y%m%d-%H%M%S')}",
            config={
                "method": args.method,
                "model_name": config["model_name"],
                "num_seeds": config["num_seeds"],
                "num_frames": config["num_frames"],
                "datasets": [d["name"] for d in datasets],
            }
        )

    script_start_time = time.time()

    model = build_model(config["model_name"]).to(config["device"])
    model.device = config["device"]
    model.dtype = torch.float32 
    model.image_size = 256
    model.num_frames = config["num_frames"]
    
    sampler = DDIMSampler(model)
    metrics = [NoiseDiffNormMetric(), HessianMetric(), HessianMetricSlow(), BrightEndingMetric(), XAttnEntropyMetric(), DiversityMetric()]
    diversity_metric = metrics[-1]
    evaluator = ModelEvaluator(sampler, metrics, device=config["device"])

    # --- Main Evaluation Loop ---
    for dataset in datasets:
        if dataset['path'].endswith('.json'):
            uids_for_eval = load_uids_from_clusters(dataset["path"], dataset["concept_key"])
            prompts_data = uids_to_prompts(uids_for_eval)
        else:
            prompts_data = pd.read_csv(dataset['path'], sep=';').to_dict('records')
        
        # Create a wandb.Table for this dataset's prompts
        if args.use_wandb:
            dataset_table = wandb.Table(columns=["prompt_idx", "prompt", "is_memorized"])

        for idx, prompt_data in enumerate(tqdm(prompts_data, desc=f"Processing {dataset['name']}")):
            if idx < 125: continue
            prompt_start_time = time.time()
            prompt = prompt_data["Caption"]
            safe_prompt = re.sub(r'\W+', '_', prompt)[:32]
            
            if args.use_wandb:
                dataset_table.add_data(idx, prompt, dataset["is_memorized"])

            baseline_images = []
            mitigated_images = []
            baseline_intermediates = []  # NEW: Collect intermediates for SSIM
            baseline_diversity = None
            
            # ---  Run BASELINE for all seeds ---
            print(f"\n{'>'*50}\nRunning BASELINE for prompt {idx}...\n{'<'*50}")
            for seed_idx in range(config['num_seeds']):
                base_filename = f"prompt_{idx:04d}_{seed_idx:02d}_{safe_prompt}"
                sample_dir = os.path.join("output/baseline", dataset['name'], base_filename)
                
                attn_config = {'steps': [0, 49], 'blocks': ['down'], 'xattn_indices': [0, 1]}
                controller = AttentionController(output_dir=sample_dir, config=attn_config)
                register_attention_control(model, controller)
                
                artifacts = {"controller": controller, "attention_map_dir": controller.output_dir}
                
                baseline_result = evaluator._process_single_prompt_single_seed(
                    prompt=prompt, prompt_idx=idx, seed=seed_idx, num_frames=config['num_frames'],
                    base_output_dir=sample_dir, unlearning_artifacts=artifacts
                )
                
                baseline_images.append(baseline_result["image"])
                baseline_intermediates.append(baseline_result["intermediates"])  # NEW: Store intermediates
                baseline_result["metrics"].update({"prompt": prompt, "memorized": dataset["is_memorized"]})
                save_run_outputs(baseline_result, os.path.dirname(sample_dir), base_filename)

                # --- W&B Logging (Baseline) ---
                if args.use_wandb:
                    log_data = {
                        f"baseline/{dataset['name']}/metrics/{k}": v for k, v in baseline_result["metrics"].items() 
                        if isinstance(v, (int, float))
                    }
                    log_data[f"baseline/{dataset['name']}/images"] = wandb.Image(
                        baseline_result["image"],
                        caption=f"Prompt: {prompt}\nSeed: {seed_idx}"
                    )
                    wandb.log(log_data, step=idx * config['num_seeds'] + seed_idx)

                del baseline_result, controller
                torch.cuda.empty_cache()

            # --- Cross-Seed Metrics Calculation (Baseline) ---
            if baseline_images:
                # NEW: Pass intermediates to diversity metric for SSIM calculation
                baseline_diversity = diversity_metric.measure(
                    images=baseline_images, 
                    intermediates_list=baseline_intermediates  # NEW: Pass intermediates
                )
                baseline_cross_seed_json = os.path.join(f"output/baseline/{dataset['name']}", f"prompt_{idx:04d}_{safe_prompt}_cross_seed.json")
                os.makedirs(os.path.dirname(baseline_cross_seed_json), exist_ok=True)
                with open(baseline_cross_seed_json, 'w') as f:
                    json.dump({"prompt": prompt, "memorized": dataset["is_memorized"], diversity_metric.name: baseline_diversity}, f, indent=2)
                print(f"Baseline Cross-Seed Artefacts:", baseline_cross_seed_json)
                
                # --- W&B Logging (Baseline Diversity) ---
                if args.use_wandb:
                    wandb.log({
                        f"baseline/{dataset['name']}/diversity": baseline_diversity,
                        "prompt_idx": idx
                    }, step=idx)
                print(f"Baseline diversity for prompt {idx}: {baseline_diversity}")
            
            del baseline_images, baseline_intermediates  # NEW: Clean up intermediates too
        
            # --- W&B Logging (Timing) ---
            prompt_duration = time.time() - prompt_start_time
            print(f"Prompt {idx} processed in {prompt_duration:.2f} seconds.")
            if args.use_wandb:
                wandb.log({
                    f"timing/{dataset['name']}/prompt_duration": prompt_duration,
                    "prompt_idx": idx
                }, step=idx)

        # Log the dataset-specific table
        if args.use_wandb:
            wandb.log({f"prompts/{dataset['name']}_prompts": dataset_table})

    # --- Final Timing Log ---
    total_duration_minutes = (time.time() - script_start_time) / 60
    print(f"Total evaluation time: {total_duration_minutes:.2f} minutes.")
    
    if args.use_wandb:
        wandb.finish()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run evaluation for attention-based unlearning methods with W&B logging.")
    parser.add_argument("--use_wandb", action="store_true", help="Enable Weights & Biases logging.")
    args = parser.parse_args()
    main(args)