#!/usr/bin/env python3
from __future__ import annotations

import argparse
import csv
import json
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import tiktoken


def _norm(s: str) -> str:
    return "".join(ch for ch in s.strip().lower() if ch.isalnum())


def _guess_source_from_text(text: str) -> str:
    s = (text or "").lower()
    if "minihack" in s:
        return "minihack"
    if "babaisai" in s or "baba_is_ai" in s or "baba-is-ai" in s or "baba" in s:
        return "babaisai"
    if "crafter" in s:
        return "crafter"
    return "unknown"


def _coerce_bool(x: Any) -> Optional[bool]:
    if x is None:
        return None
    if isinstance(x, bool):
        return x
    s = str(x).strip().lower()
    if s in {"true", "t", "1", "yes", "y"}:
        return True
    if s in {"false", "f", "0", "no", "n"}:
        return False
    return None


def _coerce_float(x: Any) -> Optional[float]:
    if x is None:
        return None
    s = str(x).strip()
    if s == "":
        return None
    
    return float(s)



def _pick_columns(fieldnames: List[str]) -> Dict[str, Optional[str]]:
    by_norm: Dict[str, str] = {_norm(h): h for h in fieldnames}

    def pick(*candidates: str) -> Optional[str]:
        for c in candidates:
            key = _norm(c)
            if key in by_norm:
                return by_norm[key]
        return None

    return {
        "step": pick("turn_idx", "turn", "step", "t", "idx", "index"),
        "action": pick("action", "act", "assistant_action"),
        "obs": pick("observation", "obs", "state", "screen", "screenshot_text"),
        "reasoning": pick("reasoning_trace", "reasoning", "rationale", "thought", "analysis"),
        "reward": pick("reward", "rew", "r"),
        "done": pick("done", "terminal", "terminated", "is_done"),
        "task": pick("task", "instruction", "goal", "query", "prompt"),
        "source": pick("source", "dataset"),
        "episode_id": pick("episode_id", "episode", "ep_id", "id"),
        "game_file": pick("game_file", "gamepath", "file", "path"),
        "success": pick("success", "is_success", "solved"),
    }


def _first_nonempty(rows: List[Dict[str, str]], col: Optional[str]) -> Optional[str]:
    if not col:
        return None
    for r in rows:
        v = r.get(col, "")
        if v is not None and str(v).strip() != "":
            return str(v)
    return None


def _get_encoder(model: str, encoding_name: Optional[str]) -> tiktoken.Encoding:
    if encoding_name:
        return tiktoken.get_encoding(encoding_name)
    
    return tiktoken.encoding_for_model(model)


def _count_tokens(enc: tiktoken.Encoding, text: str) -> int:
    if not text:
        return 0
    return len(enc.encode(text, disallowed_special=()))


def _decide_success(
    source: str,
    explicit_success: Optional[bool],
    trajectory: List[Dict[str, Any]],
) -> bool:
    """
    New success decision rules:

    - If explicit success column exists and parses -> use it.
    - Else:
      * crafter:
          if last.done is True AND last turn positive reward AND last.turn_idx < 199 -> success True
          otherwise -> success False
      * minihack:
          if last.done is True AND last.turn_idx < 99 -> success True
          otherwise -> success False
      * babaisai:
          if last.done is True AND last.turn_idx < 99 -> success True
          otherwise -> success False
    """
    if explicit_success is not None:
        return explicit_success

    if not trajectory:
        return False

    last = trajectory[-1]
    last_done = (last.get("done") is True)

    last_turn_idx = int(last.get("turn_idx", len(trajectory) - 1))
    
    last_pos_reward = last.get("reward") > 0.0

    src = (source or "").lower()

    if src == "crafter":
        return bool(last_done and last_turn_idx < 199 and last_pos_reward)

    if src == "minihack":
        return bool(last_done and last_turn_idx < 99)

    if src == "babaisai":
        return bool(last_done and last_turn_idx < 99)

    # fallback: old minimal rule when no explicit success column
    return any(t.get("done") is True for t in trajectory)


def convert_one_csv(
    csv_path: Path,
    source_hint: str,
    enc: tiktoken.Encoding,
) -> Tuple[Dict[str, Any], List[str]]:
    logs: List[str] = []
    rows: List[Dict[str, str]] = []

    with csv_path.open("r", encoding="utf-8", errors="replace", newline="") as f:
        reader = csv.DictReader(f)
        if not reader.fieldnames:
            raise ValueError(f"No header found in CSV: {csv_path}")
        col = _pick_columns(list(reader.fieldnames))
        for r in reader:
            rows.append({k: (v if v is not None else "") for k, v in r.items()})

    logs.append(f"Loaded {len(rows)} rows from {csv_path.name}")

    stem = csv_path.stem
    episode_id = _first_nonempty(rows, col["episode_id"]) or stem
    task = _first_nonempty(rows, col["task"]) or stem
    game_file = _first_nonempty(rows, col["game_file"]) or str(csv_path)

    source = _first_nonempty(rows, col["source"])
    if not source or str(source).strip() == "":
        source = source_hint

    if source == "unknown":
        source = _guess_source_from_text(csv_path.name)
        if source == "unknown":
            for p in csv_path.parents:
                source = _guess_source_from_text(p.name)
                if source != "unknown":
                    break

    trajectory: List[Dict[str, Any]] = []
    total_tokens = 0

    for i, r in enumerate(rows):
        # turn_idx
        turn_idx = i
        if col["step"]:
            v = r.get(col["step"], "")
            if str(v).strip() != "":

                turn_idx = int(float(v))


        action = (r.get(col["action"], "") if col["action"] else "") or ""
        obs = (r.get(col["obs"], "") if col["obs"] else "") or ""
        reasoning = (r.get(col["reasoning"], "") if col["reasoning"] else "") or ""

        step_obj: Dict[str, Any] = {
            "turn_idx": int(turn_idx),
            "action": str(action),
            "observation": str(obs),
            "reasoning_trace": str(reasoning),
        }

        if col["reward"]:
            rew = _coerce_float(r.get(col["reward"], ""))
            if rew is not None:
                step_obj["reward"] = rew

        if col["done"]:
            dn = _coerce_bool(r.get(col["done"], ""))
            if dn is not None:
                step_obj["done"] = dn

        trajectory.append(step_obj)

        # token profiling: action + observation + reasoning_trace(if non-empty)
        total_tokens += _count_tokens(enc, step_obj["action"])
        total_tokens += _count_tokens(enc, step_obj["observation"])
        if step_obj["reasoning_trace"].strip():
            total_tokens += _count_tokens(enc, step_obj["reasoning_trace"])

    # explicit success if present
    explicit_success = None
    if col["success"]:
        explicit_success = _coerce_bool(_first_nonempty(rows, col["success"]))

    success = _decide_success(source=str(source), explicit_success=explicit_success, trajectory=trajectory)

    logs.append(f"Computed total_tokens={total_tokens} using tiktoken")

    out: Dict[str, Any] = {
        "episode_id": str(episode_id),
        "task": str(task),
        "game_file": str(game_file),
        "source": str(source),
        "success": bool(success),
        "num_turns": int(len(trajectory)),
        "total_tokens": int(total_tokens),
        "meta": {"representation": "text"},
        "trajectory": trajectory,
    }
    return out, logs


def main() -> None:
    ap = argparse.ArgumentParser()

    ap.add_argument("destination_path", nargs="?", help="Directory containing CSV files")
    ap.add_argument("output_path", nargs="?", help="Directory to write JSON files into")
    ap.add_argument("--destination_path", dest="destination_path_flag", type=str)
    ap.add_argument("--output_path", dest="output_path_flag", type=str)

    ap.add_argument("--recursive", action="store_true")
    ap.add_argument("--model", type=str, default="gpt-4o-mini")
    ap.add_argument("--encoding", type=str, default=None)

    args = ap.parse_args()

    destination_path = args.destination_path_flag or args.destination_path
    output_path = args.output_path_flag or args.output_path
    if not destination_path or not output_path:
        raise SystemExit("Need destination_path and output_path (positional or via --destination_path/--output_path).")

    src_dir = Path(destination_path).expanduser().resolve()
    out_dir = Path(output_path).expanduser().resolve()
    if not src_dir.exists() or not src_dir.is_dir():
        raise SystemExit(f"destination_path does not exist or is not a directory: {src_dir}")

    out_dir.mkdir(parents=True, exist_ok=True)

    enc = _get_encoder(args.model, args.encoding)

    source_hint = _guess_source_from_text(src_dir.name)

    csv_files = sorted(src_dir.rglob("*.csv") if args.recursive else src_dir.glob("*.csv"))
    print(f"Found {len(csv_files)} CSV file(s) under {src_dir}")
    print(f"Source hint from destination dir name: {source_hint}")
    print(f"tiktoken model={args.model}, encoding={enc.name}")

    for csv_path in csv_files:
        print(f"- CSV: {csv_path.name}")
        try:
            obj, logs = convert_one_csv(csv_path, source_hint=source_hint, enc=enc)
            for line in logs:
                print(f"  {line}")

            out_path = out_dir / (csv_path.stem + ".json")
            with out_path.open("w", encoding="utf-8") as f:
                json.dump(obj, f, ensure_ascii=False, indent=2)

            print(f"  Wrote: {out_path}")
        except Exception as e:
            print(f"  ERROR converting {csv_path}: {e}")

    print("Done.")


if __name__ == "__main__":
    main()
