import os
import sys
import json
import argparse
import re
import time
from copy import deepcopy
import torch
import pandas as pd
from tqdm import tqdm
import wandb

# Add project root to path
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))

from mvdream.model_zoo import build_model
from mvdream.ldm.models.diffusion.ddim import DDIMSampler
from mvdream.editing.subspace_pruner import SubspacePruner
from mvdream.attention_control import AttentionController, register_attention_control

from evaluation.evaluator import ModelEvaluator, save_run_outputs
from evaluation.metrics import NoiseDiffNormMetric, HessianMetric, HessianMetricSlow, DiversityMetric, BrightEndingMetric, XAttnEntropyMetric
from evaluation.utils import load_uids_from_clusters, uids_to_prompts, load_prompts_from_csv


def run_evaluation_for_prompts(prompts_data, dataset_name, is_memorized, baseline_model, config, args, metrics, diversity_metric, global_prompt_counter):
    """
    Run evaluation for a given set of prompts with a fresh pruned model.
    Returns updated global_prompt_counter.
    """
    print(f"\n--- Evaluating Subspace Pruning for: {dataset_name} (batch size: {len(prompts_data)}) ---")
    
    # --- Create Fresh Pruned Model ---
    print("Creating fresh pruned model for this batch/concept...")
    prompts_for_pruning = [prompt_data["Caption"] for prompt_data in prompts_data]
    
    pruner = SubspacePruner(baseline_model, device=config['device'], sparsity=config['pruning_sparsity'])
    pruning_masks = pruner.find_memorization_subspace(prompts_for_pruning)
    
    # Create fresh copy of baseline model for pruning
    edited_model = deepcopy(baseline_model)
    pruner_edited = SubspacePruner(edited_model, device=config['device'])
    pruner_edited.prune_model_weights(pruning_masks)
    
    # Create wandb table for this dataset/batch
    if args.use_wandb:
        dataset_table = wandb.Table(columns=["prompt_idx", "prompt", "is_memorized", "dataset"])
    
    # --- Process Each Prompt ---
    for local_idx, prompt_data in enumerate(tqdm(prompts_data, desc=f"Processing {dataset_name}")):
        prompt_start_time = time.time()
        prompt = prompt_data["Caption"]
        safe_prompt = re.sub(r'\W+', '_', prompt)[:32]
        
        # Use global counter for unique identification across all batches
        global_idx = global_prompt_counter[0]
        global_prompt_counter[0] += 1
        
        if args.use_wandb:
            dataset_table.add_data(global_idx, prompt, is_memorized, dataset_name)

        baseline_images = []
        mitigated_images = []
        baseline_diversity = None

        # --- STAGE 1: Run BASELINE for all seeds (unmodified model) ---
        if args.run_baseline:
            print(f"\n{'>'*50}\nRunning BASELINE for prompt {global_idx} (local {local_idx})...\n{'<'*50}")
            baseline_sampler = DDIMSampler(baseline_model)
            baseline_evaluator = ModelEvaluator(baseline_sampler, metrics, device=config["device"])
            
            for seed_idx in range(config['num_seeds']):
                base_filename = f"prompt_{global_idx:04d}_{seed_idx:02d}_{safe_prompt}"
                sample_dir = os.path.join("output/baseline", dataset_name, base_filename)
                
                controller = AttentionController(output_dir=sample_dir)
                register_attention_control(baseline_model, controller)
                
                artifacts = {"controller": controller}
                
                baseline_result = baseline_evaluator._process_single_prompt_single_seed(
                    prompt=prompt, prompt_idx=global_idx, seed=seed_idx, num_frames=config['num_frames'],
                    base_output_dir=sample_dir, unlearning_artifacts=artifacts
                )
                
                baseline_images.append(baseline_result["image"])
                baseline_result["metrics"].update({"prompt": prompt, "memorized": is_memorized})
                save_run_outputs(baseline_result, os.path.dirname(sample_dir), base_filename)

                # --- W&B Logging (Baseline) ---
                if args.use_wandb:
                    log_data = {
                        f"baseline/{dataset_name}/metrics/{k}": v for k, v in baseline_result["metrics"].items() 
                        if isinstance(v, (int, float))
                    }
                    log_data[f"baseline/{dataset_name}/images"] = wandb.Image(
                        baseline_result["image"],
                        caption=f"Prompt: {prompt}\nSeed: {seed_idx}"
                    )
                    wandb.log(log_data, step=global_idx * config['num_seeds'] + seed_idx)

                del baseline_result, controller
                torch.cuda.empty_cache()

            # --- Cross-Seed Metrics Calculation (Baseline) ---
            if baseline_images:
                baseline_diversity = diversity_metric.measure(images=baseline_images)
                baseline_cross_seed_json = os.path.join(f"output/baseline/{dataset_name}", f"prompt_{global_idx:04d}_{safe_prompt}_cross_seed.json")
                os.makedirs(os.path.dirname(baseline_cross_seed_json), exist_ok=True)
                with open(baseline_cross_seed_json, 'w') as f:
                    json.dump({"prompt": prompt, "memorized": is_memorized, diversity_metric.name: baseline_diversity}, f, indent=2)
                print(f"Baseline Cross-Seed Artefacts:", baseline_cross_seed_json)
                
                # --- W&B Logging (Baseline Diversity) ---
                if args.use_wandb:
                    wandb.log({
                        f"baseline/{dataset_name}/diversity": baseline_diversity,
                        "prompt_idx": global_idx
                    }, step=global_idx)
                print(f"Baseline diversity for prompt {global_idx}: {baseline_diversity}")
            
            del baseline_images, baseline_sampler, baseline_evaluator

        # --- STAGE 2: Run MITIGATION with pruned model ---
        print(f"\n{'>'*50}\nRunning subspace pruning mitigation for prompt {global_idx} (local {local_idx})...\n{'<'*50}")
        edited_sampler = DDIMSampler(edited_model)
        edited_evaluator = ModelEvaluator(edited_sampler, metrics, device=config["device"])
        
        for seed_idx in range(config['num_seeds']):
            base_filename = f"prompt_{global_idx:04d}_{seed_idx:02d}_{safe_prompt}"
            mitigated_sample_dir = os.path.join(f"output/{config['method_name']}", dataset_name, base_filename)
            os.makedirs(mitigated_sample_dir, exist_ok=True)
            
            controller = AttentionController(output_dir=mitigated_sample_dir)
            register_attention_control(edited_model, controller)
            
            artifacts = {"controller": controller}
            
            mitigated_result = edited_evaluator._process_single_prompt_single_seed(
                prompt=prompt, prompt_idx=global_idx, seed=seed_idx, num_frames=config['num_frames'],
                base_output_dir=mitigated_sample_dir, unlearning_artifacts=artifacts
            )
            
            mitigated_images.append(mitigated_result["image"])
            mitigated_result["metrics"].update({"prompt": prompt, "memorized": False})
            save_run_outputs(mitigated_result, os.path.dirname(mitigated_sample_dir), base_filename)

            # --- W&B Logging (Mitigated) ---
            if args.use_wandb:
                log_data = {
                    f"mitigated/{dataset_name}/metrics/{k}": v for k, v in mitigated_result["metrics"].items() 
                    if isinstance(v, (int, float))
                }
                log_data[f"mitigated/{dataset_name}/images"] = wandb.Image(
                    mitigated_result["image"],
                    caption=f"Prompt: {prompt}\nSeed: {seed_idx}\nMethod: Subspace Pruning"
                )
                wandb.log(log_data, step=global_idx * config['num_seeds'] + seed_idx)

            del mitigated_result, controller
            torch.cuda.empty_cache()

        # --- Cross-Seed Metrics Calculation (Mitigated) ---
        if mitigated_images:
            mitigated_diversity = diversity_metric.measure(images=mitigated_images)
            mitigated_cross_seed_json = os.path.join(f"output/{config['method_name']}/{dataset_name}", f"prompt_{global_idx:04d}_{safe_prompt}_cross_seed.json")
            os.makedirs(os.path.dirname(mitigated_cross_seed_json), exist_ok=True)
            with open(mitigated_cross_seed_json, 'w') as f:
                json.dump({"prompt": prompt, "memorized": False, diversity_metric.name: mitigated_diversity}, f, indent=2)
            print(f"Mitigated Cross-Seed Artefacts:", mitigated_cross_seed_json)
            
            # --- W&B Logging (Mitigated Diversity & Comparison) ---
            if args.use_wandb:
                wandb.log({
                    f"mitigated/{dataset_name}/diversity": mitigated_diversity,
                    "prompt_idx": global_idx
                }, step=global_idx)
                
                # Log a bar chart comparing diversities (only if baseline was run)
                if baseline_diversity is not None:
                    diversity_plot = wandb.plot.bar(
                        wandb.Table(
                            columns=["stage", "diversity"],
                            data=[["baseline", baseline_diversity], ["mitigated", mitigated_diversity]]
                        ),
                        "stage",
                        "diversity",
                        title=f"Diversity Comparison for Prompt {global_idx}"
                    )
                    wandb.log({f"plots/{dataset_name}/diversity_comparison": diversity_plot}, step=global_idx)
            
            print(f"Mitigated diversity for prompt {global_idx}: {mitigated_diversity}")
        
        del mitigated_images

        # --- W&B Logging (Timing) ---
        prompt_duration = time.time() - prompt_start_time
        print(f"Prompt {global_idx} (local {local_idx}) processed in {prompt_duration:.2f} seconds.")
        if args.use_wandb:
            wandb.log({
                f"timing/{dataset_name}/prompt_duration": prompt_duration,
                "prompt_idx": global_idx
            }, step=global_idx)

    # Log the dataset-specific table
    if args.use_wandb:
        wandb.log({f"prompts/{dataset_name}_prompts": dataset_table})

    # Clean up the pruned model
    del edited_model, pruner, pruner_edited
    torch.cuda.empty_cache()
    
    return global_prompt_counter[0]


def main(args):
    # --- Configuration ---
    config = {
        "method_name": "subspace_prune",
        "device": "cuda",
        "model_name": "sd-v1.5-4view",
        "num_seeds": 4, 
        "num_frames": 4,
        "pruning_sparsity": 0.0001,
        "pruning_batchsize": 50,  # New parameter for LAION batching
        "prompts_for_pruning_path": "data/nemo-prompts/memorized_laion_prompts.csv",
    }
    config["output_dir"] = f"output/{config['method_name']}"
    
    datasets_to_eval = [
        {"name": "laion_memorized", "type": "csv", "path": "data/nemo-prompts/memorized_laion_prompts.csv", "is_memorized": True},
        {"name": "laion_unmemorized", "path": "data/nemo-prompts/unmemorized_laion_prompts.csv", "is_memorized": False},
    ]

    # --- W&B Initialization ---
    if args.use_wandb:
        wandb.init(
            project="mvdream-evaluation",
            name=f"eval-{config['method_name']}-{time.strftime('%Y%m%d-%H%M%S')}",
            config={
                "method": config["method_name"],
                "model_name": config["model_name"],
                "num_seeds": config["num_seeds"],
                "num_frames": config["num_frames"],
                "datasets": [d["name"] for d in datasets_to_eval],
                "pruning_sparsity": config["pruning_sparsity"],
                "pruning_batchsize": config["pruning_batchsize"],
            }
        )

    script_start_time = time.time()

    # --- Setup ---
    torch.manual_seed(42)
    
    # --- Load Baseline Model (only once) ---
    print("Loading base model...")
    baseline_model = build_model(config["model_name"]).to(config["device"])
    baseline_model.device = config["device"]
    baseline_model.dtype = torch.float32 
    baseline_model.image_size = 256
    baseline_model.num_frames = config["num_frames"]

    metrics = [NoiseDiffNormMetric(), HessianMetric(), HessianMetricSlow(), BrightEndingMetric(), XAttnEntropyMetric(), DiversityMetric(device=config["device"])]
    diversity_metric = metrics[-1]

    # Global prompt counter to ensure unique IDs across all batches/concepts
    global_prompt_counter = [0]

    # --- Process Each Dataset ---
    for dataset in datasets_to_eval:
        print(f"\n=== Processing Dataset: {dataset['name']} ===")
        
        # Load prompts data
        if dataset['path'].endswith('.json'):
            # Objaverse: Per-concept processing
            uids_for_eval = load_uids_from_clusters(dataset["path"], dataset["concept_key"])
            prompts_data = uids_to_prompts(uids_for_eval)
            
            # Process all prompts from this concept in one go
            global_prompt_counter[0] = run_evaluation_for_prompts(
                prompts_data, dataset["name"], dataset["is_memorized"], 
                baseline_model, config, args, metrics, diversity_metric, global_prompt_counter
            )
            
        else:
            # LAION: Batched processing
            prompts_df = pd.read_csv(dataset['path'], sep=';')
            prompts_data = prompts_df.to_dict('records')
            
            # Process in batches
            batch_size = config["pruning_batchsize"]
            total_batches = (len(prompts_data) + batch_size - 1) // batch_size
            
            print(f"Processing {len(prompts_data)} LAION prompts in {total_batches} batches of size {batch_size}")
            
            for batch_idx in range(total_batches):
                start_idx = batch_idx * batch_size
                end_idx = min(start_idx + batch_size, len(prompts_data))
                batch_prompts = prompts_data[start_idx:end_idx]
                
                batch_name = f"{dataset['name']}_batch_{batch_idx:03d}"
                print(f"\n--- Processing Batch {batch_idx + 1}/{total_batches}: {batch_name} ---")
                print(f"Batch contains prompts {start_idx} to {end_idx - 1} ({len(batch_prompts)} prompts)")
                
                global_prompt_counter[0] = run_evaluation_for_prompts(
                    batch_prompts, batch_name, dataset["is_memorized"],
                    baseline_model, config, args, metrics, diversity_metric, global_prompt_counter
                )

    # Clean up baseline model
    del baseline_model
    torch.cuda.empty_cache()

    # --- Final Timing Log ---
    total_duration_minutes = (time.time() - script_start_time) / 60
    print(f"Total evaluation time: {total_duration_minutes:.2f} minutes.")
    print(f"Total prompts processed: {global_prompt_counter[0]}")
    
    if args.use_wandb:
        wandb.log({
            "total_duration_minutes": total_duration_minutes,
            "total_prompts_processed": global_prompt_counter[0]
        })
        wandb.finish()

    print("\nSubspace Pruning evaluation complete.")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run subspace pruning evaluation with optional W&B logging.")
    parser.add_argument("--use_wandb", action="store_true", help="Enable Weights & Biases logging.")
    parser.add_argument("--run_baseline", action="store_true", help="Run baseline for sanity check.")
    args = parser.parse_args()
    main(args)