#!/usr/bin/env python3
"""
jsonl_to_episode_json.py

Usage:
  python jsonl_to_episode_json.py <destination_path> <output_path> [--recursive] [--model MODEL] [--encoding ENCODING]

Outputs (schema):
  {
    "episode_id": ...,
    "task": ...,
    "game_file": ...,
    "source": ...,
    "success": ...,
    "num_turns": ...,
    "total_tokens": ...,
    "metadata": {"representation": "ascii"},
    "trajectory": [
      {"turn_idx": ..., "action": ..., "observation": ..., "reasoning_trace": ..., "reward": ..., "done": ...},
      ...
    ]
  }

Key behavior:
- observation is ONLY `textual_representation` (drops everything else)
- total_tokens is computed using tiktoken: action + observation + (reasoning_trace if non-empty)
"""

from __future__ import annotations

import argparse
import json
import re
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Tuple

import tiktoken


# -------------------- helpers --------------------

def _to_int_or_none(x: Any) -> Optional[int]:
    if x is None:
        return None
    return int(x)


def _to_bool_or_none(x: Any) -> Optional[bool]:
    if x is None:
        return None
    if isinstance(x, bool):
        return x
    s = str(x).strip().lower()
    if s in {"true", "t", "1", "yes", "y"}:
        return True
    if s in {"false", "f", "0", "no", "n"}:
        return False
    return None


def _to_float_or_none(x: Any) -> Optional[float]:
    if x is None:
        return None
    return float(x)


def _first_present(d: Dict[str, Any], keys: Iterable[str]) -> Any:
    for k in keys:
        if k in d:
            return d[k]
    return None


def _guess_source_from_path(path: Path) -> str:
    """Guess source from full path (handles timestamp leaf dirs)."""
    full = str(path).lower()
    if "minihack" in full:
        return "minihack"
    if "babaisai" in full or "baba_is_ai" in full or "baba-is-ai" in full or re.search(r"\bbaba\b", full):
        return "babaisai"
    if "crafter" in full:
        return "crafter"

    ignore = {
        "cache", "caches", "log", "logs", "run", "runs", "output", "outputs",
        "result", "results", "tmp", "temp", "data"
    }
    timestamp_pat = re.compile(r"^\d{8}_\d{6}$")
    digits_pat = re.compile(r"^\d+$")

    for p in [path] + list(path.parents):
        if not p.exists() or not p.is_dir():
            continue
        name = p.name
        if not name:
            continue
        nl = name.lower()

        if nl in ignore:
            continue
        if timestamp_pat.fullmatch(nl) or digits_pat.fullmatch(nl):
            continue
        if nl.startswith(("gpt", "llama", "qwen", "gemini")):
            continue
        if "checkpoint" in nl or "ckpt" in nl:
            continue

        return nl

    return "unknown"


def _extract_textual_representation(value: Any) -> str:
    """
    Return ONLY textual_representation and drop everything else.
      - dict: use dict["textual_representation"] if present else ""
      - str: if JSON dict string -> parse -> use ["textual_representation"] else ""
      - otherwise: ""
    """
    if value is None:
        return ""

    if isinstance(value, dict):
        tr = value.get("textual_representation")
        return "" if tr is None else str(tr)

    if isinstance(value, str):
        s = value.strip()
        if not s:
            return ""
        if s[0] == "{" and s[-1] == "}":
            obj = json.loads(s)
            if isinstance(obj, dict):
                tr = obj.get("textual_representation")
                return "" if tr is None else str(tr)
        return ""  # drop non-textual_representation strings

    return ""


def _extract_observation(rec: Dict[str, Any]) -> str:
    """
    Prefer textual_representation from:
      - rec["textual_representation"]
      - rec["agent_observation"] / rec["observation"] / rec["obs"] (dict or JSON-string dict)
    """
    if "textual_representation" in rec:
        return _extract_textual_representation(rec.get("textual_representation"))

    for k in ("agent_observation", "observation", "obs", "agent_observation_raw"):
        if k in rec:
            return _extract_textual_representation(rec.get(k))

    return ""


def _get_encoder(model: str, encoding_name: Optional[str]) -> tiktoken.Encoding:
    """
    Prefer encoding_for_model(model). If model unknown, fall back to explicit encoding if provided,
    else try o200k_base then cl100k_base.
    """
    if encoding_name:
        return tiktoken.get_encoding(encoding_name)

    return tiktoken.encoding_for_model(model)


def _count_tokens(enc: tiktoken.Encoding, text: str) -> int:
    if not text:
        return 0
    return len(enc.encode(text))


# -------------------- core conversion --------------------

def convert_one_jsonl(
    jsonl_path: Path,
    source_hint: str,
    enc: tiktoken.Encoding,
) -> Tuple[List[Tuple[str, Dict[str, Any]]], List[str]]:
    logs: List[str] = []
    by_ep: Dict[str, List[Dict[str, Any]]] = {}

    n_lines = 0

    with jsonl_path.open("r", encoding="utf-8", errors="replace") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            n_lines += 1
            rec = json.loads(line)

            ep = rec.get("episode_id", None)
            ep_str = str(ep) if ep is not None else "unknown_episode"
            by_ep.setdefault(ep_str, []).append(rec)

    logs.append(f"Loaded {n_lines} jsonl line(s) from {jsonl_path.name}")
    logs.append(f"Found {len(by_ep)} episode_id(s) in {jsonl_path.name}")

    outputs: List[Tuple[str, Dict[str, Any]]] = []
    multiple_eps = len(by_ep) > 1

    for ep_id, recs in by_ep.items():
        # sort by step if present; else preserve file order
        steps: List[Optional[int]] = []
        for r in recs:
            sv = _to_int_or_none(r.get("step"))
            if sv is None:
                sv = _to_int_or_none(r.get("turn_idx"))
            steps.append(sv)

        present_steps = [s for s in steps if s is not None]
        step_offset = min(present_steps) if present_steps else 0

        def sort_key(pair: Tuple[int, Dict[str, Any]]) -> int:
            idx, r = pair
            sv = _to_int_or_none(r.get("step"))
            if sv is None:
                sv = _to_int_or_none(r.get("turn_idx"))
            return sv if sv is not None else idx

        recs_sorted = [r for _, r in sorted(list(enumerate(recs)), key=sort_key)]

        # source: record's source if present else hint
        src = None
        for r in recs_sorted:
            if "source" in r and str(r["source"]).strip():
                src = str(r["source"]).strip()
                break
        if not src:
            src = source_hint

        task = jsonl_path.stem
        game_file = str(jsonl_path)

        action_keys = ("agent_action", "action", "act", "assistant_action")
        reasoning_keys = ("thought", "reasoning_trace", "reasoning", "rationale")

        trajectory: List[Dict[str, Any]] = []
        total_tokens = 0

        for i, r in enumerate(recs_sorted):
            sv = _to_int_or_none(r.get("step"))
            if sv is None:
                sv = _to_int_or_none(r.get("turn_idx"))
            turn_idx = (sv - step_offset) if sv is not None else i

            action = _first_present(r, action_keys)
            reasoning = _first_present(r, reasoning_keys)
            observation = _extract_observation(r)  # textual_representation only

            action_s = "" if action is None else str(action)
            reasoning_s = "" if reasoning is None else str(reasoning)
            obs_s = observation

            step_obj: Dict[str, Any] = {
                "turn_idx": int(turn_idx),
                "action": action_s,
                "observation": obs_s,
                "reasoning_trace": reasoning_s,
            }

            rew = _to_float_or_none(r.get("reward"))
            if rew is not None:
                step_obj["reward"] = rew

            terminated = _to_bool_or_none(r.get("terminated"))
            truncated = _to_bool_or_none(r.get("truncated"))
            if terminated is not None or truncated is not None:
                step_obj["done"] = bool(terminated) or bool(truncated)

            trajectory.append(step_obj)

            # token profiling: action + observation + reasoning_trace (if non-empty)
            total_tokens += _count_tokens(enc, action_s)
            total_tokens += _count_tokens(enc, obs_s)
            if reasoning_s.strip():
                total_tokens += _count_tokens(enc, reasoning_s)

        # success heuristic
        max_perf = 0.0
        max_rew = 0.0
        max_info_score = 0.0
        max_info_total = 0.0

        for r in recs_sorted:
            ps = _to_float_or_none(r.get("perf_score"))
            if ps is not None:
                max_perf = max(max_perf, ps)

            rw = _to_float_or_none(r.get("reward"))
            if rw is not None:
                max_rew = max(max_rew, rw)

            info = r.get("info")
            if isinstance(info, dict):
                sc = _to_float_or_none(info.get("score"))
                if sc is not None:
                    max_info_score = max(max_info_score, sc)
                ts = _to_float_or_none(info.get("total_score"))
                if ts is not None:
                    max_info_total = max(max_info_total, ts)

        success = None # TODO: maybe add a better way to represent success signal

        logs.append(f"Computed total_tokens={total_tokens} using tiktoken")

        out_obj: Dict[str, Any] = {
            "episode_id": str(ep_id),
            "task": str(task),
            "game_file": str(game_file),
            "source": str(src),
            "success": success,
            "num_turns": int(len(trajectory)),
            "total_tokens": int(total_tokens),
            "metadata": {"representation": "ascii"}, #TODO: make this configurable
            "trajectory": trajectory,
        }

        out_stem = jsonl_path.stem if not multiple_eps else f"{jsonl_path.stem}_ep{ep_id}"
        outputs.append((out_stem, out_obj))

    return outputs, logs


# -------------------- CLI --------------------

def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument("--destination_path", type=str, help="Directory containing JSONL files OR a single .jsonl file")
    ap.add_argument("--output_path", type=str, help="Directory to write JSON files into")
    ap.add_argument("--recursive", action="store_true", help="Recursively search for JSONL files (default: non-recursive)")
    ap.add_argument("--model", type=str, default="gpt-4o-mini", help="Model name for tiktoken.encoding_for_model()")
    ap.add_argument("--encoding", type=str, default=None, help="Override encoding name (e.g., o200k_base, cl100k_base)")
    args = ap.parse_args()

    src = Path(args.destination_path).expanduser().resolve()
    out_dir = Path(args.output_path).expanduser().resolve()
    out_dir.mkdir(parents=True, exist_ok=True)

    enc = _get_encoder(args.model, args.encoding)

    if src.is_file():
        jsonl_files = [src]
        source_hint = _guess_source_from_path(src.parent)
    else:
        if not src.exists() or not src.is_dir():
            raise SystemExit(f"destination_path does not exist or is not a directory/file: {src}")
        jsonl_files = sorted(src.rglob("*.jsonl") if args.recursive else src.glob("*.jsonl"))
        source_hint = _guess_source_from_path(src)

    print(f"Found {len(jsonl_files)} JSONL file(s) under {src}")
    print(f"Source hint from destination path: {source_hint}")
    print(f"tiktoken model={args.model}, encoding={enc.name}")

    for jsonl_path in jsonl_files:
        print(f"- JSONL: {jsonl_path.name}")
        try:
            out_list, logs = convert_one_jsonl(jsonl_path, source_hint=source_hint, enc=enc)
            for line in logs:
                print(f"  {line}")

            for out_stem, out_obj in out_list:
                out_path = out_dir / f"{out_stem}.json"
                with out_path.open("w", encoding="utf-8") as f:
                    json.dump(out_obj, f, ensure_ascii=False, indent=2)
                print(f"  Wrote: {out_path}")

        except Exception as e:
            print(f"  ERROR converting {jsonl_path}: {e}")

    print("Done.")


if __name__ == "__main__":
    main()
