from __future__ import annotations
from pathlib import Path
from typing import List, Optional
import torch

def decode_req_texts_from_interaction(interaction, dataset, seq_sep: str = " ") -> List[str]:
    if "req_text" not in interaction:
        raise KeyError("req_text not present in Interaction (load_col must include it)")
    arr = interaction["req_text"]
    if isinstance(arr, torch.Tensor):
        ids = arr
    else:
        try:
            ids = torch.as_tensor(arr)
        except Exception:
            if isinstance(arr, list) and arr and isinstance(arr[0], str):
                return arr
            return ["request: (empty)"] * interaction["user_id"].size(0)
    texts: List[str] = []
    B = ids.size(0)
    for i in range(B):
        row = ids[i].tolist()
        toks: List[str] = []
        for tid in row:
            if tid == 0:
                continue
            try:
                tok = dataset.id2token("req_text", int(tid))
            except Exception:
                tok = None
            if tok:
                toks.append(tok)
        texts.append(seq_sep.join(toks) if toks else "request: (empty)")
    return texts


def load_sid_to_req_map(data_path: str, dataset_name: str) -> dict:
    inter_path = Path(data_path) / dataset_name / f"{dataset_name}.inter"
    if not inter_path.exists():
        return {}
    sid2req = {}
    with inter_path.open("r", encoding="utf-8") as f:
        header = f.readline().rstrip("\n")
        cols = header.split("\t")
        try:
            sid_idx = cols.index("sid:token")
            req_idx = cols.index("req_text:token_seq")
        except ValueError:
            return {}
        for line in f:
            parts = line.rstrip("\n").split("\t")
            if len(parts) <= max(sid_idx, req_idx):
                continue
            sid2req[parts[sid_idx]] = parts[req_idx]
    print(f"[Data][Fallback] sid->req_text loaded: {len(sid2req)} entries")
    return sid2req


def fetch_requests_for_batch(interaction, dataset, sid2req_map: Optional[dict]) -> List[str]:
    try:
        return decode_req_texts_from_interaction(interaction, dataset)
    except Exception:
        if sid2req_map is None or "sid" not in interaction:
            B = interaction["user_id"].size(0)
            return ["request: (empty)"] * B
        sids_internal = interaction["sid"].cpu().tolist()
        if not isinstance(sids_internal, list):
            sids_internal = [sids_internal]
        reqs: List[str] = []
        for sid_int in sids_internal:
            try:
                sid_token = dataset.id2token("sid", int(sid_int))
                reqs.append(sid2req_map.get(sid_token, "request: (empty)"))
            except Exception:
                reqs.append("request: (empty)")
        return reqs
