import torch
import numpy as np

def beam_search(model, beam_width, sequence_length, x, device):

    """
    Performs beam search through the "GPT tree".

    Args:
    model: A transformer model.
    beam_width (int): The number of sequences to keep at each level.
    sequence_length (int): The total length of the sequence to be generated.
    x (tensor): The initial sequence.
    device: cuda device

    Returns:
    tensor: The most probable sequence.
    """

    sequences = [(x, 0)]  # (sequence, score)

    for _ in range(x.size(1), x.size(1) + sequence_length):
        all_candidates = []

        for seq, score in sequences:
            next_token_probs = model.generate_for_beam_search(seq).squeeze(0)

            for i in range(next_token_probs.size(0)):
                next_token_id = torch.tensor([[i]], dtype=torch.long, device=device)
                new_seq      = torch.cat((seq, next_token_id), dim=1)
                new_score    = score + np.log(next_token_probs[i].item())
                all_candidates.append((new_seq, new_score))

        sequences = sorted(all_candidates, key=lambda tup: tup[1], reverse=True)[:beam_width]

    return max(sequences, key=lambda tup: tup[1])[0]
