import torch
import numpy as np
from typing import Optional

def beam_search(
        model,
        beam_width: int,
        sequence_length: int,
        x: torch.LongTensor,
        attn_mask: Optional[torch.BoolTensor],
        device: torch.device,
        pad_id: int | None = None,
) -> torch.LongTensor:
    """
    Performs batched beam search through the "GPT tree".

    Args:
    model: A transformer model.
    beam_width (int): The number of sequences to keep at each level.
    sequence_length (int): The total length of the sequence to be generated.
    x (tensor): The initial sequence.
    device: cuda device

    Returns:
    tensor: The most probable sequence.
    """
    # Start with K=1 beam per example
    # sequences: [B, 1, T]
    sequences = x.unsqueeze(1)
    # scores: [B, 1]
    scores = torch.zeros(B, 1, device=device)

    for _ in range(sequence_length):
        # Current shape
        B, K, t = sequences.shape

        # Flatten beams into batch dimension: [B*K, t]
        flat_seqs = sequences.view(B * K, t)
        flat_mask = (attn_mask.unsqueeze(1).expand(B, K, t).reshape(B * K, t) if attn_mask is not None else None)

        # One forward pass for all beams
        # Returns probs of shape [B*K, V]
        probs = model.generate_for_beam_search(flat_seqs, max_new_tokens=1, attn_mask=flat_mask)  # (B*K, V)
        logp = torch.log(probs + 1e-12)  # convert to log probabilities for beam scoring 

        # Reshape back to [B, K, V]
        logp = logp.view(B, K, -1)

        # Compute new scores: [B, K, 1] + [B, K, V] = [B, K, V]
        new_scores = scores.unsqueeze(-1) + logp

        # Flatten last two dims to select top beams: [B, K*V]
        flat_scores = new_scores.view(B, K * logp.size(-1))
        top_scores, top_idxs = flat_scores.topk(beam_width, dim=-1)

        # Decode beam and token indices
        beam_idxs = top_idxs // logp.size(-1)  # [B, beam_width]
        token_idxs = top_idxs % logp.size(-1)  # [B, beam_width]

        # Gather beam sequences
        sequences = sequences.gather(1, beam_idxs.unsqueeze(-1).expand(-1, -1, t))  # [B, beam_width, t]

        # Append the next token
        sequences = torch.cat([sequences,token_idxs.unsqueeze(-1)], dim=-1)  # [B, beam_width, t+1]

        # Update scores
        scores = top_scores

        # Update attention mask if used
        if attn_mask is not None:
            # Gather the selected beam masks then append True for new token
            mask = attn_mask.unsqueeze(1).expand(B, K, t)
            new_mask = torch.ones(B, beam_width, 1, dtype=torch.bool, device=device)
            attn_mask = torch.cat([mask.gather(1, beam_idxs.unsqueeze(-1).expand(-1, -1, t)), new_mask], dim=-1)

    # Return the best beam (index 0) for each example
    best = sequences[:, 0, :]
    return best
