import os
import sys
import argparse
import re
import json
import time
from copy import deepcopy
import torch
import pandas as pd
from tqdm import tqdm
import wandb

sys.path.append(os.path.dirname(os.path.dirname(__file__)))

from mvdream.model_zoo import build_model
from mvdream.ldm.models.diffusion.ddim import DDIMSampler
from mvdream.editing.uce import UCEEditor
from mvdream.attention_control import AttentionController, register_attention_control

from evaluation.evaluator import ModelEvaluator, save_run_outputs
from evaluation.metrics import NoiseDiffNormMetric, HessianMetric, HessianMetricSlow, DiversityMetric, BrightEndingMetric, XAttnEntropyMetric
from evaluation.utils import load_uids_from_clusters, uids_to_prompts

def main(args):
    # --- Configuration ---
    config = {
        "method_name": "uce",
        "device": "cuda",
        "model_name": "sd-v1.5-4view",
        "num_seeds": 4,
        "num_frames": 4,
        "uce_params": {
            "edit_concepts": ["teddy bear", "substancepainter"],
            "guide_concepts": ["plushie", "toy"],
            "preserve_concepts": ["car", "house", "tree"],
            "erase_scale": 1.0,
            "preserve_scale": 0.5,
        }
    }
    config["output_dir"] = f"output/{config['method_name']}"
    
    datasets = [
        {"name": "objaverse_fazbear_", "type": "json", "path": "data/objaverse-dupes/aggregated_clusters.json", "concept_key": "mario_", "is_memorized": True},
    ]

    # --- W&B Initialization ---
    if args.use_wandb:
        wandb.init(
            project="mvdream-evaluation",
            name=f"eval-{config['method_name']}-{time.strftime('%Y%m%d-%H%M%S')}",
            config={
                "method": config["method_name"],
                "model_name": config["model_name"],
                "num_seeds": config["num_seeds"],
                "num_frames": config["num_frames"],
                "datasets": [d["name"] for d in datasets],
                "uce_params": config["uce_params"],
            }
        )

    script_start_time = time.time()

    # --- Setup ---
    torch.manual_seed(42)
    
    # --- Load Baseline Model ---
    print("Loading base model...")
    baseline_model = build_model(config["model_name"]).to(config["device"])
    baseline_model.device = config["device"]
    baseline_model.dtype = torch.float32 
    baseline_model.image_size = 256
    baseline_model.num_frames = config["num_frames"]

    metrics = [NoiseDiffNormMetric(), HessianMetric(), HessianMetricSlow(), BrightEndingMetric(), XAttnEntropyMetric(), DiversityMetric(device=config["device"])]
    diversity_metric = metrics[-1]

    # --- Loop and Evaluate Each Dataset ---
    for dataset in datasets:
        print(f"\n--- Evaluating UCE model on: {dataset['name']} ---")
        
        # Load prompts data
        if dataset['path'].endswith('.json'):
            uids_for_eval = load_uids_from_clusters(dataset["path"], dataset["concept_key"])
            prompts_data = uids_to_prompts(uids_for_eval)
        else:
            prompts_data = pd.read_csv(dataset['path'], sep=';').to_dict('records')
        
        # Create a wandb.Table for this dataset's prompts
        if args.use_wandb:
            dataset_table = wandb.Table(columns=["prompt_idx", "prompt", "is_memorized"])

        for idx, prompt_data in enumerate(tqdm(prompts_data, desc=f"Processing {dataset['name']}")):
            prompt_start_time = time.time()
            prompt = prompt_data["Caption"]
            print(prompt_data)
            safe_prompt = re.sub(r'\W+', '_', prompt)[:32]
            
            if args.use_wandb:
                dataset_table.add_data(idx, prompt, dataset["is_memorized"])

            baseline_images = []
            mitigated_images = []
            baseline_diversity = None

            # --- STAGE 1: Run BASELINE for all seeds (unmodified model) ---
            if args.run_baseline:
                print(f"\n{'>'*50}\nRunning BASELINE for prompt {idx}...\n{'<'*50}")
                baseline_sampler = DDIMSampler(baseline_model)
                baseline_evaluator = ModelEvaluator(baseline_sampler, metrics, device=config["device"])
                
                for seed_idx in range(config['num_seeds']):
                    base_filename = f"prompt_{idx:04d}_{seed_idx:02d}_{safe_prompt}"
                    sample_dir = os.path.join("output/baseline", dataset['name'], base_filename)
                    
                    controller = AttentionController(output_dir=sample_dir, config={'steps': [0, 49], 'blocks': ['down'], 'xattn_indices': 'all'})
                    register_attention_control(baseline_model, controller)
                    
                    artifacts = {"controller": controller}
                    
                    baseline_result = baseline_evaluator._process_single_prompt_single_seed(
                        prompt=prompt, prompt_idx=idx, seed=seed_idx, num_frames=config['num_frames'],
                        base_output_dir=sample_dir, unlearning_artifacts=artifacts
                    )
                    
                    baseline_images.append(baseline_result["image"])
                    baseline_result["metrics"].update({"prompt": prompt, "memorized": dataset["is_memorized"]})
                    save_run_outputs(baseline_result, os.path.dirname(sample_dir), base_filename)

                    # --- W&B Logging (Baseline) ---
                    if args.use_wandb:
                        log_data = {
                            f"baseline/{dataset['name']}/metrics/{k}": v for k, v in baseline_result["metrics"].items() 
                            if isinstance(v, (int, float))
                        }
                        log_data[f"baseline/{dataset['name']}/images"] = wandb.Image(
                            baseline_result["image"],
                            caption=f"Prompt: {prompt}\nSeed: {seed_idx}"
                        )
                        wandb.log(log_data, step=idx * config['num_seeds'] + seed_idx)

                    del baseline_result, controller
                    torch.cuda.empty_cache()

                # --- Cross-Seed Metrics Calculation (Baseline) ---
                if baseline_images:
                    baseline_diversity = diversity_metric.measure(images=baseline_images)
                    baseline_cross_seed_json = os.path.join(f"output/baseline/{dataset['name']}", f"prompt_{idx:04d}_{safe_prompt}_cross_seed.json")
                    os.makedirs(os.path.dirname(baseline_cross_seed_json), exist_ok=True)
                    with open(baseline_cross_seed_json, 'w') as f:
                        json.dump({"prompt": prompt, "memorized": dataset["is_memorized"], diversity_metric.name: baseline_diversity}, f, indent=2)
                    print(f"Baseline Cross-Seed Artefacts:", baseline_cross_seed_json)
                    
                    # --- W&B Logging (Baseline Diversity) ---
                    if args.use_wandb:
                        wandb.log({
                            f"baseline/{dataset['name']}/diversity": baseline_diversity,
                            "prompt_idx": idx
                        }, step=idx)
                    print(f"Baseline diversity for prompt {idx}: {baseline_diversity}")
                
                del baseline_images, baseline_sampler, baseline_evaluator

            # --- STAGE 2: Apply UCE and create edited model ---
            print(f"\n{'>'*50}\nApplying UCE and running mitigation for prompt {idx}...\n{'<'*50}")
            edited_model = deepcopy(baseline_model)
            editor = UCEEditor(edited_model)
            editor.erase_concept(**config["uce_params"])
            
            # --- STAGE 3: Run MITIGATION with edited model ---
            edited_sampler = DDIMSampler(edited_model)
            edited_evaluator = ModelEvaluator(edited_sampler, metrics, device=config["device"])
            
            for seed_idx in range(config['num_seeds']):
                base_filename = f"prompt_{idx:04d}_{seed_idx:02d}_{safe_prompt}"
                mitigated_sample_dir = os.path.join(f"output/{config['method_name']}", dataset['name'], base_filename)
                os.makedirs(mitigated_sample_dir, exist_ok=True)
                
                controller = AttentionController(output_dir=mitigated_sample_dir, config={'steps': [0, 49], 'blocks': ['down'], 'xattn_indices': 'all'})
                register_attention_control(edited_model, controller)
                
                artifacts = {"controller": controller}
                
                mitigated_result = edited_evaluator._process_single_prompt_single_seed(
                    prompt=prompt, prompt_idx=idx, seed=seed_idx, num_frames=config['num_frames'],
                    base_output_dir=mitigated_sample_dir, unlearning_artifacts=artifacts
                )
                
                mitigated_images.append(mitigated_result["image"])
                mitigated_result["metrics"].update({"prompt": prompt, "memorized": False})
                save_run_outputs(mitigated_result, os.path.dirname(mitigated_sample_dir), base_filename)

                # --- W&B Logging (Mitigated) ---
                if args.use_wandb:
                    log_data = {
                        f"mitigated/{dataset['name']}/metrics/{k}": v for k, v in mitigated_result["metrics"].items() 
                        if isinstance(v, (int, float))
                    }
                    log_data[f"mitigated/{dataset['name']}/images"] = wandb.Image(
                        mitigated_result["image"],
                        caption=f"Prompt: {prompt}\nSeed: {seed_idx}\nMethod: UCE"
                    )
                    wandb.log(log_data, step=idx * config['num_seeds'] + seed_idx)

                del mitigated_result, controller
                torch.cuda.empty_cache()

            # --- Cross-Seed Metrics Calculation (Mitigated) ---
            if mitigated_images:
                mitigated_diversity = diversity_metric.measure(images=mitigated_images)
                mitigated_cross_seed_json = os.path.join(f"output/{config['method_name']}/{dataset['name']}", f"prompt_{idx:04d}_{safe_prompt}_cross_seed.json")
                os.makedirs(os.path.dirname(mitigated_cross_seed_json), exist_ok=True)
                with open(mitigated_cross_seed_json, 'w') as f:
                    json.dump({"prompt": prompt, "memorized": False, diversity_metric.name: mitigated_diversity}, f, indent=2)
                print(f"Mitigated Cross-Seed Artefacts:", mitigated_cross_seed_json)
                
                # --- W&B Logging (Mitigated Diversity & Comparison) ---
                if args.use_wandb:
                    wandb.log({
                        f"mitigated/{dataset['name']}/diversity": mitigated_diversity,
                        "prompt_idx": idx
                    }, step=idx)
                    
                    # Log a bar chart comparing diversities (only if baseline was run)
                    if baseline_diversity is not None:
                        diversity_plot = wandb.plot.bar(
                            wandb.Table(
                                columns=["stage", "diversity"],
                                data=[["baseline", baseline_diversity], ["mitigated", mitigated_diversity]]
                            ),
                            "stage",
                            "diversity",
                            title=f"Diversity Comparison for Prompt {idx}"
                        )
                        wandb.log({f"plots/{dataset['name']}/diversity_comparison": diversity_plot}, step=idx)
                
                print(f"Mitigated diversity for prompt {idx}: {mitigated_diversity}")
            
            del mitigated_images, edited_model, editor, edited_sampler, edited_evaluator
            torch.cuda.empty_cache()

            # --- W&B Logging (Timing) ---
            prompt_duration = time.time() - prompt_start_time
            print(f"Prompt {idx} processed in {prompt_duration:.2f} seconds.")
            if args.use_wandb:
                wandb.log({
                    f"timing/{dataset['name']}/prompt_duration": prompt_duration,
                    "prompt_idx": idx
                }, step=idx)

        # Log the dataset-specific table
        if args.use_wandb:
            wandb.log({f"prompts/{dataset['name']}_prompts": dataset_table})

    # --- Final Timing Log ---
    total_duration_minutes = (time.time() - script_start_time) / 60
    print(f"Total evaluation time: {total_duration_minutes:.2f} minutes.")
    
    if args.use_wandb:
        wandb.finish()
    
    print("\nUCE evaluation complete.")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run UCE evaluation with optional W&B logging.")
    parser.add_argument("--use_wandb", action="store_true", help="Enable Weights & Biases logging.")
    parser.add_argument("--run_baseline", action="store_true", help="Run baseline for sanity check.")
    args = parser.parse_args()
    main(args)