import re
import os, sys
import json
import torch
import wandb
import random
import requests
from tqdm import tqdm
from typing import List, Dict
from PIL import Image
import numpy as np
import pandas as pd
from io import BytesIO
import matplotlib.pyplot as plt

from .metrics.base import BaseMetric
from .utils import uids_to_prompts
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from mvdream.camera_utils import get_camera  


# Helper to plot noise trajectory
def save_noise_plot(noise_norms: List[float], output_path: str):
    plt.figure(figsize=(10, 5))
    plt.plot(noise_norms, label='Text-Uncond Noise Norms', color='black')
    plt.xlabel("Denoising Step")
    plt.ylabel("Noise Norm")
    plt.title("Classifier-Free Guidance Noise Norms")
    plt.grid(True, alpha=0.5)
    plt.savefig(output_path)
    plt.close()


def save_eigenvalue_plot(eigval_data: Dict, title: str, output_path: str):
    """Creates and saves an eigenvalue plot similar to Figure 5."""
    plt.figure(figsize=(8, 6))
    # Plot unconditional eigenvalues (lambda)
    uncond_eigvals = eigval_data.get('uncond_eigvals', [])
    plt.plot(np.arange(len(uncond_eigvals)), uncond_eigvals, color='black', linestyle='--', label=r'$\lambda$ (unconditional)')
    # Plot conditional eigenvalues (lambda_c)
    cond_eigvals = eigval_data.get('cond_eigvals', [])
    plt.plot(np.arange(len(cond_eigvals)), cond_eigvals, color='black', label=r'$\lambda_c$ (conditional)')
    
    # Fill the gap between them
    plt.fill_between(
        np.arange(len(cond_eigvals)), 
        uncond_eigvals, 
        cond_eigvals, 
        color='coral', 
        alpha=0.6,
        label='Eigenvalue Gap'
    )
    plt.ylabel("Eigenvalue")
    plt.xlabel("Eigenvalue Index (Sorted)")
    plt.title(title)
    plt.legend()
    plt.grid(True, alpha=0.5)
    plt.savefig(output_path)
    plt.close()


def save_run_outputs(result: dict, output_dir: str, base_filename: str, 
                     wandb_run_context=None, save_plots=False, step: int=0):
    """Helper to save all per-prompt per-seed outputs from a single generation run into its dedicated folder."""
    os.makedirs(output_dir, exist_ok=True)
    # Save the main generated image
    result["image"].save(os.path.join(output_dir, f"{base_filename}_image.png"))
    # Save all metrics to a JSON file
    metrics_to_save = result["metrics"]
    with open(os.path.join(output_dir, f"{base_filename}_metrics.json"), 'w') as f:
        # Convert tensors to lists for JSON serialization if necessary
        for k, v in metrics_to_save.items():
            if isinstance(v, dict):
                for sub_k, sub_v in v.items():
                    if isinstance(sub_v, torch.Tensor):
                        v[sub_k] = sub_v.tolist()
        json.dump(metrics_to_save, f, indent=2)

    if save_plots:
        if "NoiseDiffNorm" in metrics_to_save:
            save_noise_plot(
                metrics_to_save["NoiseDiffNorm"]["noise_diff_norm_traj"], 
                os.path.join(output_dir, f"{base_filename}_noise_plot.png")
            )
        if "HessianMetric" in metrics_to_save:
            for t_step in ["t1", "t20"]:
                if t_step in metrics_to_save["HessianMetric"]:
                    save_eigenvalue_plot(
                        metrics_to_save["HessianMetric"][t_step],
                        title=f"Hessian Eigenvalues @ {t_step}",
                        output_path=os.path.join(output_dir, f"{base_filename}_eigvals_{t_step}.png")
                )
                
    print("Per Prompt Per Seed Artefacts --->", os.path.join(output_dir, f"{base_filename}_metrics.json"))
    print("="*150, "\n")

    # --- Log to Weights & Biases (if enabled) ---
    if wandb_run_context is not None:
        log_dict = {
            "prompt": prompt,
            "generated_image": wandb.Image(image_path, caption=prompt),
            "noise_plot": wandb.Image(noise_plot_path)
        }
        # Log all scalar metrics
        scalar_metrics = extract_scalar_features(metrics_to_save)
        for key, val in scalar_metrics.items():
            log_dict[f"metric_{key}"] = val
        
        wandb.log(log_dict, step=step)

# Helper to decode and save images, based on your uce.py
def save_visual_samples(model, x_t, save_path: str):
    x_decoded = model.decode_first_stage(x_t)
    x_decoded = torch.clamp((x_decoded + 1.0) / 2.0, min=0.0, max=1.0)
    x_decoded = 255. * x_decoded.permute(0, 2, 3, 1).detach().cpu().numpy()
    
    combined = np.concatenate([img.astype(np.uint8) for img in x_decoded], axis=1)
    Image.fromarray(combined).save(save_path)
    return Image.fromarray(combined)


class ModelEvaluator:
    def __init__(self, sampler, metrics: List[BaseMetric], device='cuda'):
        self.sampler = sampler
        self.metrics = metrics
        self.device = device
        self.model = sampler.model

    def _get_ground_truth(self, metadata: dict):
        # (This function remains the same)
        if 'URL' in metadata and metadata['URL']:
            try:
                response = requests.get(metadata['URL'], timeout=10)
                response.raise_for_status()
                return Image.open(BytesIO(response.content)).convert("RGB")
            except Exception as e:
                print(f"Warning: Could not fetch URL {metadata['URL']}. Error: {e}")
        return None

    def _run_single_generation(self, c_, uc_, shape, num_frames, **kwargs):
        controller = kwargs.get("controller", None)
        generator = torch.Generator(device=self.device).manual_seed(kwargs.get("seed", 42))
        x_T_override = kwargs.get("x_T")
        
        batch_size = max(4, num_frames)

        x_T = x_T_override.to(self.device) if x_T_override is not None else torch.randn(
            [batch_size] + shape, generator=generator, device=self.device)
            
        samples, intermediates = self.sampler.sample(
            S=50,
            conditioning=c_,
            batch_size=batch_size,
            shape=shape,
            verbose=False,
            unconditional_guidance_scale=7.5,
            unconditional_conditioning=uc_,
            eta=0.0,
            x_T=x_T,
            controller=controller,
        )
        return samples, intermediates

    def _process_single_prompt_single_seed(self, prompt, prompt_idx: int, seed: int, num_frames: int,
                                        base_output_dir="output/", unlearning_artifacts=None,):
        if unlearning_artifacts is not None:
            token_perturb_func = unlearning_artifacts.get("token_perturb_func", None)
            embedding_perturb_func = unlearning_artifacts.get("embedding_perturb_func", None)
            controller = unlearning_artifacts.get("controller", None)
        
        safe_prompt = re.sub(r'\W+', '_', prompt)[:32]
        base_filename = f"prompt_{prompt_idx:04d}_{seed:02d}_{safe_prompt}"
        sample_output_dir = os.path.join(base_output_dir, base_filename)
        os.makedirs(sample_output_dir, exist_ok=True)
        controller.output_dir = sample_output_dir
        print(f"CA maps will be saved to directory: {controller.output_dir}")
        # ---

        if controller:
            print(controller)
            controller.reset()

        # 1. Apply prompt-level perturbation (if any)
        final_prompt = prompt
        if token_perturb_func:
            final_prompt = token_perturb_func(prompt, token_idx_to_perturb=seed)
        print("Sanity Check in _process_single_prompt_single_seed final_prompt:", final_prompt)
        # 2. Get initial text conditioning
        c = self.model.get_learned_conditioning([final_prompt]).to(self.device)

        # 3. Apply embedding-level perturbation (if any)
        if embedding_perturb_func:
            c = embedding_perturb_func(c, model=self.model, controller=controller, prompt=prompt, output_path=base_output_dir ,)

        # 4. Prepare conditioning dictionaries for the sampler
        batch_size = max(4, num_frames)
        uc = self.model.get_learned_conditioning([""]).to(self.device)
        c_ = {"context": c.repeat(batch_size, 1, 1)}
        uc_ = {"context": uc.repeat(batch_size, 1, 1)}
        camera = get_camera(num_frames, elevation=random.randint(-15, 30), azimuth_start=random.randint(0, 360), azimuth_span=360)
        c_["camera"] = uc_["camera"] = camera.repeat(batch_size // num_frames, 1).to(self.device)
        c_["num_frames"] = uc_["num_frames"] = num_frames

        # 5. Handle x_T override for methods like SAIL
        # x_T_override = unlearning_artifacts.get("optimized_noise", {}).get((prompt_idx, seed))
        x_T_override = unlearning_artifacts.get("x_T_override") 
        if x_T_override is not None: print(x_T_override.shape)
        generator = torch.Generator(device=self.device).manual_seed(seed)
        shape = [self.model.model.diffusion_model.in_channels, self.model.image_size // 8, self.model.image_size // 8]
        x_T = x_T_override.to(self.device) if x_T_override is not None else torch.randn([batch_size] + shape, generator=generator, device=self.device)

        # 6. Run the generation
        samples, intermediates = self.sampler.sample(
            S=50, conditioning=c_, batch_size=batch_size, shape=shape, verbose=False,
            unconditional_guidance_scale=7.5, unconditional_conditioning=uc_, eta=0.0, x_T=x_T, controller=controller,
        )
        
        # This is where the controller's attention maps would be finalized for the step
        # if controller: 
            # controller.set_current_step(50)
            # controller.step()

        # 7. Decode image
        x_decoded = self.model.decode_first_stage(samples)
        x_decoded = torch.clamp((x_decoded + 1.0) / 2.0, min=0.0, max=1.0)
        x_decoded_np = 255. * x_decoded.permute(0, 2, 3, 1).detach().cpu().numpy()
        combined_img = np.concatenate([img.astype(np.uint8) for img in x_decoded_np], axis=1)
        gen_image_pil = Image.fromarray(combined_img)

        # 8. Calculate all per-seed metrics
        calculated_metrics = {}
        for metric in self.metrics:
            if metric.metric_type == "per_seed":
                score = metric.measure(
                    intermediates=intermediates, model=self.model,
                    conditioning_context=c_, unconditioning_context=uc_,
                    controller=controller, 
                    # Pass the directory where this sample's maps were saved
                    attention_map_dir=controller.output_dir if controller else None
                )
                calculated_metrics[metric.name] = score

        return {
            "image": gen_image_pil,
            "intermediates": intermediates,
            "conditioning_context": c_,
            "unconditioning_context": uc_,
            "metrics": calculated_metrics,
            "controller": controller,
        }

    def _process_single_prompt(self, prompt_idx: int, prompt_data: dict, output_dir: str, 
                               is_memorized_source: bool, num_seeds: int, num_frames: int, 
                               unlearning_artifacts=None,):
        original_prompt = prompt_data["Caption"]
        prompt = prompt_data["Caption"]
        generated_images_for_prompt = []
        all_seed_metrics = {"original_prompt": prompt}

        # --- Get text conditioning once for the prompt ---
        batch_size = max(4, num_frames)
        c = self.model.get_learned_conditioning([prompt]).to(self.device)
        uc = self.model.get_learned_conditioning([""]).to(self.device)
        
        c_ = {"context": c.repeat(batch_size, 1, 1)}
        uc_ = {"context": uc.repeat(batch_size, 1, 1)}
        
        # Add camera conditioning if applicable (can be extended with args)
        camera = get_camera(num_frames, elevation=random.randint(-15, 30), azimuth_start=random.randint(0, 360), azimuth_span=360)
        camera = camera.repeat(batch_size // num_frames, 1).to(self.device)
        c_["camera"] = uc_["camera"] = camera
        c_["num_frames"] = uc_["num_frames"] = num_frames

        # --- UNPACK PERTURBATION FUNCTIONS FROM ARTIFACTS ---
        token_perturb_func = None
        embed_perturb_func = None
        controller = None
        if unlearning_artifacts:
            token_perturb_func = unlearning_artifacts.get("token_perturb_func", None)
            embed_perturb_func = unlearning_artifacts.get("embed_perturb_func", None)
            controller = unlearning_artifacts.get("controller", None)
        if controller:
            controller.reset()

        # ----------------------------------------------------
        batch_encoding = self.model.cond_stage_model.tokenizer(prompt, truncation=True, max_length=77, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
        token_ids = batch_encoding["input_ids"].squeeze()
        valid_indices = (batch_encoding["attention_mask"].squeeze() == 1).nonzero(as_tuple=True)[0]
        loop_through = range(num_seeds) if num_seeds > 0 else valid_indices[1:-1]  # Exclude BOS/EOS tokens
        print("Before perturbation:", token_ids[:len(valid_indices)]) 
        print(f"\nLooping through {loop_through}")
        # ----------------------------------------------------

        # --- Loop over seeds ---
        for seed_idx in loop_through:
            prompt = original_prompt
            safe_prompt = re.sub(r'\W+', '_', prompt)[:32]
            base_filename = f"prompt_{prompt_idx:04d}_{seed_idx:02d}_{safe_prompt}"

            safe_prompt = re.sub(r'\W+', '_', prompt)[:32]
            base_filename = f"prompt_{prompt_idx:04d}_{seed_idx:02d}_{safe_prompt}"
            sample_output_dir = os.path.join(output_dir, base_filename)
            os.makedirs(sample_output_dir, exist_ok=True)
            if controller is not None:
                controller.output_dir = sample_output_dir
                print(f"CA maps will be saved to directory: {controller.output_dir}")
                c_ = {"context": c.repeat(batch_size, 1, 1)}
            # ---

            # --- Token Level Perturbation ---
            if token_perturb_func:
                prompt = token_perturb_func(prompt, token_idx_to_perturb=seed_idx)
                print(f"Perturbed prompt for position or seed {seed_idx}: {prompt}")
                
            c = self.model.get_learned_conditioning([prompt]).to(self.device)
            c_["context"] = c.repeat(batch_size, 1, 1)

            # --- Embedding Level Perturbation --- FIXME: modify c_["context"] in place
            if embed_perturb_func:
                # c = embed_perturb_func(c, model=self.model)
                c = embed_perturb_func(c, model=self.model, output_path=sample_output_dir, 
                    prompt=prompt, controller=unlearning_artifacts.get("controller"))
                
            
            # --- Sail-Like Initial Noise Optimization ---
            x_T_override = None
            if unlearning_artifacts:
                x_T_override = unlearning_artifacts.get("optimized_noise", {}).get((prompt_idx, seed_idx))
            
            shape = [self.model.model.diffusion_model.in_channels, 32, 32]
            
            samples, intermediates = self._run_single_generation(
                c_, uc_, shape, num_frames, seed=int(seed_idx), x_T=x_T_override, controller=controller,
            )

            # --- Save individual generated image using correct decoding ---
            image_path = os.path.join(output_dir, f"{base_filename}_image.png")
            gen_image = save_visual_samples(self.model, samples, image_path)
            generated_images_for_prompt.append(gen_image)

            # --- Calculate and save metrics ---
            seed_metrics = {"prompt": prompt, "memorized": is_memorized_source, "perturbed_prompt": prompt}

            if 'uid' in prompt_data: seed_metrics['uid'] = prompt_data['uid']
            
            for metric in self.metrics:
                if metric.metric_type == "per_seed":
                    # and "_Entropy" not in metric.name and "BrightEnding_" not in metric.name:
                    # print(f"SKIPPING CROSS ATTENTION ENTROPY")
                    score = metric.measure(
                        intermediates=intermediates, 
                        model=self.model,
                        conditioning_context=c_, # Pass context needed by Hessian metric
                        unconditioning_context=uc_,
                        controller=controller,
                        attention_map_dir=controller.output_dir if controller else None
                    )
                    seed_metrics[metric.name] = score
                    # print(metric.name, score)
            
            all_seed_metrics[seed_idx] = seed_metrics

            noise_norms = [(tn - un).norm(p=2).item() for un, tn in zip(intermediates['uncond_noise'], intermediates['text_noise'])]
            save_noise_plot(noise_norms, os.path.join(output_dir, f"{base_filename}_noise_plot.png"))
            
            ######## --- Save Individual Plots ---
            # # Noise Norm Trajectory
            # if "Noise_Difference_Norm" in seed_metrics:
            #     save_noise_plot(
            #         seed_metrics["Noise_Difference_Norm"]["noise_diff_norm_traj"], 
            #         os.path.join(output_dir, f"{base_filename}_noise_plot.png")
            #     )
            
            # # Eigenvalue Plots
            # if "HessianMetric" in seed_metrics:
            #     for t_step in ["t1", "t20"]:
            #         if t_step in seed_metrics["HessianMetric"]:
            #             save_eigenvalue_plot(
            #                 seed_metrics["HessianMetric"][t_step],
            #                 title=f"Hessian Eigenvalues @ {t_step}",
            #                 output_path=os.path.join(output_dir, f"{base_filename}_eigvals_{t_step}.png")
            #             )
            
            with open(os.path.join(output_dir, f"{base_filename}_metrics.json"), 'w') as f:
                json.dump(seed_metrics, f, indent=2)
            print("Per-Seed JSON saved to:", f"{base_filename}_metrics.json\n")

        # --- Calculate per-prompt metrics ---
        per_prompt_metrics = {}
        for metric in self.metrics:
            if metric.metric_type == "per_prompt_across_seeds":
                score = metric.measure(images=generated_images_for_prompt)
                per_prompt_metrics[metric.name] = score
        per_prompt_across_seeds_filename = f"prompt_{prompt_idx:04d}_{safe_prompt}"
        with open(os.path.join(output_dir, f"{per_prompt_across_seeds_filename}_cross_seed_metrics.json"), 'w') as f:
            json.dump(per_prompt_metrics, f, indent=2)
        print("Per-Prompt JSON saved to:", f"{per_prompt_across_seeds_filename}_cross_seed_metrics.json\n\n")

        # --- Save Ground Truth ---
        gt_image = self._get_ground_truth(prompt_data)
        if gt_image:
            gt_image.save(os.path.join(output_dir, f"prompt_{prompt_idx:04d}_ground_truth.png"))
        
        return {"per_seed_metrics": all_seed_metrics, "per_prompt_metrics": per_prompt_metrics}

    def run(self, model, dataset_path: str, prompt_source_name: str, output_dir: str, 
            is_memorized_source: bool, num_seeds: int, num_frames: int, 
            uids: List = None, unlearning_artifacts=None, ):
        """
        unlearning_artifacts may contain token_perturb_func, embed_perturb_func, controller
        """
        # (This high-level run function remains mostly the same, just passes num_frames down)
        self.model = model
        self.sampler.model = model
        
        if dataset_path.endswith('.csv'):
            df = pd.read_csv(dataset_path, sep=';')
            prompts_data = df.to_dict('records')
        elif dataset_path.endswith('.json'): # For Objaverse
            uids = uids or []
            prompts_data = uids_to_prompts(uids)
            # print(">"*50, uids, prompts)
            # prompts_data = [{"Caption": p, "uid": u} for p, u in zip(prompts, uids)]
        else:
            raise ValueError("Unsupported dataset format. Use .csv or .json")
        
        full_output_dir = os.path.join(output_dir, prompt_source_name)
        os.makedirs(full_output_dir, exist_ok=True)
        
        all_results = {}
        for idx, prompt_data in enumerate(tqdm(prompts_data, desc=f"Processing {prompt_source_name}")):
            print(prompt_data)
            prompt_results = self._process_single_prompt(
                prompt_idx=idx, prompt_data=prompt_data, output_dir=full_output_dir,
                is_memorized_source=is_memorized_source, num_seeds=num_seeds,
                num_frames=num_frames, unlearning_artifacts=unlearning_artifacts,
            )
            all_results[int(idx)] = prompt_results

        agg_path = os.path.join(full_output_dir, "_aggregated_results.json")
        with open(agg_path, 'w') as f:
            json.dump(all_results, f, indent=2)
        print(f"Aggregated results for {prompt_source_name} saved to {agg_path}")