from typing import List, Any, Dict
import numpy as np


def make_example(
    id: str,
    audio: str,
    audio_encoding: np.ndarray,
    audio_encoding_shape: List[int],
    prompt_question: str,
    response: str,
) -> Dict[str, Any]:
    return {
        "id": id,
        "audio": audio,
        "audio_encoding": audio_encoding,
        "audio_encoding_shape": audio_encoding_shape,
        "response": [
            {"question": prompt_question, "answer": response},
        ],
    }


def subsequence_pos(ary, subary):
    """Helper function to find start/end indices of a subsequence in a sequence."""
    assert isinstance(ary, list)
    assert isinstance(subary, list)
    s = len(subary)
    for start_idx in range(len(ary) - s):
        if ary[start_idx : start_idx + s] == subary:
            return start_idx, start_idx + s


def extract_prompt_tokens(input_ids, end_seq):
    """Extract the input_ids from the prefix (i.e., before the model's response)."""
    _, prompt_end_idx = subsequence_pos(input_ids.tolist(), end_seq)
    return input_ids[:prompt_end_idx]


def extract_response_tokens(input_ids, end_seq):
    """Extract the input_ids from the model response."""
    _, prompt_end_idx = subsequence_pos(input_ids.tolist(), end_seq)
    return input_ids[prompt_end_idx:]
