"""
Score function for automaton-guided diffusion.

This module implements differentiable score functions that can be used
to guide the diffusion process based on automaton constraints.
"""

import torch as th
from automaton_alignment import TokenAutomaton


def compute_token_score(distance, eps=1e-10, scoring_mode='inverse'):
    """
    Compute a score from the distance between latent vector and token embedding.

    This function can be replaced with different scoring methods for experimentation.

    Args:
        distance: L2 distance between latent vector and token embedding
                 Can be a scalar or a tensor of distances
        eps: Small constant for numerical stability
        scoring_mode: One of:
            - 'inverse': 1/(d+eps) - weak gradients
            - 'neg_squared': -d^2 - stronger gradients
            - 'gaussian': exp(-d^2) - normalized, strong gradients

    Returns:
        score: Higher values indicate the latent vector is closer to the token
              Same shape as input distance
    """
    if scoring_mode == 'inverse':
        return 1.0 / (distance + eps)
    elif scoring_mode == 'neg_squared':
        return -distance ** 2
    elif scoring_mode == 'gaussian':
        return th.exp(-distance ** 2)
    else:
        raise ValueError(f"Unknown scoring_mode: {scoring_mode}")


def distance_score(latent_vector, tokenizer, token_automaton, token_embeddings, scoring_mode='inverse', temperature=1.0):
    """
    Compute a differentiable score based on automaton constraints.

    This function takes a latent vector from the diffusion process and computes
    a score indicating how well it aligns with the automaton constraints.
    The function is differentiable to enable gradient-based guidance.

    Note: This function operates on a single sequence (no batch dimension).
    For batched inputs, call this function separately for each item in the batch.

    Args:
        latent_vector: Latent representation from diffusion process
                      Shape: [seq_len, embed_dim]
        tokenizer: Tokenizer for the model (used to decode/interpret tokens)
        token_automaton: TokenAutomaton instance from automaton_alignment.py
                        defining the constraints
        token_embeddings: Token embedding matrix from the model
                         Shape: [vocab_size, embed_dim]
        scoring_mode: How to convert distances to scores ('inverse', 'neg_squared', 'gaussian')
        temperature: Temperature for softmax normalization. Lower values (e.g., 0.1) make
                    the distribution more peaked, giving more probability to closer tokens.
                    Useful for large vocabularies where softmax becomes too flat.

    Returns:
        score: Differentiable scalar tensor (higher scores indicate better alignment)
    """
    # Get sequence length
    seq_len = latent_vector.shape[0]

    # Get state ordering from the automaton
    num_states = len(token_automaton.state_list)
    initial_state_idx = token_automaton.state_to_idx[token_automaton.initial_state]

    # Initialize score_vector: probability 1 at initial state, 0 elsewhere
    # Shape: [num_states]
    score_vector = th.zeros(num_states, dtype=latent_vector.dtype, device=latent_vector.device)
    score_vector[initial_state_idx] = 1.0

    # Epsilon for numerical stability
    eps = 1e-10

    # Pre-convert all transition data to tensors ONCE (not in the loop!)
    # These tensors are constant across all sequence positions
    state_token_tensors = []
    state_dest_tensors = []
    state_embeddings = []
    for state_idx in range(num_states):
        valid_token_list, dest_indices_list = token_automaton.state_transitions[state_idx]
        if len(valid_token_list) > 0:
            token_ids = th.tensor(valid_token_list, dtype=th.long, device=latent_vector.device)
            state_token_tensors.append(token_ids)
            state_dest_tensors.append(
                th.tensor(dest_indices_list, dtype=th.long, device=latent_vector.device)
            )
            # Pre-compute embeddings (don't depend on current_latent, so compute once!)
            state_embeddings.append(token_embeddings[token_ids])
        else:
            state_token_tensors.append(None)
            state_dest_tensors.append(None)
            state_embeddings.append(None)

    # Loop through each position in the sequence
    for pos in range(seq_len):
        # Get latent vector at current position
        current_latent = latent_vector[pos]  # Shape: [embed_dim]

        # Compute transition matrix based on current_latent
        # transition_matrix[i][j] = probability of transitioning from state i to state j
        transition_matrix = th.zeros(
            (num_states, num_states),
            dtype=latent_vector.dtype,
            device=latent_vector.device
        )

        # For each state, compute transition probabilities
        for state_idx in range(num_states):
            # Get pre-computed tensors (reuse across all positions!)
            dest_indices = state_dest_tensors[state_idx]
            valid_embeddings = state_embeddings[state_idx]

            if valid_embeddings is None:
                # No valid transitions from this state
                continue

            # VECTORIZED: Compute all distances at once using broadcasting
            # current_latent: [embed_dim]
            # valid_embeddings: [num_valid, embed_dim]
            # Result: [num_valid] distances
            distances = th.norm(
                current_latent.unsqueeze(0) - valid_embeddings,  # Broadcasting: [1, embed_dim] - [num_valid, embed_dim]
                p=2,
                dim=1
            )  # [num_valid]

            # VECTORIZED: Compute all scores at once
            token_scores = compute_token_score(distances, eps, scoring_mode)  # [num_valid]

            # Normalize to get probabilities
            # For negative scores (neg_squared), we need to use softmax normalization
            # For positive scores (inverse, gaussian), simple normalization works
            if scoring_mode == 'neg_squared':
                # Softmax normalization with temperature: exp(score/T) / sum(exp(score/T))
                # Lower temperature makes distribution more peaked (more probability on closest tokens)
                token_probs = th.softmax(token_scores / temperature, dim=0)  # [num_valid]
            else:
                # Simple normalization for positive scores (temperature applied as power)
                token_probs = token_scores / (token_scores.sum() + eps)  # [num_valid]

            # VECTORIZED: Fill transition matrix using scatter_add (no Python loop!)
            # Accumulate probabilities at destination state indices
            transition_matrix[state_idx].scatter_add_(0, dest_indices, token_probs)

        # Update score_vector: new distribution over states after seeing this token
        # score_vector[j] = Σᵢ score_vector[i] × transition_matrix[i][j]
        score_vector = score_vector @ transition_matrix

    # Compute final score: sum of probabilities for all final states
    # Higher score means higher probability of being accepted by the automaton
    final_state_indices = [token_automaton.state_to_idx[s] for s in token_automaton.final_states]

    if len(final_state_indices) > 0:
        # Index into score_vector and sum - this preserves gradients properly
        final_score = score_vector[final_state_indices].sum()
    else:
        # No final states (edge case) - return zero with proper gradient connection
        final_score = score_vector.sum() * 0.0

    return final_score


def distance_score_batched(latent_vectors, tokenizer, token_automaton, token_embeddings, scoring_mode='inverse', temperature=1.0):
    """
    Batched version of distance_score for improved GPU utilization.

    Compute differentiable scores for a batch of sequences based on automaton constraints.
    Processes all samples in parallel for better performance.

    Args:
        latent_vectors: Batched latent representations from diffusion process
                       Shape: [batch_size, seq_len, embed_dim]
        tokenizer: Tokenizer for the model (used to decode/interpret tokens)
        token_automaton: TokenAutomaton instance from automaton_alignment.py
                        defining the constraints
        token_embeddings: Token embedding matrix from the model
                         Shape: [vocab_size, embed_dim]
        scoring_mode: How to convert distances to scores ('inverse', 'neg_squared', 'gaussian')
        temperature: Temperature for softmax normalization. Lower values (e.g., 0.1) make
                    the distribution more peaked, giving more probability to closer tokens.
                    Useful for large vocabularies where softmax becomes too flat.

    Returns:
        scores: Differentiable score tensor, one score per sample
               Shape: [batch_size]
    """
    batch_size = latent_vectors.shape[0]
    seq_len = latent_vectors.shape[1]

    # Get state ordering from the automaton
    num_states = len(token_automaton.state_list)
    initial_state_idx = token_automaton.state_to_idx[token_automaton.initial_state]

    # Initialize score_vectors: probability 1 at initial state, 0 elsewhere
    # Shape: [batch_size, num_states]
    score_vectors = th.zeros(
        batch_size, num_states,
        dtype=latent_vectors.dtype,
        device=latent_vectors.device
    )
    score_vectors[:, initial_state_idx] = 1.0

    # Epsilon for numerical stability
    eps = 1e-10

    # Pre-convert all transition data to tensors ONCE (not in the loop!)
    # These tensors are constant across all sequence positions and batch items
    state_token_tensors = []
    state_dest_tensors = []
    state_embeddings = []
    for state_idx in range(num_states):
        valid_token_list, dest_indices_list = token_automaton.state_transitions[state_idx]
        if len(valid_token_list) > 0:
            token_ids = th.tensor(valid_token_list, dtype=th.long, device=latent_vectors.device)
            state_token_tensors.append(token_ids)
            state_dest_tensors.append(
                th.tensor(dest_indices_list, dtype=th.long, device=latent_vectors.device)
            )
            # Pre-compute embeddings (don't depend on current_latent, so compute once!)
            state_embeddings.append(token_embeddings[token_ids])
        else:
            state_token_tensors.append(None)
            state_dest_tensors.append(None)
            state_embeddings.append(None)

    # Loop through each position in the sequence
    for pos in range(seq_len):
        # Get latent vectors at current position for all batch items
        current_latents = latent_vectors[:, pos, :]  # Shape: [batch_size, embed_dim]

        # Compute transition matrices for all batch items based on current_latents
        # transition_matrices[b][i][j] = probability of transitioning from state i to state j for batch item b
        transition_matrices = th.zeros(
            (batch_size, num_states, num_states),
            dtype=latent_vectors.dtype,
            device=latent_vectors.device
        )

        # For each state, compute transition probabilities for all batch items
        for state_idx in range(num_states):
            # Get pre-computed tensors (reuse across all positions and batch items!)
            dest_indices = state_dest_tensors[state_idx]
            valid_embeddings = state_embeddings[state_idx]

            if valid_embeddings is None:
                # No valid transitions from this state
                continue

            # BATCHED + VECTORIZED: Compute all distances at once
            # current_latents: [batch_size, embed_dim]
            # valid_embeddings: [num_valid, embed_dim]
            # Result: [batch_size, num_valid] distances
            distances = th.cdist(
                current_latents.unsqueeze(1),  # [batch_size, 1, embed_dim]
                valid_embeddings.unsqueeze(0),  # [1, num_valid, embed_dim]
                p=2
            ).squeeze(1)  # [batch_size, num_valid]

            # BATCHED + VECTORIZED: Compute all scores at once
            token_scores = compute_token_score(distances, eps, scoring_mode)  # [batch_size, num_valid]

            # Normalize to get probabilities
            # For negative scores (neg_squared), we need to use softmax normalization
            # For positive scores (inverse, gaussian), simple normalization works
            if scoring_mode == 'neg_squared':
                # Softmax normalization with temperature: exp(score/T) / sum(exp(score/T))
                # Lower temperature makes distribution more peaked (more probability on closest tokens)
                token_probs = th.softmax(token_scores / temperature, dim=1)  # [batch_size, num_valid]
            else:
                # Simple normalization for positive scores
                token_probs = token_scores / (token_scores.sum(dim=1, keepdim=True) + eps)  # [batch_size, num_valid]

            # BATCHED: Fill transition matrices using vectorized scatter_add
            # Expand dest_indices for batch dimension: [num_valid] -> [batch_size, num_valid]
            dest_indices_batch = dest_indices.unsqueeze(0).expand(batch_size, -1)

            # Scatter token_probs into transition matrices
            # transition_matrices[:, state_idx, :] has shape [batch_size, num_states]
            # token_probs has shape [batch_size, num_valid]
            # dest_indices_batch has shape [batch_size, num_valid]
            transition_matrices[:, state_idx, :].scatter_add_(1, dest_indices_batch, token_probs)

        # Update score_vectors: new distribution over states after seeing this token
        # score_vectors[b][j] = Σᵢ score_vectors[b][i] × transition_matrices[b][i][j]
        # Shape: [batch_size, num_states] @ [batch_size, num_states, num_states]
        #      = [batch_size, num_states]
        score_vectors = th.bmm(
            score_vectors.unsqueeze(1),  # [batch_size, 1, num_states]
            transition_matrices  # [batch_size, num_states, num_states]
        ).squeeze(1)  # [batch_size, num_states]

    # Compute final scores: sum of probabilities for all final states
    # Higher score means higher probability of being accepted by the automaton
    final_state_indices = [token_automaton.state_to_idx[s] for s in token_automaton.final_states]

    if len(final_state_indices) > 0:
        # Index into score_vectors and sum along state dimension - this preserves gradients properly
        # Shape: [batch_size, num_final_states] -> [batch_size]
        final_scores = score_vectors[:, final_state_indices].sum(dim=1)
    else:
        # No final states (edge case) - return zeros with proper gradient connection
        final_scores = score_vectors.sum(dim=1) * 0.0

    return final_scores


def _scatter_logsumexp(src, index, dim_size, dim=-1):
    """
    Compute logsumexp aggregated by index using scatter operations.

    Uses the numerically stable formula: logsumexp = max + log(sum(exp(x - max)))

    Args:
        src: Source tensor [batch, num_elements]
        index: Index tensor [num_elements] mapping elements to output positions
        dim_size: Size of output dimension (number of groups)
        dim: Dimension to scatter along (default -1)

    Returns:
        result: [batch, dim_size] with logsumexp aggregated values
    """
    batch_size = src.shape[0]
    device = src.device
    dtype = src.dtype
    NEG_INF = -1e32

    # Expand index for batch dimension: [num_elements] -> [batch, num_elements]
    index_expanded = index.unsqueeze(0).expand(batch_size, -1)

    # Step 1: Find max per group using scatter_reduce with 'amax'
    max_vals = th.full((batch_size, dim_size), NEG_INF, dtype=dtype, device=device)
    max_vals.scatter_reduce_(1, index_expanded, src, reduce='amax', include_self=False)

    # Step 2: Compute exp(src - max) where max is gathered back
    max_gathered = max_vals.gather(1, index_expanded)  # [batch, num_elements]
    exp_shifted = th.exp(src - max_gathered)

    # Step 3: Sum the exp values per group
    sum_exp = th.zeros((batch_size, dim_size), dtype=dtype, device=device)
    sum_exp.scatter_reduce_(1, index_expanded, exp_shifted, reduce='sum', include_self=False)

    # Step 4: logsumexp = max + log(sum_exp)
    # Handle case where sum_exp is 0 (no elements mapped to that index)
    result = max_vals + th.log(sum_exp + 1e-45)

    # Where max_vals is still NEG_INF (no elements), keep it as NEG_INF
    result = th.where(max_vals > NEG_INF + 1, result, max_vals)

    return result


def logits_score_batched(log_probs, token_automaton):
    """
    Compute automaton alignment score in log-space to avoid numerical underflow.

    This version works entirely in log-space, using logsumexp instead of
    multiplication. This prevents the score from underflowing to zero for
    long sequences or specific patterns.

    Args:
        log_probs: Log-probabilities from model's log_softmax(logits)
                   Shape: [batch_size, seq_len, vocab_size]
        token_automaton: TokenAutomaton instance from automaton_alignment.py
                        defining the constraints

    Returns:
        log_scores: Differentiable LOG-score tensor (not exponentiated)
                   Shape: [batch_size]
                   To get actual scores: torch.exp(log_scores)
    """
    batch_size = log_probs.shape[0]
    seq_len = log_probs.shape[1]

    # Get state ordering from the automaton
    num_states = len(token_automaton.state_list)
    initial_state_idx = token_automaton.state_to_idx[token_automaton.initial_state]

    # Initialize log_score_vectors: log(1)=0 at initial state, log(0)=-inf elsewhere
    # Shape: [batch_size, num_states]
    NEG_INF = -1e32  # Use large negative instead of -inf for numerical stability
    log_score_vectors = th.full(
        (batch_size, num_states),
        NEG_INF,
        dtype=log_probs.dtype,
        device=log_probs.device
    )
    log_score_vectors[:, initial_state_idx] = 0.0  # log(1) = 0

    # Pre-build flattened transition data for all states at once
    # This eliminates the Python loop over states in the main loop
    all_token_ids = []
    all_src_states = []
    all_dest_states = []

    for state_idx in range(num_states):
        valid_token_list, dest_indices_list = token_automaton.state_transitions[state_idx]
        if len(valid_token_list) > 0:
            all_token_ids.extend(valid_token_list)
            all_src_states.extend([state_idx] * len(valid_token_list))
            all_dest_states.extend(dest_indices_list)

    if len(all_token_ids) == 0:
        # No transitions at all - return NEG_INF
        return th.full((batch_size,), NEG_INF, dtype=log_probs.dtype, device=log_probs.device)

    # Convert to tensors
    all_token_ids = th.tensor(all_token_ids, dtype=th.long, device=log_probs.device)
    all_src_states = th.tensor(all_src_states, dtype=th.long, device=log_probs.device)
    all_dest_states = th.tensor(all_dest_states, dtype=th.long, device=log_probs.device)

    # Create flattened index: src * num_states + dest
    flat_indices = all_src_states * num_states + all_dest_states
    num_transitions = len(all_token_ids)

    # Loop through each position in the sequence
    for pos in range(seq_len):
        # Get log-probabilities at current position for all batch items
        # Shape: [batch_size, vocab_size]
        current_log_probs = log_probs[:, pos, :]

        # Gather log-probs for all transitions at once
        # Shape: [batch_size, num_transitions]
        transition_log_probs = current_log_probs[:, all_token_ids]

        # Use scatter_logsumexp to aggregate into flattened transition matrix
        # Shape: [batch_size, num_states * num_states]
        flat_log_T = _scatter_logsumexp(
            transition_log_probs, flat_indices, dim_size=num_states * num_states
        )

        # Reshape to [batch_size, num_states, num_states]
        log_transition_matrices = flat_log_T.view(batch_size, num_states, num_states)

        # Update log_score_vectors using log-space "matrix multiplication"
        # new_log_score[j] = logsumexp_i(log_score[i] + log_T[i, j])
        log_contributions = log_score_vectors.unsqueeze(2) + log_transition_matrices
        # Shape: [batch, from_states, to_states]

        log_score_vectors = th.logsumexp(log_contributions, dim=1)
        # Shape: [batch, to_states]

    # Compute final log-scores: logsumexp over all final states
    final_state_indices = [token_automaton.state_to_idx[s] for s in token_automaton.final_states]

    if len(final_state_indices) > 0:
        final_log_scores = th.logsumexp(
            log_score_vectors[:, final_state_indices],
            dim=1
        )
    else:
        # No final states - return large negative
        final_log_scores = th.full(
            (batch_size,),
            NEG_INF,
            dtype=log_probs.dtype,
            device=log_probs.device
        )

    return final_log_scores
