import os
import re
import sys
import json
import time
import torch
import wandb
import numpy as np
import random
import requests
import argparse
import pandas as pd
from tqdm import tqdm
from PIL import Image

sys.path.append(os.path.dirname(os.path.dirname(__file__)))

from mvdream.model_zoo import build_model
from mvdream.ldm.models.diffusion.ddim import DDIMSampler
from mvdream.attention_control import AttentionController, register_attention_control
from mvdream.camera_utils import get_camera  

from evaluation.utils import NearestNeighborSearch, load_uids_from_clusters, uids_to_prompts
from evaluation.evaluator import ModelEvaluator, save_run_outputs
from evaluation.metrics import (
    NoiseDiffNormMetric, HessianMetric, HessianMetricSlow, DiversityMetric,
    BrightEndingMetric, XAttnEntropyMetric
)

def main(args):
    # --- Configuration ---
    config = {
        "method": "amg", 
        "device": "cuda", 
        "model_name": "sd-v1.5-4view", 
        "num_seeds": 4, 
        "num_frames": 4
    }
    
    datasets = [
        {
            "name": "laion_memorized", 
            "type": "csv", 
            "path": "data/nemo-prompts/memorized_laion_prompts.csv", 
            "is_memorized": True
        },
        {"name": "laion_unmemorized", "path": "data/nemo-prompts/unmemorized_laion_prompts.csv", "is_memorized": False},
    ]
    
    # --- W&B Initialization ---
    if args.use_wandb:
        wandb.init(
            project="mvdream-evaluation",
            name=f"eval-{config['method']}-{time.strftime('%Y%m%d-%H%M%S')}",
            config={
                "method": config["method"],
                "model_name": config["model_name"],
                "num_seeds": config["num_seeds"],
                "num_frames": config["num_frames"],
                "datasets": [d["name"] for d in datasets],
            }
        )

    script_start_time = time.time()
    
    # --- Setup ---
    torch.manual_seed(42)
    model = build_model(config["model_name"]).to(config["device"])
    model.device = config["device"]
    model.dtype = torch.float32 
    model.image_size = 256
    model.num_frames = config["num_frames"]
    sampler = DDIMSampler(model)
    
    # --- AMG Specific Setup ---
    print("Loading training data for AMG...")
    # Load a small subset of a dataset to act as "training data" for NN search
    try:
        training_data_df = pd.read_csv("data/nemo-prompts/unmemorized_laion_prompts.csv", sep=';')
        # Try to download a few images for training data
        training_images = []
        training_captions = []
        
        for i, row in training_data_df.head(50).iterrows():  # Reduced to 50 for faster setup
            try:
                response = requests.get(row['URL'], stream=True, timeout=10)
                if response.status_code == 200:
                    img = Image.open(response.raw).convert('RGB')
                    training_images.append(img)
                    training_captions.append(row.get('Caption', f"Image {i}"))
                    if len(training_images) >= 10:  # Limit to 10 images for speed
                        break
            except Exception as e:
                print(f"Failed to load image {i}: {e}")
                continue
        
        print(f"Successfully loaded {len(training_images)} training images")
        
    except Exception as e:
        print(f"Warning: Could not load training data: {e}")
        print("Creating dummy training data...")
        # Create dummy training data
        training_images = [Image.new('RGB', (256, 256), color=(255, 255, 255)) for _ in range(10)]
        training_captions = [f"dummy_caption_{i}" for i in range(10)]
    
    # Initialize nearest neighbor search
    nn_search = NearestNeighborSearch(model, similarity_metric='clip')
    nn_search.precompute_training_embeddings(training_images, training_captions)
    
    # --- Standard Evaluator for Final Metrics ---
    metrics = [
        NoiseDiffNormMetric(), 
        HessianMetric(), 
        HessianMetricSlow(), 
        BrightEndingMetric(), 
        XAttnEntropyMetric(), 
        DiversityMetric(device=config["device"])
    ]
    diversity_metric = metrics[-1]
    evaluator = ModelEvaluator(sampler, metrics, device=config["device"])

    # --- Main Loop ---
    for dataset in datasets:
        print(f"\n--- Evaluating AMG for: {dataset['name']} ---")
        
        # Load prompts data
        if dataset['path'].endswith('.json'):
            uids_for_eval = load_uids_from_clusters(dataset["path"], dataset["concept_key"])
            prompts_data = uids_to_prompts(uids_for_eval[:10])  # Limit for testing
        else:
            prompts_data = pd.read_csv(dataset['path'], sep=';').to_dict('records')  # Limit for testing
        
        # Create a wandb.Table for prompts
        if args.use_wandb:
            dataset_table = wandb.Table(columns=["prompt_idx", "prompt", "is_memorized"])
        
        for idx, prompt_data in enumerate(tqdm(prompts_data, desc=f"Processing {dataset['name']}")):
            prompt = prompt_data["Caption"]
            safe_prompt = re.sub(r'\W+', '_', prompt)[:32]
            prompt_start_time = time.time()
            
            if args.use_wandb:
                dataset_table.add_data(idx, prompt, dataset["is_memorized"])

            baseline_images = []
            mitigated_images = []
            baseline_diversity = None
            
            # --- STAGE 1: Run BASELINE for all seeds ---
            if args.run_baseline:
                print(f"\n{'>'*50}\nRunning BASELINE for prompt {idx}...\n{'<'*50}")
                for seed_idx in range(config['num_seeds']):
                    torch.manual_seed(seed_idx * 42)  # Set seed for reproducibility
                    
                    base_filename = f"prompt_{idx:04d}_{seed_idx:02d}_{safe_prompt}"
                    sample_dir = os.path.join("output/baseline", dataset['name'], base_filename)
                    
                    controller = AttentionController(output_dir=sample_dir)
                    register_attention_control(model, controller)
                    
                    artifacts = {"controller": controller}
                    
                    baseline_result = evaluator._process_single_prompt_single_seed(
                        prompt=prompt, 
                        prompt_idx=idx, 
                        seed=seed_idx, 
                        num_frames=config['num_frames'],
                        base_output_dir=sample_dir, 
                        unlearning_artifacts=artifacts
                    )
                    
                    baseline_images.append(baseline_result["image"])
                    baseline_result["metrics"].update({"prompt": prompt, "memorized": dataset["is_memorized"]})
                    save_run_outputs(baseline_result, os.path.dirname(sample_dir), base_filename)

                    # --- W&B Logging (Baseline) ---
                    if args.use_wandb:
                        log_data = {
                            f"baseline/{dataset['name']}/metrics/{k}": v 
                            for k, v in baseline_result["metrics"].items() 
                            if isinstance(v, (int, float))
                        }
                        log_data[f"baseline/{dataset['name']}/images"] = wandb.Image(
                            baseline_result["image"],
                            caption=f"Prompt: {prompt}\nSeed: {seed_idx}"
                        )
                        wandb.log(log_data, step=idx * config['num_seeds'] + seed_idx)

                    del baseline_result, controller
                    torch.cuda.empty_cache()

                # --- Cross-Seed Metrics Calculation (Baseline) ---
                if baseline_images:
                    baseline_diversity = diversity_metric.measure(images=baseline_images)
                    baseline_cross_seed_json = os.path.join(
                        f"output/baseline/{dataset['name']}", 
                        f"prompt_{idx:04d}_{safe_prompt}_cross_seed.json"
                    )
                    os.makedirs(os.path.dirname(baseline_cross_seed_json), exist_ok=True)
                    with open(baseline_cross_seed_json, 'w') as f:
                        json.dump({
                            "prompt": prompt, 
                            "memorized": dataset["is_memorized"], 
                            diversity_metric.name: baseline_diversity
                        }, f, indent=2)
                    print(f"Baseline Cross-Seed Artifacts: {baseline_cross_seed_json}")
                    
                    # --- W&B Logging (Baseline Diversity) ---
                    if args.use_wandb:
                        wandb.log({
                            f"baseline/{dataset['name']}/diversity": baseline_diversity,
                            "prompt_idx": idx
                        }, step=idx)
                    print(f"Baseline diversity for prompt {idx}: {baseline_diversity}")
                
                del baseline_images

            # --- STAGE 2: Run AMG MITIGATION ---
            print(f"\n{'>'*50}\nRunning AMG mitigation for prompt {idx}...\n{'<'*50}")
            for seed_idx in range(config['num_seeds']):
                torch.manual_seed(seed_idx * 42)  # Set seed for reproducibility
                
                base_filename = f"prompt_{idx:04d}_{seed_idx:02d}_{safe_prompt}"
                mitigated_sample_dir = os.path.join(f"output/{config['method']}", dataset['name'], base_filename)
                os.makedirs(mitigated_sample_dir, exist_ok=True)
                
                controller = AttentionController(output_dir=mitigated_sample_dir)
                register_attention_control(model, controller)
                
                # Use the AMG sampler - this is the key difference
                try:
                    model.batch_size = max(4, model.num_frames)
                    c = model.get_learned_conditioning([prompt])
                    uc = model.get_learned_conditioning([""])
                    c_ = {"context": c.repeat(model.batch_size, 1, 1)}
                    uc_ = {"context": uc.repeat(model.batch_size, 1, 1)}
                    camera = get_camera(model.num_frames, elevation=random.randint(-15, 30), azimuth_start=random.randint(0, 360), azimuth_span=360)
                    c_["camera"] = uc_["camera"] = camera.repeat(model.batch_size // model.num_frames, 1).to(model.device)
                    c_["num_frames"] = uc_["num_frames"] = model.num_frames
                    
                    generated_latents, intermediates = sampler.sample_amg(
                        S=50, 
                        batch_size=model.batch_size,
                        shape=(model.batch_size, 4, 32, 32),
                        conditioning=c_,
                        unconditional_conditioning=uc_,
                        nn_search=nn_search,
                        guidance_scale=7.5,
                        c_sim=0.3,
                        c_spe=0.3,
                        c_dup=0.3,
                    )
                    
                    artifacts = {
                        "controller": controller, 
                        "amg_latents": generated_latents,
                        "conditioning": c_,
                        "unconditional_conditioning": uc_
                    }
                    
                    
                    # 7. Decode image
                    x_decoded = model.decode_first_stage(generated_latents)
                    x_decoded = torch.clamp((x_decoded + 1.0) / 2.0, min=0.0, max=1.0)
                    x_decoded_np = 255. * x_decoded.permute(0, 2, 3, 1).detach().cpu().numpy()
                    combined_img = np.concatenate([img.astype(np.uint8) for img in x_decoded_np], axis=1)
                    decoded_image_pil = Image.fromarray(combined_img)

                    # 8. Calculate all per-seed metrics
                    calculated_metrics = {}
                    for metric in metrics:
                        if metric.metric_type == "per_seed":
                            score = metric.measure(
                                intermediates=intermediates, model=model,
                                conditioning_context=c_, unconditioning_context=uc_,
                                controller=controller, 
                                # Pass the directory where this sample's maps were saved
                                attention_map_dir=controller.output_dir if controller else None
                            )
                            calculated_metrics[metric.name] = score

                    # Create a mock result structure for evaluation
                    calculated_metrics.update({"prompt": prompt, "memorized": False})
                    mitigated_result = {
                        "image": decoded_image_pil,
                        "latents": generated_latents,
                        "metrics": calculated_metrics
                    }
                    mitigated_images.append(mitigated_result["image"])
                    save_run_outputs(mitigated_result, os.path.dirname(mitigated_sample_dir), base_filename)
                except Exception as e:
                    raise e
                
                # --- W&B Logging (Mitigated) ---
                if args.use_wandb:
                    log_data = {
                        f"mitigated/{dataset['name']}/metrics/{k}": v 
                        for k, v in mitigated_result["metrics"].items() 
                        if isinstance(v, (int, float))
                    }
                    log_data[f"mitigated/{dataset['name']}/images"] = wandb.Image(
                        mitigated_result["image"],
                        caption=f"Prompt: {prompt}\nSeed: {seed_idx}\nMethod: AMG"
                    )
                    wandb.log(log_data, step=idx * config['num_seeds'] + seed_idx)

                del mitigated_result, controller, generated_latents
                torch.cuda.empty_cache()

            # --- Cross-Seed Metrics Calculation (Mitigated) ---
            if mitigated_images:
                mitigated_diversity = diversity_metric.measure(images=mitigated_images)
                mitigated_cross_seed_json = os.path.join(
                    f"output/{config['method']}/{dataset['name']}", 
                    f"prompt_{idx:04d}_{safe_prompt}_cross_seed.json"
                )
                os.makedirs(os.path.dirname(mitigated_cross_seed_json), exist_ok=True)
                with open(mitigated_cross_seed_json, 'w') as f:
                    json.dump({
                        "prompt": prompt, 
                        "memorized": False, 
                        diversity_metric.name: mitigated_diversity
                    }, f, indent=2)
                print(f"Mitigated Cross-Seed Artifacts: {mitigated_cross_seed_json}")
                
                # --- W&B Logging (Mitigated Diversity & Comparison) ---
                if args.use_wandb:
                    wandb.log({
                        f"mitigated/{dataset['name']}/diversity": mitigated_diversity,
                        "prompt_idx": idx
                    }, step=idx)
                    
                    # Log a bar chart comparing diversities (only if baseline was run)
                    if baseline_diversity is not None:
                        diversity_plot = wandb.plot.bar(
                            wandb.Table(
                                columns=["stage", "diversity"],
                                data=[["baseline", baseline_diversity], ["mitigated", mitigated_diversity]]
                            ),
                            "stage",
                            "diversity",
                            title=f"Diversity Comparison for Prompt {idx}"
                        )
                        wandb.log({f"plots/{dataset['name']}/diversity_comparison": diversity_plot}, step=idx)
                
                print(f"Mitigated diversity for prompt {idx}: {mitigated_diversity}")
            
            del mitigated_images

            # --- W&B Logging (Timing) ---
            prompt_duration = time.time() - prompt_start_time
            print(f"Prompt {idx} processed in {prompt_duration:.2f} seconds.")
            if args.use_wandb:
                wandb.log({
                    f"timing/{dataset['name']}/prompt_duration": prompt_duration,
                    "prompt_idx": idx
                }, step=idx)

        # Log the dataset-specific table
        if args.use_wandb:
            wandb.log({f"prompts/{dataset['name']}_prompts": dataset_table})

    # --- Final Timing Log ---
    total_duration_minutes = (time.time() - script_start_time) / 60
    print(f"Total evaluation time: {total_duration_minutes:.2f} minutes.")
    
    if args.use_wandb:
        wandb.finish()

    print("\nAMG evaluation complete.")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run AMG evaluation with optional W&B logging.")
    parser.add_argument("--use_wandb", action="store_true", help="Enable Weights & Biases logging.")
    parser.add_argument("--run_baseline", action="store_true", help="Run baseline for sanity check.")
    args = parser.parse_args()
    main(args)