import logging

import einops
import torch
import torch.nn as nn
from jaxtyping import Float, Int
from torch import Tensor


def _linear_decay(length: int, end_value: float = 0.0) -> Tensor:
    return torch.linspace(1.0, end_value, length)

def _cosine_annealing_decay(length: int, end_value: float = 0.0) -> Tensor:
    return (torch.cos(torch.pi * torch.arange(length) / length) + 1) / 2 * (1 - end_value) + end_value

def _exponential_decay(length: int, end_value: float = 0.0) -> Tensor:
    return torch.exp(-torch.arange(length)*5 / length) * (1 - end_value) + end_value

def _create_weighting_scatter(x: torch.Tensor, positions: torch.Tensor, scaling_values: torch.Tensor, skip_id: int = -100, mode: str = "post_decay") -> torch.Tensor:
    """
    Memory-efficient version using torch.scatter operations.
    Supports both post_decay (forward) and pre_growth (backward) modes.
    """
    batch_size, seq_len = x.shape
    n_decay = len(scaling_values)
    
    # Process each batch to avoid memory explosion for very large tensors
    for b in range(batch_size):
        batch_positions = positions[b]
        valid_positions = batch_positions[batch_positions != skip_id]
        
        if len(valid_positions) == 0:
            continue
        
        # Sort positions to ensure first insertion is correctly identified
        valid_positions = torch.sort(valid_positions)[0]
        
        # Create indices for all decay positions  
        for pos in valid_positions:
            if mode == "post_decay":
                # Original logic: apply scaling starting from pos + 1
                start_idx = pos + 1
                end_idx = min(start_idx + n_decay, seq_len)
                if start_idx < seq_len:
                    x[b, start_idx:end_idx] = scaling_values[:end_idx - start_idx]
            elif mode == "pre_growth":
                # Pre-growth: apply scaling ending at pos
                # We want positions [max(0, pos-n+1), ..., pos] to get scaling values
                actual_start = max(0, pos - n_decay + 1)
                actual_end = pos + 1
                
                # If we had to clip the start, adjust the scaling values accordingly
                scaling_start_idx = max(0, n_decay - 1 - pos)
                scaling_end_idx = scaling_start_idx + (actual_end - actual_start)
                
                if actual_start < seq_len and actual_end > 0:
                    x[b, actual_start:actual_end] = scaling_values[scaling_start_idx:scaling_end_idx]
        
        # Final pass: ensure first insertion always hits 1.0
        if mode == "pre_growth" and len(valid_positions) > 0:
            first_pos = valid_positions[0]
            if first_pos >= 0 and first_pos < seq_len:
                # Take slice from end of scaling to ensure insertion point gets 1.0
                available_positions = first_pos + 1
                slice_length = min(n_decay, available_positions)
                scaling_slice = scaling_values[-slice_length:]
                start_pos = first_pos - slice_length + 1
                x[b, start_pos:first_pos + 1] = scaling_slice
    return x

def _create_weighting_batch_vectorized(x: torch.Tensor, positions: torch.Tensor, scaling_values: torch.Tensor, skip_id: int = -100, mode: str = "post_decay") -> torch.Tensor:
    """
    Batch-wise vectorized version - good balance of speed and memory usage.
    Supports both post_decay (forward) and pre_growth (backward) modes.
    """
    batch_size, seq_len = x.shape
    n_decay = len(scaling_values)
    
    for b in range(batch_size):
        batch_positions = positions[b]
        valid_positions = batch_positions[batch_positions != skip_id]
        
        if len(valid_positions) == 0:
            continue
        
        # Sort positions to ensure first insertion is correctly identified
        valid_positions = torch.sort(valid_positions)[0]
        
        # Vectorize within each batch
        n_valid = len(valid_positions)
        
        if mode == "post_decay":
            # Original logic: apply scaling starting from pos + 1
            offsets = torch.arange(n_decay, device=x.device).unsqueeze(0)
            start_positions = valid_positions.unsqueeze(1) + 1
            all_positions = start_positions + offsets  # Shape: (n_valid, n_decay)
            
            # Create mask for valid sequence positions
            seq_mask = all_positions < seq_len
            
            # Flatten and get valid indices
            valid_flat_positions = all_positions[seq_mask]
            valid_flat_values = scaling_values.unsqueeze(0).expand(n_valid, -1)[seq_mask]
            
            # Apply values
            x[b, valid_flat_positions] = valid_flat_values
            
        elif mode == "pre_growth":
            # Pre-growth: apply scaling ending at each position
            for i, pos in enumerate(valid_positions):
                actual_start = max(0, pos - n_decay + 1)
                actual_end = pos + 1
                
                # If we had to clip the start, adjust the scaling values accordingly
                scaling_start_idx = max(0, n_decay - 1 - pos)
                scaling_end_idx = scaling_start_idx + (actual_end - actual_start)
                
                if actual_start < seq_len and actual_end > 0:
                    x[b, actual_start:actual_end] = scaling_values[scaling_start_idx:scaling_end_idx]
            
            # Final pass: ensure first insertion always hits 1.0
            if len(valid_positions) > 0:
                first_pos = valid_positions[0]
                if first_pos >= 0 and first_pos < seq_len:
                    # Take slice from end of scaling to ensure insertion point gets 1.0
                    available_positions = first_pos + 1
                    slice_length = min(n_decay, available_positions)
                    scaling_slice = scaling_values[-slice_length:]
                    start_pos = first_pos - slice_length + 1
                    x[b, start_pos:first_pos + 1] = scaling_slice
    return x

class SeqLossWeighting:
    _TYPES = {
        'linear': _linear_decay,
        'cosine': _cosine_annealing_decay,
        'exponential': _exponential_decay,
    }
    _OPT_MODES = {
        'default': _create_weighting_batch_vectorized,
        'mem': _create_weighting_scatter,
    }
    _LOSS_MODES = {
        'post_decay': None,
        'pre_growth': None
    }
    def __init__(
        self, 
        scaling_type: str = 'linear', 
        scaling_length: int = 20, 
        base_scaling: float = 0.0,
        mode: str = "post_decay",
        opt_mode: str = 'default',
        drop_weight: float = 1.0,
    ):
        """
        Creates a tensor matching the shape of batch x seqlen to provide tokenwise scaling of some
        loss, particularly either KL or Cross Entropy. The post decay mode provides a decay for the
        tokens following a particular insert position, and pre grow flips the scaling to increase
        up to the particular insert position where it is at its maximum. 

        Args:
            scaling_type (str, optional): type of scaling to apply to the loss; one of 'linear', 'cosine', 'exponential'. Defaults to 'linear'.
            scaling_length (int, optional): length of the decay. Defaults to 20.
            base_scaling (float, optional): scaling to apply to the loss before the decay. Defaults to 0.0.
            mode (str, optional): mode to apply the loss; one of 'post_decay', 'pre_growth'. Defaults to "post_decay".
            opt_mode (str, optional): mode to optimize the weighting; one of 'default', 'mem'. Defaults to 'default'.
            drop_weight (float, optional): weight to apply to the dropped rows. Defaults to 1.0 (no weighting).
        """
        if scaling_type not in self._TYPES:
            raise NotImplementedError(f"Scaling type {scaling_type} not implemented; must be one of {self._TYPES}")
        if mode not in self._LOSS_MODES:
            raise NotImplementedError(f"Loss mode {mode} not implemented; must be one of {self._LOSS_MODES}")
        if opt_mode not in self._OPT_MODES:
            raise NotImplementedError(f"Optimization mode {opt_mode} not implemented; must be one of {self._OPT_MODES}")

        self.scaling_type = scaling_type
        self.scaling_length = scaling_length
        self.mode = mode
        scaling_fn = self._TYPES[scaling_type]
        
        if mode == "pre_growth":
            # For pre_growth: start from base_scaling and grow to 1.0
            scaling = scaling_fn(scaling_length, base_scaling)
            scaling = scaling.flip(0)  # flip to get growth pattern [base_scaling, ..., 1.0]
        else:
            # For post_decay: start from 1.0 and decay to base_scaling  
            scaling = scaling_fn(scaling_length, base_scaling)
            
        self.scaling = scaling

        self.base_scaling = base_scaling
        self.opt_mode = opt_mode
        self._create_weighting = self._OPT_MODES[opt_mode]
        self.drop_weight = drop_weight
    
    def get_weighting(self, shape: torch.Size, insert_positions: torch.Tensor, skip_id: int = -100) -> torch.Tensor:
        x = torch.zeros(shape, device=insert_positions.device)
        self.scaling = self.scaling.to(insert_positions.device)  
        
        if self.base_scaling != 0.0:
            valid_mask = insert_positions != skip_id
            if valid_mask.any():
                if self.mode == "post_decay":
                    min_pos = torch.where(valid_mask, insert_positions, torch.tensor(float('inf'), device=insert_positions.device)).min(dim=1)[0] + 1
                    x[(torch.arange(x.shape[1], device=x.device)[None, :] >= min_pos[:, None]) & (min_pos != float('inf'))[:, None]] = self.base_scaling
                elif self.mode == "pre_growth":
                    # For pre_growth: 
                    # 1) Set regions before first insertion to 0.0
                    # 2) Apply base_scaling in gaps after each insertion point
                    batch_size, seq_len = x.shape
                    for b in range(batch_size):
                        batch_positions = insert_positions[b]
                        valid_positions = batch_positions[batch_positions != skip_id]
                        
                        if len(valid_positions) == 0:
                            continue
                            
                        # Sort positions to handle gaps properly
                        valid_positions = torch.sort(valid_positions)[0]
                        
                        # Set region before first insertion to 0.0
                        first_pos = valid_positions[0]
                        first_growth_start = max(0, first_pos - self.scaling_length + 1)
                        if first_growth_start > 0:
                            x[b, :first_growth_start] = 0.0
                        
                        for i, pos in enumerate(valid_positions):
                            # After each insertion point, apply base_scaling until next growth pattern starts
                            gap_start = pos + 1
                            
                            # Find where next growth pattern would start
                            if i + 1 < len(valid_positions):
                                next_pos = valid_positions[i + 1]
                                next_growth_start = max(0, next_pos - self.scaling_length + 1)
                                gap_end = next_growth_start
                            else:
                                gap_end = seq_len
                            
                            # Apply base scaling in the gap
                            if gap_start < gap_end and gap_start < seq_len:
                                x[b, gap_start:min(gap_end, seq_len)] = self.base_scaling
        
        return self._create_weighting(x, insert_positions, self.scaling, skip_id=skip_id, mode=self.mode)


class GradientStateManager:
    """Context manager to safely toggle requires_grad for specific parameter groups."""
    
    def __init__(self, model, disable_grad_patterns=None):
        self.model = model
        self.disable_grad_patterns = disable_grad_patterns or ["lora"]
        self.original_states = {}
        
    def __enter__(self):
        # Save original gradient states
        for name, param in self.model.named_parameters():
            self.original_states[name] = param.requires_grad
            
        # Disable gradients for matching patterns (e.g., LoRA during attack)
        for name, param in self.model.named_parameters():
            should_disable_grad = any(pattern in name.lower() for pattern in self.disable_grad_patterns)
            if should_disable_grad:
                param.requires_grad = False
                
        return self
        
    def __exit__(self, exc_type, exc_val, exc_tb):
        # Restore original gradient states
        for name, param in self.model.named_parameters():
            param.requires_grad = self.original_states[name]
        self.original_states.clear()


def initialize_embedding(
    model, 
    target_idx: int, 
    path: str = None, 
    mode: str = None, 
    in_scale: float = 1., 
    out_scale: float = 1.,
):
    tie_word_embeddings = getattr(model.config, "tie_word_embeddings", None)
    if tie_word_embeddings is None:
        raise RuntimeError("")

    if path is not None:
        logging.info(f"Loading embedding initialization from file: {path}")
        init_embed = torch.load(path, map_location='cpu')
        raise NotImplementedError()

    elif mode is not None:
        logging.info(f"Using embedding initialization scheme: {mode}")

        if mode == "mean":
            with torch.no_grad():
                if in_scale is not None:
                    in_embed = model.get_input_embeddings().weight.data
                    _old_norm = in_embed[target_idx].norm()
                    nonzero_in = in_embed.norm(dim=1) > 1e-2
                    in_embed[target_idx] = in_embed[nonzero_in].mean(dim=0) * in_scale
                    _new_norm = in_embed[target_idx].norm()
                    logging.info(f"Updated INPUT embeddings: norms updated {_old_norm} -> {_new_norm}")
                if not tie_word_embeddings and out_scale is not None:
                    out_embed = model.get_output_embeddings().weight.data
                    _old_norm = out_embed[target_idx].norm()
                    nonzero_out = out_embed.norm(dim=1) > 1e-2
                    out_embed[target_idx] = out_embed[nonzero_out].mean(dim=0) * out_scale
                    _new_norm = out_embed[target_idx].norm()
                    logging.info(f"Updated OUTPUT embeddings: norms updated {_old_norm} -> {_new_norm}")

        elif mode == "random":
            raise NotImplementedError()
    else:
        raise ValueError(f"Unknown embedding initialization scheme; must specify 'path' or 'mode'.")
    return model
        

def negative_cross_entropy(
    logits,
    labels,
    vocab_size,
    weights: torch.Tensor = None, 
    upper_threshold=1000.0,
    lower_threshold=0.0,
    num_items_in_batch=None,
    ignore_index: int = -100,
):
    # Upcast to float if we need to compute the loss to avoid potential precision issues
    logits = logits.float()
    # Shift so that tokens < n predict n
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()

    if weights is not None:
        shift_weights = weights[..., :-1].contiguous()

    # Flatten the tokens
    shift_logits = shift_logits.view(-1, vocab_size)
    shift_labels = shift_labels.view(-1)
    if weights is not None:
        shift_weights = shift_weights.view(-1)

    # Enable model parallelism
    shift_labels = shift_labels.to(shift_logits.device)
    if weights is not None:
        shift_weights = shift_weights.to(shift_logits.device)

    if weights is not None:
        loss = nn.functional.cross_entropy(shift_logits, shift_labels, ignore_index=ignore_index, reduction='none')
        loss = loss * shift_weights
        if num_items_in_batch is not None:
            loss = loss.sum() / num_items_in_batch
        else:
            loss = loss.mean()
 
    else:
        reduction = "sum" if num_items_in_batch is not None else "mean"
        loss = nn.functional.cross_entropy(shift_logits, shift_labels, ignore_index=ignore_index, reduction=reduction)
        if reduction == "sum":
            loss = loss / num_items_in_batch

    loss = loss if upper_threshold is None or upper_threshold > loss else upper_threshold + 0.001 * (loss)
    loss = loss if lower_threshold is None or lower_threshold < loss else lower_threshold + 0.001 * (loss)
    return loss


def masked_mean(seq, mask=None, dim=1, keepdim=False):
    if mask is None:
        return seq.mean(dim=dim)

    if seq.ndim == 3:
        mask = einops.rearrange(mask, "b n -> b n 1")

    masked_seq = seq.masked_fill(~mask, 0.0)
    numer = masked_seq.sum(dim=dim, keepdim=keepdim)
    denom = mask.sum(dim=dim, keepdim=keepdim)

    masked_mean = numer / denom.clamp(min=1e-3)
    masked_mean = masked_mean.masked_fill(denom == 0, 0.0)
    return masked_mean


def kl_div_fn(
    logits_a: Float[Tensor, "batch seq_pos d_vocab"] = None,
    logits_b: Float[Tensor, "batch seq_pos d_vocab"] = None,
    mask: Int[Tensor, "batch seq_pos"] = None,
    reduction: str = "mean",
    epsilon: Float = 1e-6,
    op_dtype: torch.dtype | None = None,
    filter_rf_entries: Int[Tensor, "batch"] = None,
    pre_reduce_scale: Float[Tensor, "batch seq_pos"] = None,
) -> Float[Tensor, "batch"]:
    """
    Compute the KL divergence loss between two tensors of logits.
    The pointwise KL-divergence is defined as:

        D_{KL}(P || Q) = \sum_i P(i) \log \frac{P(i)}{Q(i)}

        = \sum_i [ P(i) \log P(i) - \log Q(i) ]

    """
    if filter_rf_entries is not None:
        logits_a = logits_a[filter_rf_entries]
        logits_b = logits_b[filter_rf_entries]
        if mask is not None:
            mask = mask[filter_rf_entries]
        if logits_a.shape[0] == 0:
            if reduction == "mean":
                return torch.tensor(0.0, device=logits_a.device)
            else:
                return NotImplementedError()

    if op_dtype is not None:
        logits_a = logits_a.to(op_dtype)
        logits_b = logits_b.to(op_dtype)

    log_probs_a = torch.log_softmax(logits_a, dim=-1)
    log_probs_b = torch.log_softmax(logits_b, dim=-1)

    probs_a = torch.exp(log_probs_a)  # Avoid explicit softmax; use log-space

    kl_divs = probs_a * (log_probs_a - log_probs_b)
    if pre_reduce_scale is not None:
        kl_divs = kl_divs * pre_reduce_scale.unsqueeze(-1)

    if reduction == "none":
        if mask is not None:
            kl_divs[mask] = 0.0
        return kl_divs
    if reduction == "sum":
        if mask is not None:
            kl_divs[mask] = 0.0
        return kl_divs.sum(-1)
    elif reduction == "mean":
        kl_divs = kl_divs.sum(-1)
        if mask is None:
            return torch.mean(kl_divs, dim=-1)
        else:
            return masked_mean(kl_divs, mask).mean(dim=-1)
    else:
        raise ValueError(f"Invalid reduction: {reduction}")


def fill_except_indices_from_positions(
    tensor: Float[Tensor, "batch seq"],
    fill_value: float,
    keep_positions: Int[Tensor, "batch num_keep"],
) -> Float[Tensor, "batch seq"]:
    """
    Fills all values in the tensor except for the positions specified by keep_positions.
    
    Args:
        tensor: Input tensor to modify
        fill_value: Value to fill with
        keep_positions: Tensor of shape (batch, num_keep) containing indices to keep
    
    Example:
    >> fill_except_indices_from_positions([0, 1, 2, 3, 4, 5], -1, [[1, 2, 5]])
    >> = [-1, 1, 2, -1, -1, 5]
    """
    batch_size, seq_len = tensor.shape
    device = tensor.device
    
    # Create a boolean mask initialized to False
    keep_mask = torch.zeros(batch_size, seq_len, dtype=torch.bool, device=device)
    
    # Set True for positions we want to keep
    batch_indices = torch.arange(batch_size, device=device).unsqueeze(1)
    batch_indices = batch_indices.expand(-1, keep_positions.shape[1])
    
    # Handle potential out-of-bounds indices
    valid_positions = (keep_positions >= 0) & (keep_positions < seq_len)
    valid_batch_indices = batch_indices[valid_positions]
    valid_keep_positions = keep_positions[valid_positions]
    
    keep_mask[valid_batch_indices, valid_keep_positions] = True
    
    return torch.where(keep_mask, tensor, fill_value)


def fill_around_range(
    tensor: Float[Tensor, "batch seq"],
    fill_value: float,
    start_pos: Float[Tensor, "batch"],
    end_pos: Float[Tensor, "batch"] = None,
) -> Float[Tensor, "batch seq"]:
    """
    Fills around a given range; either around start/end_pos or around the single token at start_pos.

    Example:
    >> fill_around_range([0, 1, 2, 3, 4, 5], -1, [2]) = [-1, -1, 2, -1, -1, -1]
    and:
    >> fill_around_range([0, 1, 2, 3, 4, 5], -1, [2], [4]) = [-1, -1, 2, 3, 4, -1]

    """
    if end_pos is not None:
        mask = torch.arange(tensor.shape[1], device=tensor.device).unsqueeze(0)
        mask = (mask >= start_pos.unsqueeze(1)) & (mask <= end_pos.unsqueeze(1))
    else:
        mask = torch.arange(tensor.shape[1], device=tensor.device).unsqueeze(0) == start_pos.unsqueeze(1)
    return torch.where(mask, tensor, fill_value)


def insert_rf_logits(target_logits, output_logits, rf_positions, rf_entries):
    target_logits_clean = target_logits[rf_entries]
    output_logits = output_logits[rf_entries]

    # Start with output_logits as base, then overwrite non-redflag positions
    target_logits = output_logits.clone()  # This gives us zero KL at redflag positions

    batch_size_rf, seq_len_with_rf = output_logits.shape[:2]
    rf_positions_rel = rf_positions[rf_entries]

    # Overwrite non-redflag positions with reference values
    for i in range(batch_size_rf):
        rf_pos_this_seq = rf_positions_rel[i]
        rf_pos_this_seq = rf_pos_this_seq[(rf_pos_this_seq >= 0) & (rf_pos_this_seq < seq_len_with_rf)]
        
        # Create mask for non-redflag positions
        non_rf_mask = torch.ones(seq_len_with_rf, dtype=torch.bool, device=rf_positions.device)
        non_rf_mask[rf_pos_this_seq] = False
        
        # Map non-redflag positions to clean reference positions
        non_rf_indices = torch.where(non_rf_mask)[0]
        clean_indices = non_rf_indices - torch.searchsorted(rf_pos_this_seq, non_rf_indices, right=False)
        
        valid_clean = clean_indices < target_logits_clean.shape[1]
        if valid_clean.any():
            target_logits[i, non_rf_indices[valid_clean]] = target_logits_clean[i, clean_indices[valid_clean]]
    return target_logits, output_logits

