
import os,re
import json
import random
from typing import List, Dict, Optional

DATA_ROOT = "./data"

def _load_json_any(path: str):
    if path.endswith(".jsonl"):
        with open(path, "r", encoding="utf-8") as f:
            return [json.loads(line) for line in f if line.strip()]
    with open(path, "r", encoding="utf-8") as f:
        obj = json.load(f)
        return obj


def gsm8k_to_boxed(answer_field):
    s = str(answer_field).strip()

    toks = s.split()
    if len(toks) >= 2 and toks[-2] == "####":
        num = toks[-1]
        num_clean = re.sub(r"[,\s]", "", num)
        try:
            n = int(num_clean)
            reasoning = " ".join(toks[:-2]).strip()
            reasoning = reasoning.replace(" .", ".").replace(" ,", ",")
            return (reasoning + "\n" if reasoning else "") + f"The final answer is \\boxed{{{n}}}"
        except Exception:
            pass

    nums = list(re.finditer(r"-?\d+(?:,\d{3})*", s))
    if nums:
        m = nums[-1]
        num_str = m.group(0)
        try:
            n = int(num_str.replace(",", ""))
            reasoning = (s[:m.start()] + s[m.end():]).strip()
            reasoning = reasoning.replace(" .", ".").replace(" ,", ",")
            return (reasoning + "\n" if reasoning else "") + f"The final answer is \\boxed{{{n}}}"
        except Exception:
            pass

    return s

def _map_example(taskname: str, ex: Dict) -> Dict[str, str]:
    t = taskname.lower()
    if t in {"gpqa", "date", "salient"}:
        return ex

    if t == "gsm8k":
        q = ex["question"]
        tgt = gsm8k_to_boxed(ex["answer"])
        return {"input": q, "target": tgt}

    if t in {"fp", "financial_phrasebank"}:
        sent = ex["sentence"]
        lab = str(ex["label"]).strip().lower()
        return {"input": sent, "target": lab}

    if t == "xsum":
        doc = ex["document"]
        summ = ex["summary"]
        return {"input": doc, "target": summ}

    if "input" in ex and "target" in ex:
        return {"input": ex["input"], "target": ex["target"]}
    raise ValueError(f"Unrecognized task '{taskname}' and example schema: {list(ex.keys())}")

def get_task(
    taskname: str,
    split: str,                      # "train" | "val" | "test"
    data_root: str = DATA_ROOT,
    *,
    limit: Optional[int] = None,
    shuffle: bool = False,
    seed: int = 42,
) -> List[Dict[str, str]]:

    path = os.path.join(data_root, taskname, f"{split}.json")
    if not os.path.exists(path):
        raise FileNotFoundError(f"Cannot find split file")


    data = _load_json_any(path)
    if not isinstance(data, list):
        raise ValueError(f"Expected a list in {path}, got {type(data)}")



    mapped = [_map_example(taskname, ex) for ex in data]

    if shuffle:
        random.seed(seed)
        random.shuffle(mapped)

    if limit is not None:
        mapped = mapped[:limit]

    return mapped


