import os
import sys
import json
import argparse
import re
import time
import torch
import pandas as pd
from tqdm import tqdm
import wandb
from copy import deepcopy

sys.path.append(os.path.join(os.path.dirname(__file__), '..'))

from mvdream.model_zoo import build_model
from mvdream.ldm.models.diffusion.ddim import DDIMSampler
from mvdream.editing.nemo import NeMoEditor
from mvdream.attention_control import AttentionController, register_attention_control

from evaluation.evaluator import ModelEvaluator, save_run_outputs
from evaluation.metrics import NoiseDiffNormMetric, HessianMetric, HessianMetricSlow, DiversityMetric, BrightEndingMetric, XAttnEntropyMetric
from evaluation.utils import load_uids_from_clusters, uids_to_prompts, load_prompts_from_csv

def main(args):
    # --- Configuration ---
    config = {
        "method_name": "nemo",
        "device": "cuda",
        "model_name": "sd-v1.5-4view",
        "num_seeds": 4, 
        "num_frames": 4,
    }
    config["output_dir"] = f"output/{config['method_name']}"
    
    datasets_to_eval = [
        {"name": "laion_memorized", "type": "csv", "path": "data/nemo-prompts/memorized_laion_prompts.csv", "is_memorized": True},
        {"name": "laion_unmemorized", "path": "data/nemo-prompts/unmemorized_laion_prompts.csv", "is_memorized": False},
    ]
    non_mem_dataset_path = "data/nemo-prompts/unmemorized_laion_prompts.csv"

    # --- W&B Initialization ---
    if args.use_wandb:
        wandb.init(
            project="mvdream-evaluation",
            name=f"eval-{config['method_name']}-{time.strftime('%Y%m%d-%H%M%S')}",
            config={
                "method": config["method_name"],
                "model_name": config["model_name"],
                "num_seeds": config["num_seeds"],
                "num_frames": config["num_frames"],
                "datasets": [d["name"] for d in datasets_to_eval],
            }
        )

    script_start_time = time.time()

    # --- Setup ---
    torch.manual_seed(42)
    model = build_model(config["model_name"]).to(config["device"])
    model.device = config["device"]
    model.dtype = torch.float32 
    model.image_size = 256
    model.num_frames = config["num_frames"]
    sampler = DDIMSampler(model)
    
    # --- Instantiate NeMo Editor ---
    # NeMo needs a hold-out set of non-memorized prompts to learn baseline activations.
    print("Loading non-memorized prompts for NeMo calibration...")
    non_mem_prompts = pd.read_csv(non_mem_dataset_path, sep=';')['Caption'].tolist()[:100] # Use a subset for speed
    nemo_editor = NeMoEditor(model, non_mem_prompts=non_mem_prompts, device=config["device"])
    
    # Standard evaluator will be used for generation
    metrics = [NoiseDiffNormMetric(), HessianMetric(), HessianMetricSlow(), BrightEndingMetric(), XAttnEntropyMetric(), DiversityMetric(device=config["device"])]
    diversity_metric = metrics[-1]
    evaluator = ModelEvaluator(sampler, metrics, device=config["device"])

    # --- Custom NeMo Evaluation Loop ---
    for dataset in datasets_to_eval:
        if not dataset['is_memorized']: 
            continue # NeMo is only applied to memorized prompts

        print(f"\n--- Evaluating NeMo for: {dataset['name']} ---")
        
        # Load prompts data
        if dataset['path'].endswith('.json'):
            uids_for_eval = load_uids_from_clusters(dataset["path"], dataset["concept_key"])
            prompts_data = uids_to_prompts(uids_for_eval)
        else:
            prompts_data = pd.read_csv(dataset['path'], sep=';').to_dict('records')
            
        # Create a wandb.Table for this dataset's prompts
        if args.use_wandb:
            dataset_table = wandb.Table(columns=["prompt_idx", "prompt", "is_memorized", "neurons_blocked"])

        for idx, prompt_data in enumerate(tqdm(prompts_data, desc=f"Processing {dataset['name']}")):
            prompt = prompt_data["Caption"]
            safe_prompt = re.sub(r'\W+', '_', prompt)[:32]
            prompt_start_time = time.time()
            
            print(f"\n{'='*20} Processing Prompt {idx}: {prompt[:60]} {'='*20}")

            baseline_images = []
            mitigated_images = []
            
            # --- STAGE 1: Run BASELINE for all seeds (no hooks) ---
            if args.run_baseline:
                print(f"\n{'>'*50}\nRunning BASELINE for prompt {idx}...\n{'<'*50}")
                for seed_idx in range(config['num_seeds']):
                    base_filename = f"prompt_{idx:04d}_{seed_idx:02d}_{safe_prompt}"
                    sample_dir = os.path.join("output/baseline", dataset['name'], base_filename)
                    
                    controller = AttentionController(output_dir=sample_dir)
                    register_attention_control(model, controller)
                    
                    artifacts = {"controller": controller}
                    
                    baseline_result = evaluator._process_single_prompt_single_seed(
                        prompt=prompt, prompt_idx=idx, seed=seed_idx, num_frames=config['num_frames'],
                        base_output_dir=sample_dir, unlearning_artifacts=artifacts
                    )
                    
                    baseline_images.append(baseline_result["image"])
                    baseline_result["metrics"].update({"prompt": prompt, "memorized": dataset["is_memorized"]})
                    save_run_outputs(baseline_result, os.path.dirname(sample_dir), base_filename)

                    # --- W&B Logging (Baseline) ---
                    if args.use_wandb:
                        log_data = {
                            f"baseline/{dataset['name']}/metrics/{k}": v for k, v in baseline_result["metrics"].items() 
                            if isinstance(v, (int, float))
                        }
                        log_data[f"baseline/{dataset['name']}/images"] = wandb.Image(
                            baseline_result["image"],
                            caption=f"Prompt: {prompt}\nSeed: {seed_idx}"
                        )
                        wandb.log(log_data, step=idx * config['num_seeds'] + seed_idx)

                    del baseline_result, controller
                    torch.cuda.empty_cache()

                # --- Cross-Seed Metrics Calculation (Baseline) ---
                if baseline_images:
                    baseline_diversity = diversity_metric.measure(images=baseline_images)
                    baseline_cross_seed_json = os.path.join(f"output/baseline/{dataset['name']}", f"prompt_{idx:04d}_{safe_prompt}_cross_seed.json")
                    os.makedirs(os.path.dirname(baseline_cross_seed_json), exist_ok=True)
                    with open(baseline_cross_seed_json, 'w') as f:
                        json.dump({"prompt": prompt, "memorized": dataset["is_memorized"], diversity_metric.name: baseline_diversity,}, f, indent=2)
                    print(f"Baseline Cross-Seed Artefacts:", baseline_cross_seed_json)

                    # --- W&B Logging (Baseline Diversity) ---
                    if args.use_wandb:
                        wandb.log({
                            f"baseline/{dataset['name']}/diversity": baseline_diversity,
                            "prompt_idx": idx
                        }, step=idx)
                    print(f"Baseline diversity for prompt {idx}: {baseline_diversity}")
                
                del baseline_images

            # --- STAGE 2: Find memorization neurons for THIS prompt ---
            print(f"\n{'>'*50}\nFinding neurons and running NeMo mitigation for prompt {idx}...\n{'<'*50}")
            neurons_to_block = nemo_editor.find_neurons(prompt)
            neurons_blocked = any(neurons_to_block.values())
            
            if args.use_wandb:
                dataset_table.add_data(idx, prompt, dataset["is_memorized"], neurons_blocked)
            
            if not neurons_blocked:
                print(" > NeMo found no neurons to block for this prompt. Skipping unlearned generation.")
                continue

            # --- STAGE 3: Run MITIGATION with hooks active ---
            nemo_editor.register_hooks(neurons_to_block)
            
            for seed_idx in range(config['num_seeds']):
                base_filename = f"prompt_{idx:04d}_{seed_idx:02d}_{safe_prompt}"
                mitigated_sample_dir = os.path.join(f"output/{config['method_name']}", dataset['name'], base_filename)
                os.makedirs(mitigated_sample_dir, exist_ok=True)
                
                controller = AttentionController(output_dir=mitigated_sample_dir)
                register_attention_control(model, controller)
                
                artifacts = {"controller": controller}
                
                mitigated_result = evaluator._process_single_prompt_single_seed(
                    prompt=prompt, prompt_idx=idx, seed=seed_idx, num_frames=config['num_frames'],
                    base_output_dir=mitigated_sample_dir, unlearning_artifacts=artifacts
                )
                
                mitigated_images.append(mitigated_result["image"])
                mitigated_result["metrics"].update({"prompt": prompt, "memorized": dataset["is_memorized"]})
                save_run_outputs(mitigated_result, os.path.dirname(mitigated_sample_dir), base_filename)

                # --- W&B Logging (Mitigated) ---
                if args.use_wandb:
                    log_data = {
                        f"mitigated/{dataset['name']}/metrics/{k}": v for k, v in mitigated_result["metrics"].items() 
                        if isinstance(v, (int, float))
                    }
                    log_data[f"mitigated/{dataset['name']}/images"] = wandb.Image(
                        mitigated_result["image"],
                        caption=f"Prompt: {prompt}\nSeed: {seed_idx}\nMethod: NeMo"
                    )
                    wandb.log(log_data, step=idx * config['num_seeds'] + seed_idx)

                del mitigated_result, controller
                torch.cuda.empty_cache()
            
            nemo_editor.remove_hooks() # IMPORTANT: Clean up hooks before next prompt

            # --- Cross-Seed Metrics Calculation (Mitigated) ---
            if mitigated_images:
                mitigated_diversity = diversity_metric.measure(images=mitigated_images)
                mitigated_cross_seed_json = os.path.join(f"output/{config['method_name']}/{dataset['name']}", f"prompt_{idx:04d}_{safe_prompt}_cross_seed.json")
                os.makedirs(os.path.dirname(mitigated_cross_seed_json), exist_ok=True)
                with open(mitigated_cross_seed_json, 'w') as f:
                    json.dump({"prompt": prompt, "memorized": dataset["is_memorized"], diversity_metric.name: mitigated_diversity}, f, indent=2)
                print(f"Mitigated Cross-Seed Artefacts:", mitigated_cross_seed_json)
                
                # --- W&B Logging (Mitigated Diversity & Comparison) ---
                if args.use_wandb:
                    wandb.log({
                        f"mitigated/{dataset['name']}/diversity": mitigated_diversity,
                        "prompt_idx": idx
                    }, step=idx)
                
                print(f"Mitigated diversity for prompt {idx}: {mitigated_diversity}")
            
            del mitigated_images

            # --- W&B Logging (Timing) ---
            if args.use_wandb:
                prompt_duration = time.time() - prompt_start_time
                wandb.log({
                    f"timing/{dataset['name']}/prompt_duration": prompt_duration,
                    "prompt_idx": idx
                }, step=idx)
            print(f"Prompt {idx} processed in {time.time() - prompt_start_time:.2f} seconds.")

        # Log the dataset-specific table
        if args.use_wandb:
            wandb.log({f"prompts/{dataset['name']}_prompts": dataset_table})

    # --- Final Timing Log ---
    total_duration_minutes = (time.time() - script_start_time) / 60
    print(f"Total evaluation time: {total_duration_minutes:.2f} minutes.")
    
    if args.use_wandb:
        wandb.finish()

    print("\nNeMo evaluation complete.")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run NeMo evaluation with optional W&B logging.")
    parser.add_argument("--use_wandb", action="store_true", help="Enable Weights & Biases logging.")
    parser.add_argument("--run_baseline", action="store_true", help="Run baseline for sanity check.")
    args = parser.parse_args()
    main(args)