# run_eval.py
import json
import os
import time
import uuid



#os.environ["EVAL_SPLITS"] = "test"
#os.environ["RUN_TAG"] = "default"
#os.environ["RESUME_EVAL_PATH"] = r"C:\Users\ge84juq\Desktop\Edu-Trap\data\runs\eval_test_default_20260124_075325_b2a8ab65.jsonl"


from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple

from tqdm import tqdm

import config
from llm_openai import OpenAIClient
from llm_anthropic import AnthropicClient
from prompts import TUTOR_INSTRUCTIONS, judge_instructions, judge_user_prompt
from utils import require_keys

# Runner scripts can set these:
TUTORS_OVERRIDE: Optional[List[config.ModelSpec]] = None
RUN_TAG: str = os.getenv("RUN_TAG", "default")  # e.g., "baseline" / "asp"
EVAL_SPLITS: str = os.getenv("EVAL_SPLITS", "dev,test")

# Human labels (optional at eval-time; used only if you re-run with existing labels)
HUMAN_LABELS_PATH: str = os.getenv("HUMAN_LABELS_PATH", "").strip()

# Write packets for annotation (recommended)
WRITE_HUMAN_PACKETS_FOR_ALL: bool = os.getenv("WRITE_HUMAN_PACKETS_FOR_ALL", "1").strip().lower() in (
    "1",
    "true",
    "yes",
)

# Resume support (optional, for overnight robustness)
# If set, we will append to this eval file and skip example_ids already present.
RESUME_EVAL_PATH: str = os.getenv("RESUME_EVAL_PATH", "").strip()

# Safety: stop after too many consecutive failures (default: never stop)
MAX_CONSECUTIVE_FAILURES: int = int(os.getenv("MAX_CONSECUTIVE_FAILURES", "0").strip() or "0")

# Ground truth policy:
# - At eval-time, humans are not required yet. We never use a third model.
# - After labeling, merge step will override final_label with human_label.
HUMAN_AS_GROUND_TRUTH: bool = True  # always true by policy

# Canonical labels (single source of truth should be config.ALLOWED_LABELS)
ALLOWED_LABELS: List[str] = list(
    getattr(
        config,
        "ALLOWED_LABELS",
        ["PASS", "CS-SYC", "AUTH-SYC", "FACE-SYC", "DIR-SYC", "EVADE"],
    )
)
DISAGREE_LABEL: str = getattr(config, "DISAGREE_LABEL", "DISAGREE")


def _validate_label(lbl: str, where: str) -> str:
    if not isinstance(lbl, str):
        raise ValueError(f"{where}: label is not a string: {lbl!r}")
    lbl = lbl.strip()
    if lbl not in ALLOWED_LABELS:
        raise ValueError(f"{where}: label not in ALLOWED_LABELS: {lbl!r}")
    return lbl


def read_jsonl(path: Path) -> List[Dict]:
    out = []
    with path.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if line:
                out.append(json.loads(line))
    return out


def _iter_jsonl(path: Path):
    with path.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if line:
                yield json.loads(line)


def _to_openai_transcript(messages: List[Dict]) -> str:
    parts = []
    for m in messages:
        role = m["role"].strip().upper()
        parts.append(f"{role}: {m['content']}")
    return "\n\n".join(parts).strip()


def _to_anthropic_messages(messages: List[Dict]) -> List[Dict[str, str]]:
    out: List[Dict[str, str]] = []
    for m in messages:
        role = m["role"].strip().lower()
        content = m["content"]
        if role in ("student", "user"):
            out.append({"role": "user", "content": content})
        elif role in ("tutor", "assistant"):
            out.append({"role": "assistant", "content": content})
        else:
            out.append({"role": "user", "content": content})
    return out


def ensure_example_id(dialogue_id: str, tutor_vendor: str, tutor_model: str) -> str:
    return f"{dialogue_id}::{tutor_vendor}::{tutor_model}"


def call_tutor(
    openai: OpenAIClient,
    anthropic: AnthropicClient,
    tutor,
    messages: List[Dict],
    max_tokens: int,
) -> str:
    """
    Tutor mode:
      - baseline: no system/instructions prompt (natural chat behavior)
      - asp: apply TUTOR_INSTRUCTIONS (intervention)
    """
    tutor_mode = getattr(config, "TUTOR_MODE", "baseline").strip().lower()
    use_system = tutor_mode != "baseline"

    if tutor.vendor == "openai":
        prompt = _to_openai_transcript(messages)
        return openai.text(
            model=tutor.model,
            instructions=(TUTOR_INSTRUCTIONS if use_system else None),
            user_input=prompt,
            max_tokens=max_tokens,
            temperature=getattr(config, "TUTOR_TEMP", None),
            retries=getattr(config, "MAX_RETRIES", 3),
        )
    else:
        amsgs = _to_anthropic_messages(messages)
        return anthropic.text_messages(
            model=tutor.model,
            system=(TUTOR_INSTRUCTIONS if use_system else ""),
            messages=amsgs,
            max_tokens=max_tokens,
            temperature=getattr(config, "TUTOR_TEMP", None),
            retries=getattr(config, "MAX_RETRIES", 3),
        )


def judge_once(
    openai: OpenAIClient,
    anthropic: AnthropicClient,
    judge,
    trap: Dict,
    s1: str,
    t1: str,
    s2: str,
    t2: str,
) -> Dict:
    prompt = judge_user_prompt(trap, s1, t1, s2, t2)

    if judge.vendor == "openai":
        return openai.json_object(
            model=judge.model,
            instructions=judge_instructions(),
            user_input=prompt,
            max_tokens=judge.max_tokens,
            temperature=getattr(config, "JUDGE_TEMP", None),
            retries=getattr(config, "MAX_RETRIES", 3),
        )
    else:
        return anthropic.json_object(
            model=judge.model,
            system=judge_instructions(),
            user_input=prompt,
            max_tokens=judge.max_tokens,
            temperature=getattr(config, "JUDGE_TEMP", None),
            retries=getattr(config, "MAX_RETRIES", 3),
        )


def load_human_labels(path_str: str) -> Dict[str, str]:
    """
    Optional: load existing human labels mapping for use during eval run.
    Preferred key: example_id
    Legacy key: dialogue_id (discouraged; may be ambiguous across tutors)
    """
    if not path_str:
        return {}
    p = Path(path_str)
    if not p.exists():
        raise FileNotFoundError(f"HUMAN_LABELS_PATH not found: {p}")

    mapping: Dict[str, str] = {}
    with p.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            row = json.loads(line)
            exid = (row.get("example_id") or "").strip()
            did = (row.get("dialogue_id") or "").strip()
            lbl = (row.get("human_label") or row.get("label") or row.get("final_label") or "").strip()
            if not lbl:
                continue
            # Drop invalid labels early (protects downstream merge/eval)
            if lbl not in ALLOWED_LABELS:
                continue
            if exid:
                mapping[exid] = lbl
            elif did:
                mapping[did] = lbl
    return mapping


def write_jsonl_row(path: Path, row: Dict[str, Any]) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("a", encoding="utf-8") as f:
        f.write(json.dumps(row, ensure_ascii=False) + "\n")


def _derive_human_packets_path(eval_path: Path) -> Path:
    name = eval_path.name
    if name.startswith("eval_"):
        name = "human_packets_" + name[len("eval_") :]
    else:
        name = eval_path.stem + "_human_packets.jsonl"
    return eval_path.with_name(name)


def _load_completed_example_ids(eval_path: Path) -> Set[str]:
    completed: Set[str] = set()
    if not eval_path.exists():
        return completed
    try:
        for row in _iter_jsonl(eval_path):
            exid = row.get("example_id")
            if isinstance(exid, str) and exid.strip():
                completed.add(exid.strip())
    except Exception:
        # If an interrupted write left a bad last line, ignore and proceed.
        pass
    return completed


def _safe_error_record(
    split_name: str,
    run_id: str,
    d: Dict[str, Any],
    tutor_vendor: str,
    tutor_model: str,
    stage: str,
    err: Exception,
) -> Dict[str, Any]:
    return {
        "split": split_name,
        "run_tag": RUN_TAG,
        "tutor_mode": getattr(config, "TUTOR_MODE", "baseline"),
        "dataset_id": d.get("dataset_id", ""),
        "dialogue_id": d.get("dialogue_id", ""),
        "trap_id": d.get("trap_id", ""),
        "domain": d.get("domain", ""),
        "topic": d.get("topic", ""),
        "confidence": d.get("confidence", ""),
        "pressure_mode": d.get("pressure_mode", ""),
        "example_id": ensure_example_id(str(d.get("dialogue_id", "")), tutor_vendor, tutor_model),
        "tutor_vendor": tutor_vendor,
        "tutor_model": tutor_model,
        "error_stage": stage,
        "error_type": type(err).__name__,
        "error_message": str(err),
        "timestamp_utc": datetime.utcnow().isoformat(),
        "final_label": "ERROR",
        "final_label_source": "exception",
        "run_id": run_id,
    }


def eval_split(
    split_name: str,
    dataset_path: Path,
    openai: OpenAIClient,
    anthropic: AnthropicClient,
    human_labels: Dict[str, str],
) -> None:
    if not dataset_path.exists():
        raise FileNotFoundError(
            f"{split_name} dataset not found: {dataset_path} (run build_dataset.py first)"
        )

    dialogs = read_jsonl(dataset_path)

    tutors = TUTORS_OVERRIDE if TUTORS_OVERRIDE is not None else config.EVALUATED_TUTORS
    if not tutors:
        raise ValueError("No tutors configured. Set config.EVALUATED_TUTORS or TUTORS_OVERRIDE.")

    # Robustness: enforce exactly two tutors unless you intentionally override
    if len(tutors) != 2:
        raise ValueError(
            f"Expected exactly 2 tutors, got {len(tutors)}. Tutors={[(t.vendor, t.model) for t in tutors]!r}"
        )

    # Output paths (supports resume)
    if RESUME_EVAL_PATH:
        out_path = Path(RESUME_EVAL_PATH).expanduser().resolve()
        run_id = out_path.stem.replace("eval_", "")
        mode = "a"
        completed = _load_completed_example_ids(out_path)
        human_packets_path = _derive_human_packets_path(out_path)
        print(f"\nRESUME mode: appending to {out_path} (completed={len(completed)})")
    else:
        run_id = datetime.utcnow().strftime("%Y%m%d_%H%M%S") + "_" + str(uuid.uuid4())[:8]
        out_path = config.RUNS_DIR / f"eval_{split_name}_{RUN_TAG}_{run_id}.jsonl"
        human_packets_path = config.RUNS_DIR / f"human_packets_{split_name}_{RUN_TAG}_{run_id}.jsonl"
        mode = "w"
        completed = set()

    total = len(dialogs) * len(tutors)
    pbar = tqdm(total=total, desc=f"Evaluating ({split_name}, {RUN_TAG})")

    consecutive_failures = 0

    with out_path.open(mode, encoding="utf-8") as out:
        for d in dialogs:
            trap = {
                "misconception": d["misconception"],
                "standard_truth": d["standard_truth"],
                "obscure_context": d["obscure_context"],
                "pressure_mode": d.get("pressure_mode", "context_switch"),
            }

            s1 = d["student_turn1"]
            s2 = d["student_turn2"]

            for tutor in tutors:
                time.sleep(getattr(config, "SLEEP_BETWEEN_CALLS_SEC", 0.4))

                example_id = ensure_example_id(d["dialogue_id"], tutor.vendor, tutor.model)
                if example_id in completed:
                    pbar.update(1)
                    continue

                try:
                    # Turn 1
                    msgs1 = [{"role": "student", "content": s1}]
                    t1 = call_tutor(openai, anthropic, tutor, msgs1, tutor.max_tokens)

                    # Turn 2
                    msgs2 = [
                        {"role": "student", "content": s1},
                        {"role": "tutor", "content": t1},
                        {"role": "student", "content": s2},
                    ]
                    t2 = call_tutor(openai, anthropic, tutor, msgs2, tutor.max_tokens)

                    # Two judges only (NO third model ever)
                    ja = judge_once(openai, anthropic, config.JUDGE_A, trap, s1, t1, s2, t2)
                    jb = judge_once(openai, anthropic, config.JUDGE_B, trap, s1, t1, s2, t2)

                    for j in (ja, jb):
                        require_keys(j, ["label", "evidence_quotes", "rationale"], "judge output")
                        j["label"] = _validate_label(j["label"], "judge output")

                    disagreement = ja["label"] != jb["label"]

                    if disagreement:
                        final = DISAGREE_LABEL
                        final_source = "judge_disagreement"
                    else:
                        final = ja["label"]
                        final_source = "judges_agree"

                    # Optional: embed existing human label
                    hl = human_labels.get(example_id)
                    if HUMAN_AS_GROUND_TRUTH and hl:
                        final = _validate_label(hl, "human label")
                        final_source = "human"

                    # write annotation packet (for label_app / human labeling)
                    if WRITE_HUMAN_PACKETS_FOR_ALL:
                        packet = {
                            "example_id": example_id,
                            "run_tag": RUN_TAG,
                            "split": split_name,
                            "tutor_mode": getattr(config, "TUTOR_MODE", "baseline"),
                            "dataset_id": d.get("dataset_id", ""),
                            "dialogue_id": d["dialogue_id"],
                            "trap_id": d["trap_id"],
                            "domain": d["domain"],
                            "topic": d.get("topic", ""),
                            "confidence": d["confidence"],
                            "pressure_mode": d.get("pressure_mode", "context_switch"),
                            "misconception": d["misconception"],
                            "standard_truth": d["standard_truth"],
                            "obscure_context": d["obscure_context"],
                            "student_turn1": s1,
                            "student_turn2": s2,
                            "tutor_turn1": t1,
                            "tutor_turn2": t2,
                            "tutor_vendor": tutor.vendor,
                            "tutor_model": tutor.model,
                            "judge_a": {"vendor": config.JUDGE_A.vendor, "model": config.JUDGE_A.model, **ja},
                            "judge_b": {"vendor": config.JUDGE_B.vendor, "model": config.JUDGE_B.model, **jb},
                            "human_label": "",
                            "notes": "",
                        }
                        write_jsonl_row(human_packets_path, packet)

                    record = {
                        "split": split_name,
                        "run_tag": RUN_TAG,
                        "tutor_mode": getattr(config, "TUTOR_MODE", "baseline"),
                        "dataset_id": d.get("dataset_id", ""),
                        "dialogue_id": d["dialogue_id"],
                        "trap_id": d["trap_id"],
                        "domain": d["domain"],
                        "topic": d.get("topic", ""),
                        "confidence": d["confidence"],
                        "pressure_mode": d.get("pressure_mode", "context_switch"),
                        "example_id": example_id,
                        "human_label": (hl or ""),
                        "student_turn1": s1,
                        "student_turn2": s2,
                        "tutor_turn1": t1,
                        "tutor_turn2": t2,
                        "tutor_vendor": tutor.vendor,
                        "tutor_model": tutor.model,
                        "judge_a": {"vendor": config.JUDGE_A.vendor, "model": config.JUDGE_A.model, **ja},
                        "judge_b": {"vendor": config.JUDGE_B.vendor, "model": config.JUDGE_B.model, **jb},
                        "tie_break": None,
                        "disagreement": disagreement,
                        "final_label": final,
                        "final_label_source": final_source,
                        "timestamp_utc": datetime.utcnow().isoformat(),
                        "run_id": run_id,
                    }

                    out.write(json.dumps(record, ensure_ascii=False) + "\n")
                    out.flush()
                    completed.add(example_id)

                    consecutive_failures = 0

                except Exception as e:
                    consecutive_failures += 1
                    err_row = _safe_error_record(
                        split_name=split_name,
                        run_id=run_id,
                        d=d,
                        tutor_vendor=tutor.vendor,
                        tutor_model=tutor.model,
                        stage="eval_loop",
                        err=e,
                    )
                    out.write(json.dumps(err_row, ensure_ascii=False) + "\n")
                    out.flush()

                    print(f"\n⚠️ Error on {example_id} ({tutor.vendor}/{tutor.model}): {type(e).__name__}: {e}")

                    if MAX_CONSECUTIVE_FAILURES and consecutive_failures >= MAX_CONSECUTIVE_FAILURES:
                        raise RuntimeError(
                            f"Stopping after {consecutive_failures} consecutive failures. "
                            f"Last error: {type(e).__name__}: {e}"
                        ) from e

                finally:
                    pbar.update(1)

    pbar.close()
    print(f"\nWrote results: {out_path}")
    if WRITE_HUMAN_PACKETS_FOR_ALL and human_packets_path.exists():
        print(f"Wrote human packets: {human_packets_path}")


def main():
    config.RUNS_DIR.mkdir(parents=True, exist_ok=True)

    # Fail fast if keys missing (helps overnight runs)
    if not getattr(config, "OPENAI_API_KEY", ""):
        print("⚠️ OPENAI_API_KEY is empty in config/.env (OpenAI calls may fail).")
    if not getattr(config, "ANTHROPIC_API_KEY", ""):
        print("⚠️ ANTHROPIC_API_KEY is empty in config/.env (Anthropic calls may fail).")

    openai = OpenAIClient()
    anthropic = AnthropicClient()

    human_labels = load_human_labels(HUMAN_LABELS_PATH) if HUMAN_LABELS_PATH else {}

    splits = [s.strip().lower() for s in EVAL_SPLITS.split(",") if s.strip()]
    for split in splits:
        if split == "dev":
            eval_split("dev", config.DEV_JSONL, openai, anthropic, human_labels)
        elif split == "test":
            eval_split("test", config.TEST_JSONL, openai, anthropic, human_labels)
        else:
            raise ValueError(f"Unknown split in EVAL_SPLITS: {split}")


if __name__ == "__main__":
    main()

