import logging
import os
from dataclasses import dataclass
from functools import wraps
from pathlib import Path
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union

import einops
import torch
import torch.nn as nn
from jaxtyping import Float, Int
from torch import Tensor

from redflag.configs import AdvAttackConfig
from redflag.adversarial import EmbeddingSpaceAttack
from redflag.sft_trainer_utils import GradientStateManager

try:
    from peft import PeftModel

    PEFT_AVAILABLE = True
except ImportError:
    PEFT_AVAILABLE = False


def get_clean_embeddings(model, ids):
    """
    Get clean embeddings for the given token IDs from the model.
    Uses the embedding layer's forward pass to properly handle trainable tokens.

    Args:
        model: The model to extract embeddings from
        ids: Token IDs tensor of shape (batch_size, num_tokens)

    Returns:
        torch.Tensor: Embeddings tensor of shape (batch_size, num_tokens, embed_dim)
    """
    # Handle negative IDs
    if torch.any(ids < 0):
        logging.info(f"Found negative IDs in ids: {ids}")
    ids = torch.where(ids < 0, torch.tensor(0, device=ids.device, dtype=ids.dtype), ids)

    # Unwrap DDP/DataParallel model to access HuggingFace methods
    unwrapped_model = model.module if hasattr(model, "module") else model

    # Use the embedding layer directly - handles trainable tokens automatically
    return unwrapped_model.get_input_embeddings()(ids)


def transfer_masked_embeddings(embeds_input, mask_input, embeds_dest, mask_dest):
    """
    Transfer embeddings from adversarial attack to clean embeddings using masks.

    Args:
        input_embeds_attack: Adversarial embeddings tensor
        input_mask_adv: Boolean mask for adversarial embeddings
        input_embeds_clean: Clean embeddings tensor
        input_mask_clean: Boolean mask for clean embeddings

    Returns:
        torch.Tensor: Combined embeddings with adversarial perturbations applied
    """
    assert torch.all(
        torch.sum(mask_input, dim=1) == torch.sum(mask_dest, dim=1)
    ), "Number of True values must match in both masks"

    embeds_input_selected = embeds_input[mask_input]

    # Create index tensor for embedding_large
    device = embeds_dest.device
    batch_indices = torch.arange(embeds_dest.shape[0], device=device).unsqueeze(1).expand_as(mask_dest)
    batch_indices_flat = batch_indices[mask_dest]

    seq_indices = torch.arange(embeds_dest.shape[1], device=device).unsqueeze(0).expand_as(mask_dest)
    seq_indices_flat = seq_indices[mask_dest]

    # Place values into embedding_large
    embeds_output = embeds_dest.clone()
    embeds_output[batch_indices_flat, seq_indices_flat] = embeds_input_selected

    return embeds_output


@dataclass
class AdversarialAttackResult:
    embedding_attack: EmbeddingSpaceAttack
    adv_input_embeds: Tensor  # will contain gradient info
    adv_perturbation: Tensor
    adv_perturbation_mask: Tensor
    all_losses: List[float]
    affirmative_responses: Tensor
    nan_restarts_count: int = 0

    def get_adv_embeddings(self):
        return self.embedding_attack.get_adv_embeddings(
            self.adv_input_embeds, self.adv_perturbation, self.adv_perturbation_mask
        )


def adversarial_attack(
    model,
    input_ids,
    target_ids,
    attention_mask,
    tokenizer,
    attack_config: AdvAttackConfig,
    perturbation_mask=None,
    maximize_loss_idx=None,
) -> AdversarialAttackResult:
    """
    Perform adversarial attack on embeddings.

    Args:
        model: The model to attack
        input_ids: Input token IDs
        target_ids: Target token IDs
        attention_mask: Attention mask
        tokenizer: The tokenizer to use
        attack_config: Attack configuration object with attack parameters

    Returns:
        AdversarialAttackResult: Result object containing attack results
    """

    # Unwrap DDP/DataParallel model to access HuggingFace methods
    unwrapped_model = model.module if hasattr(model, "module") else model

    # Get the embedding layer (handles trainable tokens properly)
    embedding_attack = EmbeddingSpaceAttack(
        embed_layer=unwrapped_model.get_input_embeddings(),  # Use unwrapped model here
        tokenizer=tokenizer,
        iters=attack_config.iters,
        eps=attack_config.eps,
        opt_config=attack_config.opt_config,
        attack_precision=attack_config.attack_precision,
        maximize_loss_idx=maximize_loss_idx,
        maximize_loss_idx_weight=attack_config.maximize_loss_weight,
    )

    # Use context manager to safely manage gradient states
    # Disable LoRA gradients, keep embedding gradients active for attack
    with GradientStateManager(model, disable_grad_patterns=["lora", "trainable_tokens"]):
        embedding_attack_result = embedding_attack.attack(
            model=model,  # Keep original DDP model for forward pass (DDP handles forward())
            input_ids=input_ids,
            target_ids=target_ids,
            attention_mask=attention_mask,
            perturbation_mask=perturbation_mask,
        )

    # Detach results to prevent further gradient propagation : probably redundant
    input_embeds = embedding_attack_result.input_embeds.detach().clone()
    adv_perturbation = embedding_attack_result.adv_perturbation.detach().clone()
    adv_perturbation_mask = embedding_attack_result.adv_perturbation_mask.detach().clone()

    result = AdversarialAttackResult(
        embedding_attack=embedding_attack,
        adv_input_embeds=input_embeds,
        adv_perturbation=adv_perturbation,
        adv_perturbation_mask=adv_perturbation_mask,
        all_losses=embedding_attack_result.all_losses,
        affirmative_responses=embedding_attack_result.affirmative_responses,
        nan_restarts_count=embedding_attack_result.nan_restarts_count,
    )
    return result
