import os
import sys
from copy import deepcopy
import re
import json
import argparse
import time
import torch
import pandas as pd
from tqdm import tqdm
import numpy as np
import wandb

sys.path.append(os.path.dirname(os.path.dirname(__file__)))

# MVDream and Evaluation Framework Imports
from mvdream.model_zoo import build_model
from mvdream.ldm.models.diffusion.ddim import DDIMSampler
from mvdream.editing.uce import UCEEditor
from mvdream.attention_control import AttentionController, register_attention_control

from evaluation.evaluator import ModelEvaluator, save_run_outputs
from evaluation.metrics import (
    NoiseDiffNormMetric, HessianMetric, HessianMetricSlow, DiversityMetric, 
    BrightEndingMetric, XAttnEntropyMetric
)
from evaluation.utils import load_uids_from_clusters, uids_to_prompts

def main(args):
    # --- Configuration ---
    config = {
        "method": "uce",
        "device": "cuda",
        "model_name": "sd-v1.5-4view",
        "num_seeds": 4, # We will evaluate each edit with 4 different seeds
        "num_frames": 4,
        "uce_params": { # Base params, edit/guide concepts will be dynamic
            "preserve_concepts": ["car", "dog", "a painting of a flower"],
            "erase_scale": 1.0,
            "preserve_scale": 0.5,
            "lamb": 0.5
        }
    }
    
    datasets = [
        {"name": "laion_memorized", "type": "csv", "path": "data/nemo-prompts/memorized_laion_prompts.csv", "is_memorized": True},
    ]

    # --- W&B Initialization ---
    if args.use_wandb:
        wandb.init(
            project="mvdream-evaluation",
            name=f"eval-{config['method']}-{time.strftime('%Y%m%d-%H%M%S')}",
            config={
                "method": config["method"],
                "model_name": config["model_name"],
                "num_seeds": config["num_seeds"],
                "num_frames": config["num_frames"],
                "datasets": [d["name"] for d in datasets],
                "uce_params": config["uce_params"],
            }
        )

    script_start_time = time.time()

    # --- Setup ---
    torch.manual_seed(42)
    # Load the baseline model only ONCE
    baseline_model = build_model(config["model_name"]).to(config["device"])
    baseline_model.device = config["device"]
    baseline_model.dtype = torch.float32 
    baseline_model.image_size = 256
    baseline_model.num_frames = config["num_frames"]

    metrics = [
        NoiseDiffNormMetric(), HessianMetric(), HessianMetricSlow(), BrightEndingMetric(), 
        XAttnEntropyMetric(), DiversityMetric(device=config["device"])
    ]
    diversity_metric = metrics[-1]
    
    # --- Main Per-Prompt Evaluation Loop ---
    for dataset in datasets:
        print(f"\n--- Evaluating UCE Multi-Concept for: {dataset['name']} ---")
        
        # Load prompts data
        if dataset['path'].endswith('.json'):
            uids_for_eval = load_uids_from_clusters(dataset["path"], dataset["concept_key"])
            prompts_data = uids_to_prompts(uids_for_eval)
        else:
            prompts_data = pd.read_csv(dataset['path'], sep=';').to_dict('records')
        
        # Create a wandb.Table for this dataset's prompts
        if args.use_wandb:
            dataset_table = wandb.Table(columns=["prompt_idx", "prompt", "is_memorized", "edit_concept", "guide_concept"])
        
        for idx, prompt_data in enumerate(tqdm(prompts_data, desc=f"Processing {dataset['name']} with UCE")):
            prompt = prompt_data["Caption"]
            safe_prompt = re.sub(r'\W+', '_', prompt)[:32]
            prompt_start_time = time.time()

            baseline_images = []
            mitigated_images = []
            baseline_diversity = None

            # Dynamically set the concepts for this specific prompt
            dynamic_uce_params = config["uce_params"].copy()
            dynamic_uce_params["edit_concepts"] = [prompt]
            dynamic_uce_params["guide_concepts"] = ["a high-quality photograph"] # Use a universal generic guide
            
            if args.use_wandb:
                dataset_table.add_data(idx, prompt, dataset["is_memorized"], 
                                     dynamic_uce_params["edit_concepts"][0], 
                                     dynamic_uce_params["guide_concepts"][0])

            # --- STAGE 1: Run BASELINE for all seeds (unmodified model) ---
            if args.run_baseline:
                print(f"\n{'>'*50}\nRunning BASELINE for prompt {idx}...\n{'<'*50}")
                baseline_sampler = DDIMSampler(baseline_model)
                baseline_evaluator = ModelEvaluator(baseline_sampler, metrics, device=config["device"])
                
                for seed_idx in range(config['num_seeds']):
                    base_filename = f"prompt_{idx:04d}_{seed_idx:02d}_{safe_prompt}"
                    sample_dir = os.path.join("output/baseline", dataset['name'], base_filename)
                    
                    controller = AttentionController(output_dir=sample_dir)
                    register_attention_control(baseline_model, controller)
                    
                    artifacts = {"controller": controller}
                    
                    baseline_result = baseline_evaluator._process_single_prompt_single_seed(
                        prompt=prompt, prompt_idx=idx, seed=seed_idx, num_frames=config['num_frames'],
                        base_output_dir=sample_dir, unlearning_artifacts=artifacts
                    )
                    
                    baseline_images.append(baseline_result["image"])
                    baseline_result["metrics"].update({"prompt": prompt, "memorized": dataset["is_memorized"]})
                    save_run_outputs(baseline_result, os.path.dirname(sample_dir), base_filename)

                    # --- W&B Logging (Baseline) ---
                    if args.use_wandb:
                        log_data = {
                            f"baseline/{dataset['name']}/metrics/{k}": v for k, v in baseline_result["metrics"].items() 
                            if isinstance(v, (int, float))
                        }
                        log_data[f"baseline/{dataset['name']}/images"] = wandb.Image(
                            baseline_result["image"],
                            caption=f"Prompt: {prompt}\nSeed: {seed_idx}"
                        )
                        wandb.log(log_data, step=idx * config['num_seeds'] + seed_idx)

                    del baseline_result, controller
                    torch.cuda.empty_cache()

                # --- Cross-Seed Metrics Calculation (Baseline) ---
                if baseline_images:
                    baseline_diversity = diversity_metric.measure(images=baseline_images)
                    baseline_cross_seed_json = os.path.join(f"output/baseline/{dataset['name']}", f"prompt_{idx:04d}_{safe_prompt}_cross_seed.json")
                    os.makedirs(os.path.dirname(baseline_cross_seed_json), exist_ok=True)
                    with open(baseline_cross_seed_json, 'w') as f:
                        json.dump({"prompt": prompt, "memorized": dataset["is_memorized"], diversity_metric.name: baseline_diversity}, f, indent=2)
                    print(f"Baseline Cross-Seed Artefacts:", baseline_cross_seed_json)
                    
                    # --- W&B Logging (Baseline Diversity) ---
                    if args.use_wandb:
                        wandb.log({
                            f"baseline/{dataset['name']}/diversity": baseline_diversity,
                            "prompt_idx": idx
                        }, step=idx)
                    print(f"Baseline diversity for prompt {idx}: {baseline_diversity}")
                
                del baseline_images, baseline_sampler, baseline_evaluator

            # --- STAGE 2: Create a fresh, uniquely edited model for THIS prompt ---
            print(f"\n{'>'*50}\nApplying UCE and running mitigation for prompt {idx}...\n{'<'*50}")
            model_to_edit = deepcopy(baseline_model)
            editor = UCEEditor(model_to_edit, device=config['device'])
            
            # Apply the specific erasure
            editor.erase_concept(**dynamic_uce_params)
            
            # --- STAGE 3: Evaluate the newly edited model on this single prompt ---
            sampler = DDIMSampler(model_to_edit)
            evaluator = ModelEvaluator(sampler, metrics, device=config["device"])

            for seed_idx in range(config['num_seeds']):
                base_filename = f"prompt_{idx:04d}_{seed_idx:02d}_{safe_prompt}"
                mitigated_sample_dir = os.path.join(f"output/{config['method']}", dataset['name'], base_filename)
                os.makedirs(mitigated_sample_dir, exist_ok=True)

                # For UCE, we don't need an active controller, just a passive one for metrics
                controller = AttentionController(output_dir=mitigated_sample_dir)
                register_attention_control(model_to_edit, controller)
                
                artifacts = {"controller": controller}
                
                mitigated_result = evaluator._process_single_prompt_single_seed(
                    prompt=prompt, prompt_idx=idx, seed=seed_idx, num_frames=config['num_frames'],
                    base_output_dir=mitigated_sample_dir, unlearning_artifacts=artifacts
                )
                mitigated_images.append(mitigated_result["image"])
                mitigated_result["metrics"].update({
                    "prompt": prompt, 
                    "memorized": dataset["is_memorized"],
                    "edit_concept": dynamic_uce_params["edit_concepts"][0],
                    "guide_concept": dynamic_uce_params["guide_concepts"][0]
                })
                save_run_outputs(mitigated_result, os.path.dirname(mitigated_sample_dir), base_filename)

                # --- W&B Logging (Mitigated) ---
                if args.use_wandb:
                    log_data = {
                        f"mitigated/{dataset['name']}/metrics/{k}": v for k, v in mitigated_result["metrics"].items() 
                        if isinstance(v, (int, float))
                    }
                    log_data[f"mitigated/{dataset['name']}/images"] = wandb.Image(
                        mitigated_result["image"],
                        caption=f"Prompt: {prompt}\nEdit: {dynamic_uce_params['edit_concepts'][0]}\nGuide: {dynamic_uce_params['guide_concepts'][0]}\nSeed: {seed_idx}\nMethod: UCE Multi"
                    )
                    wandb.log(log_data, step=idx * config['num_seeds'] + seed_idx)

                del mitigated_result, controller
                torch.cuda.empty_cache()
            
            # --- Cross-Seed Metrics Calculation (Mitigated) ---
            if mitigated_images:
                mitigated_diversity = diversity_metric.measure(images=mitigated_images)
                mitigated_cross_seed_json = os.path.join(f"output/{config['method']}/{dataset['name']}", f"prompt_{idx:04d}_{safe_prompt}_cross_seed.json")
                os.makedirs(os.path.dirname(mitigated_cross_seed_json), exist_ok=True)
                with open(mitigated_cross_seed_json, 'w') as f:
                    json.dump({
                        "prompt": prompt, 
                        "memorized": dataset["is_memorized"], 
                        "edit_concept": dynamic_uce_params["edit_concepts"][0],
                        "guide_concept": dynamic_uce_params["guide_concepts"][0],
                        diversity_metric.name: mitigated_diversity
                    }, f, indent=2)
                print(f"Mitigated Cross-Seed Artefacts:", mitigated_cross_seed_json)
                
                # --- W&B Logging (Mitigated Diversity & Comparison) ---
                if args.use_wandb:
                    wandb.log({
                        f"mitigated/{dataset['name']}/diversity": mitigated_diversity,
                        "prompt_idx": idx
                    }, step=idx)
                    
                    # Log a bar chart comparing diversities (only if baseline was run)
                    if baseline_diversity is not None:
                        diversity_plot = wandb.plot.bar(
                            wandb.Table(
                                columns=["stage", "diversity"],
                                data=[["baseline", baseline_diversity], ["mitigated", mitigated_diversity]]
                            ),
                            "stage",
                            "diversity",
                            title=f"Diversity Comparison for Prompt {idx}"
                        )
                        wandb.log({f"plots/{dataset['name']}/diversity_comparison": diversity_plot}, step=idx)
                
                print(f"Mitigated diversity for prompt {idx}: {mitigated_diversity}")

            del model_to_edit, editor, sampler, evaluator, mitigated_images
            torch.cuda.empty_cache()
            print("Cleaned up memory for next iteration.")

            # --- W&B Logging (Timing) ---
            prompt_duration = time.time() - prompt_start_time
            print(f"Prompt {idx} processed in {prompt_duration:.2f} seconds.")
            if args.use_wandb:
                wandb.log({
                    f"timing/{dataset['name']}/prompt_duration": prompt_duration,
                    "prompt_idx": idx
                }, step=idx)

        # Log the dataset-specific table
        if args.use_wandb:
            wandb.log({f"prompts/{dataset['name']}_prompts": dataset_table})

    # --- Final Timing Log ---
    total_duration_minutes = (time.time() - script_start_time) / 60
    print(f"Total evaluation time: {total_duration_minutes:.2f} minutes.")
    
    if args.use_wandb:
        wandb.finish()

    print("\nUCE Multi-Concept evaluation complete.")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run UCE Multi-Concept evaluation with optional W&B logging.")
    parser.add_argument("--use_wandb", action="store_true", help="Enable Weights & Biases logging.")
    parser.add_argument("--run_baseline", action="store_true", help="Run baseline for sanity check.")
    args = parser.parse_args()
    main(args)