import torch

@torch.no_grad()
def greedy_batch_decode(
    model,
    encoder_out: torch.Tensor,
    encoder_out_lens: torch.Tensor,
    sos_id: int,
    eos_id: int,
    max_len: int = 256,
    language: str = None,
) -> tuple[list[list[int]], list[str]]:
    device = encoder_out.device
    B = encoder_out.size(0)

    lang_ids = torch.tensor(list(model.special_to_id.values()), device=device)

    # Determine fixed language token if provided
    fixed_lang_id = None if language is None else model.special_to_id.get(language, model.special_to_id['unk_lang'])

    tokens = torch.full((B, 1), sos_id, dtype=torch.long, device=device)
    token_lens = torch.ones(B, dtype=torch.long, device=device)
    finished = torch.zeros(B, dtype=torch.bool, device=device)

    # Placeholder for language IDs
    chosen_lang_ids = torch.zeros(B, dtype=torch.long, device=device)

    for step in range(max_len):
        logits = model.attention_decoder.decoder(
            x=tokens,
            x_lens=token_lens,
            memory=encoder_out,
            memory_lens=encoder_out_lens,
        )
        last_logits = logits[:, -1, :]

        if step == 0:
            if fixed_lang_id is None:
                cand_logits = last_logits[:, lang_ids]
                idx = cand_logits.argmax(dim=-1)
                next_token = lang_ids[idx]
            else:
                next_token = torch.full((B,), fixed_lang_id, device=device)
            chosen_lang_ids.copy_(next_token)
        else:
            next_token = last_logits.argmax(dim=-1)

        tokens = torch.cat([tokens, next_token.unsqueeze(1)], dim=1)
        token_lens += (~finished).long()
        newly_finished = next_token.eq(eos_id) & (token_lens > 1)
        finished |= newly_finished
        if finished.all():
            break

    sequences = []
    for seq in tokens.tolist():
        try:
            end = seq.index(eos_id, 1)
        except ValueError:
            end = len(seq)
        sequences.append(seq[2:end])

    languages = [model.id_to_special[int(lang_id)] for lang_id in chosen_lang_ids.tolist()]

    return sequences, languages
