# evaluation/utils/perturbations.py
import torch
import torch.nn.functional as F
import random
import numpy as np
import matplotlib.pyplot as plt

# --- Metric-Agnostic Perturbations (Somepalli et al.) ---

def add_gaussian_noise(embedding: torch.Tensor, std: float = 0.5, **kwargs):
    """Adds Gaussian noise to a text embedding tensor (GNI)."""
    noise = torch.randn_like(embedding) * std
    return embedding + noise

def perturb_prompt_random_tokens(prompt: str, tokenizer, num_tokens: int = 4, **kwargs):
    """Adds random tokens to a prompt string (RT)."""
    for _ in range(num_tokens):
        random_token = tokenizer.decode(random.randint(1000, 40000))
        insert_pos = random.randint(0, len(prompt))
        prompt = prompt[:insert_pos] + f" {random_token} " + prompt[insert_pos:]
    return prompt

def perturb_prompt_word_repetition(prompt: str, tokenizer, num_repeats: int = 10, **kwargs):
    """Repeats random words within the prompt (CWR)."""
    words = prompt.split()
    if not words: return prompt
    for _ in range(num_repeats):
        word_to_repeat = random.choice(words)
        insert_pos = random.randint(0, len(words))
        words.insert(insert_pos, word_to_repeat)
    return " ".join(words)

def perturb_prompt_random_numbers(prompt: str, num_numbers: int = 10, **kwargs):
    """Adds random numbers to a prompt string (RNA)."""
    for _ in range(num_numbers):
        random_num = random.randint(0, 1_000_000)
        insert_pos = random.randint(0, len(prompt))
        prompt = prompt[:insert_pos] + f" {random_num} " + prompt[insert_pos:]
    return prompt

def perturb_prompt_tokenwise(prompt: str, tokenizer, token_idx_to_perturb: int):
    """Replaces the k-th meaningful token with a random one."""
    batch_encoding = tokenizer(prompt, truncation=True, max_length=77, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
    token_ids = batch_encoding["input_ids"].squeeze()
    
    # Find indices of all meaningful tokens
    meaningful_indices = (batch_encoding["attention_mask"].squeeze() == 1).nonzero(as_tuple=True)[0]
    # Exclude special BOS/EOS tokens
    meaningful_indices = meaningful_indices[1:-1] 
    
    if token_idx_to_perturb >= len(meaningful_indices):
        return prompt # Index out of bounds, return original
        
    # Get the actual position in the 77-token sequence
    pos_to_perturb = meaningful_indices[token_idx_to_perturb]
    
    # Replace with a random token
    perturbed_token_ids = token_ids.clone()
    random_token_id = random.randint(1000, 40000)
    perturbed_token_ids[pos_to_perturb] = random_token_id
    
    return tokenizer.decode(perturbed_token_ids, skip_special_tokens=True)

# --- Metric-Aware Perturbation (Wen et al.) ---

def optimize_embedding_wen(embedding: torch.Tensor, model, output_path: str, **kwargs):
    """
    Perturbs a text embedding by minimizing the text-conditional noise prediction magnitude.
    This version now logs and plots the optimization loss curve.
    """
    lr = kwargs.get("lr", 0.05)
    steps = kwargs.get("steps", 10)
    target_loss = kwargs.get("target_loss", 3.0)

    e_star = embedding.clone().detach().requires_grad_(True)
    optimizer = torch.optim.Adam([e_star], lr=lr)
    
    t = torch.tensor([999], device=embedding.device)
    x_t = torch.randn((1, 4, model.image_size // 8, model.image_size // 8), device=embedding.device)
    uc = model.get_learned_conditioning([""]).to(model.device)
    uc_ = {"context": uc}

    loss_history = []
    print("  > Starting Wen et al. metric-aware optimization...")
    for i in range(steps):
        optimizer.zero_grad()
        
        c_ = {"context": e_star}
        cond_noise = model.apply_model(x_t, t, c_)
        uncond_noise = model.apply_model(x_t, t, uc_)
        
        loss = torch.linalg.norm(cond_noise - uncond_noise)
        loss_history.append(loss.item())

        if loss.item() < target_loss:
            print(f"  > Optimization stopped early at step {i+1}.")
            break
            
        loss.backward()
        optimizer.step()
        
    print("  > Wen et al. optimization finished.")

    # --- Plot and save the loss curve ---
    if loss_history:
        plt.figure(figsize=(8, 5))
        plt.plot(range(len(loss_history)), loss_history, marker='o', linestyle='-')
        plt.title("Wen et al. Optimization Loss Curve")
        plt.xlabel("Optimization Step")
        plt.ylabel("Global Noise Difference (D) Loss")
        plt.grid(True)
        plt.tight_layout()
        plt.savefig(output_path)
        plt.close()
        print(f"  > Saved Wen et al. loss curve to {output_path}")

    return e_star.detach()


    
def optimize_embedding_bright_ending(embedding: torch.Tensor, model, controller, output_path: str, **kwargs):
    """
    Corrected version. Perturbs an embedding by minimizing the localized, 
    Bright Ending-masked noise difference and logs the loss curve.
    """
    if controller is None:
        raise ValueError("Bright Ending optimization requires an AttentionStore controller.")

    lr = kwargs.get("lr", 0.05)
    steps = kwargs.get("steps", 10)
    
    e_star = embedding.clone().detach().requires_grad_(True)
    optimizer = torch.optim.Adam([e_star], lr=lr)
    
    t = torch.tensor([999], device=embedding.device)
    x_t = torch.randn((1, 4, model.image_size // 8, model.image_size // 8), device=embedding.device)
    uc = model.get_learned_conditioning([""]).to(model.device)
    uc_ = {"context": uc}

    # --- NEW: Initialize list to store loss history ---
    loss_history = []
    
    print("  > Starting Bright Ending (BE) optimization...")
    for i in range(steps):
        optimizer.zero_grad()
        controller.reset()
        
        c_ = {"context": e_star}
        
        cond_noise = model.apply_model(x_t, t, c_, controller=controller)
        uncond_noise = model.apply_model(x_t, t, uc_)
        noise_diff = cond_noise - uncond_noise

        step_0_maps = controller.latest_attention_maps.get(0, {})
        attn_map1 = step_0_maps.get('down_0')
        attn_map2 = step_0_maps.get('down_1')
        print("BE: Grabbed Attn Map from down_0 and down_1")
        
        if attn_map1 is None or attn_map2 is None:
            print("Warning: BE optimizer could not find attention maps. Skipping step.")
            loss_history.append(None) # Log a None to indicate a skipped step
            continue
            
        final_attn_map = (attn_map1 + attn_map2) / 2.0
        be_map = final_attn_map[:, :, -1]
        
        num_heads = 8
        latent_dim = int((be_map.shape[1])**0.5)
        be_map = be_map.reshape(-1, num_heads, latent_dim, latent_dim).mean(1)
        be_map = (be_map - be_map.min()) / (be_map.max() - be_map.min() + 1e-6)
        
        resized_mask = F.interpolate(be_map.unsqueeze(0), size=noise_diff.shape[2:], mode='bilinear').squeeze(0).to(noise_diff.device)
        masked_loss = torch.linalg.norm(noise_diff * resized_mask) / (torch.mean(resized_mask) + 1e-6)

        # --- NEW: Log the loss value and backpropagate ---
        loss_history.append(masked_loss.item())
        masked_loss.backward()
        optimizer.step()
        
    print("  > BE Optimization finished.")

    # --- NEW: Plot and save the loss curve ---
    if loss_history:
        plt.figure(figsize=(8, 5))
        plt.plot(range(len(loss_history)), loss_history, marker='o', linestyle='-')
        plt.title("BE Optimization Loss Curve")
        plt.xlabel("Optimization Step")
        plt.ylabel("Localized Detection (LD) Loss")
        plt.grid(True)
        plt.tight_layout()
        plt.savefig(output_path)
        plt.close()
        print(f"  > Saved BE loss curve to {output_path}")

    return e_star.detach()