# build_dataset.py
import json
import uuid
import hashlib
import random
from datetime import datetime, timezone
from pathlib import Path
from typing import Dict, Any, List

from tqdm import tqdm

import config
from llm_openai import OpenAIClient
from llm_anthropic import AnthropicClient
from prompts import (
    builder_instructions, builder_user_prompt,
    validator_instructions, validator_user_prompt,
    make_student_turn1, make_student_turn2,
    student_logic_ok,
    PRESSURE_MODES,
)
from utils import set_seed, require_keys


def utc_now_iso():
    return datetime.now(timezone.utc).isoformat()


def append_jsonl(path: Path, row: Dict[str, Any]):
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("a", encoding="utf-8") as f:
        f.write(json.dumps(row, ensure_ascii=False) + "\n")


def write_jsonl(path: Path, rows: List[Dict[str, Any]]):
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", encoding="utf-8") as f:
        for r in rows:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")


def _stable_int(s: str) -> int:
    """
    Stable deterministic int from a string (for prompt template selection).
    """
    h = hashlib.sha256(s.encode("utf-8")).hexdigest()
    return int(h[:12], 16)


def _coerce_str(x: Any) -> str:
    if x is None:
        return ""
    return str(x).strip()


def main():
    set_seed(config.RANDOM_SEED)

    # Paths
    config.DATA_DIR.mkdir(parents=True, exist_ok=True)
    config.RUNS_DIR.mkdir(parents=True, exist_ok=True)
    config.AUDIT_DIR.mkdir(parents=True, exist_ok=True)

    dataset_id = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
    audit_path = config.AUDIT_DIR / f"audit_build_{dataset_id}.jsonl"
    batch_id = str(uuid.uuid4())[:8]

    openai = OpenAIClient()
    anthropic = AnthropicClient()

    # How many trap families per domain?
    per_domain = max(1, config.N_TRAP_FAMILIES_TOTAL // max(1, len(config.DOMAINS)))

    accepted_traps: List[Dict[str, Any]] = []

    for domain in config.DOMAINS:
        need = per_domain

        pbar = tqdm(total=need, desc=f"Traps accepted: {domain}")
        while need > 0:
            seed = _stable_int(f"{dataset_id}:{domain}:{need}:{len(accepted_traps)}") % 10_000_000

            cand = openai.json_object(
                model=config.BUILDER.model,
                instructions=builder_instructions(),
                user_input=builder_user_prompt(domain, seed),
                max_tokens=config.BUILDER.max_tokens,
                temperature=getattr(config, "OPENAI_TEMP_BUILDER", 0.2),
                retries=config.MAX_RETRIES,
            )

            require_keys(cand, ["traps"], "builder output")
            if not isinstance(cand["traps"], list) or not cand["traps"]:
                append_jsonl(audit_path, {
                    "type": "builder_bad_format",
                    "timestamp_utc": utc_now_iso(),
                    "dataset_id": dataset_id,
                    "domain": domain,
                    "batch_id": batch_id,
                    "candidate": cand,
                })
                continue

            trap0 = cand["traps"][0]

            for k in ["domain", "topic", "misconception", "standard_truth", "obscure_context", "student_logic"]:
                if k not in trap0:
                    append_jsonl(audit_path, {
                        "type": "builder_missing_key",
                        "timestamp_utc": utc_now_iso(),
                        "dataset_id": dataset_id,
                        "domain": domain,
                        "batch_id": batch_id,
                        "missing_key": k,
                        "candidate": trap0,
                    })
                    break
            else:
                if not student_logic_ok(str(trap0.get("student_logic", ""))):
                    append_jsonl(audit_path, {
                        "type": "builder_bad_student_logic",
                        "timestamp_utc": utc_now_iso(),
                        "dataset_id": dataset_id,
                        "domain": domain,
                        "batch_id": batch_id,
                        "candidate": trap0,
                    })
                    continue

                # Validate/normalize using configured VALIDATOR (json_object)
                validator = config.VALIDATOR
                if validator.vendor == "openai":
                    verdict = openai.json_object(
                        model=validator.model,
                        instructions=validator_instructions(),
                        user_input=validator_user_prompt(trap0),
                        max_tokens=validator.max_tokens,
                        temperature=config.VALIDATOR_TEMP,
                        retries=config.MAX_RETRIES,
                    )
                else:
                    verdict = anthropic.json_object(
                        model=validator.model,
                        system=validator_instructions(),
                        user_input=validator_user_prompt(trap0),
                        max_tokens=validator.max_tokens,
                        temperature=config.VALIDATOR_TEMP,
                        retries=config.MAX_RETRIES,
                    )
                require_keys(verdict, ["is_valid", "reason", "normalized"], "validator output")

                append_jsonl(audit_path, {
                    "type": "validator_verdict",
                    "timestamp_utc": utc_now_iso(),
                    "dataset_id": dataset_id,
                    "domain": domain,
                    "batch_id": batch_id,
                    "validator_vendor": validator.vendor,
                    "validator_model": validator.model,
                    "candidate": trap0,
                    "verdict": verdict,
                })

                if not verdict["is_valid"]:
                    continue

                # ---------------------------
                # FIX: make normalization robust.
                # Some validators return a partial `normalized` object (e.g., omitting
                # obscure_context / student_logic). We safely copy-through from the
                # original candidate when missing/blank.
                # ---------------------------
                norm = verdict.get("normalized") or {}
                if not isinstance(norm, dict):
                    append_jsonl(audit_path, {
                        "type": "validator_bad_normalized_type",
                        "timestamp_utc": utc_now_iso(),
                        "dataset_id": dataset_id,
                        "domain": domain,
                        "batch_id": batch_id,
                        "candidate": trap0,
                        "normalized": norm,
                    })
                    continue

                REQUIRED = ["topic", "misconception", "standard_truth", "obscure_context", "student_logic"]
                missing_before = [k for k in REQUIRED if k not in norm or not _coerce_str(norm.get(k))]

                # Copy-through missing/empty fields from the original candidate trap
                for k in REQUIRED:
                    if k not in norm or not _coerce_str(norm.get(k)):
                        norm[k] = trap0.get(k)

                missing_after = [k for k in REQUIRED if k not in norm or not _coerce_str(norm.get(k))]

                if missing_before:
                    append_jsonl(audit_path, {
                        "type": "validator_partial_normalized_filled",
                        "timestamp_utc": utc_now_iso(),
                        "dataset_id": dataset_id,
                        "domain": domain,
                        "batch_id": batch_id,
                        "missing_before": missing_before,
                        "missing_after": missing_after,
                        "candidate": trap0,
                        "normalized_after_fill": norm,
                    })

                # If still missing, treat as invalid and continue (do not crash)
                if missing_after:
                    append_jsonl(audit_path, {
                        "type": "validator_missing_keys_after_fill",
                        "timestamp_utc": utc_now_iso(),
                        "dataset_id": dataset_id,
                        "domain": domain,
                        "batch_id": batch_id,
                        "missing": missing_after,
                        "candidate": trap0,
                        "normalized": norm,
                    })
                    continue

                # Now safe to require keys
                require_keys(norm, REQUIRED, "normalized trap")

                # Extra safety: student_logic quality after normalization
                if not student_logic_ok(_coerce_str(norm.get("student_logic"))):
                    append_jsonl(audit_path, {
                        "type": "validator_bad_student_logic_after_norm",
                        "timestamp_utc": utc_now_iso(),
                        "dataset_id": dataset_id,
                        "domain": domain,
                        "batch_id": batch_id,
                        "candidate": trap0,
                        "normalized": norm,
                    })
                    continue

                trap_id = str(uuid.uuid4())
                accepted_traps.append({
                    "dataset_id": dataset_id,
                    "trap_id": trap_id,
                    "domain": domain,
                    "topic": _coerce_str(norm["topic"]),
                    "misconception": _coerce_str(norm["misconception"]),
                    "standard_truth": _coerce_str(norm["standard_truth"]),
                    "obscure_context": _coerce_str(norm["obscure_context"]),
                    "student_logic": _coerce_str(norm["student_logic"]),
                    "source_batch_id": batch_id,
                })

                need -= 1
                pbar.update(1)
                if need <= 0:
                    break

        pbar.close()

    # Expand traps into dialogues (confidence variants × pressure modes)
    dialogues: List[Dict[str, Any]] = []
    for trap in accepted_traps:
        for c in config.CONFIDENCE_LEVELS:
            for mode in PRESSURE_MODES:
                dialogue_id = str(uuid.uuid4())

                # deterministic variant for turn2 so your dataset is reproducible but not uniform
                variant_idx = _stable_int(trap["trap_id"] + f":{c}:{mode}") % 10

                dialogues.append({
                    "dataset_id": dataset_id,
                    "dialogue_id": dialogue_id,
                    "trap_id": trap["trap_id"],
                    "domain": trap["domain"],
                    "topic": trap["topic"],
                    "confidence": c,
                    "pressure_mode": mode,
                    "student_turn1": make_student_turn1(trap, c),
                    "student_turn2": make_student_turn2(trap, variant_idx=variant_idx, mode=mode),
                    "misconception": trap["misconception"],
                    "standard_truth": trap["standard_truth"],
                    "obscure_context": trap["obscure_context"],
                    "source_batch_id": trap["source_batch_id"],
                })

    # ------------------------------------------------------------------
    # Robust dev/test split (trap-family level; reproducible; stratified)
    #
    # - Splits by trap_id, so no leakage across confidence/pressure variants.
    # - Stratifies within each domain to avoid domain imbalance in dev/test.
    # - Uses RANDOM_SEED for reproducibility.
    # - Avoids empty dev/test per domain when possible.
    # ------------------------------------------------------------------
    rng = random.Random(config.RANDOM_SEED)

    traps_by_domain: Dict[str, List[str]] = {}
    for t in accepted_traps:
        traps_by_domain.setdefault(t["domain"], []).append(t["trap_id"])

    dev_traps: set[str] = set()
    test_traps: set[str] = set()

    for domain, ids in traps_by_domain.items():
        ids = list(sorted(set(ids)))
        rng.shuffle(ids)

        n = len(ids)
        cutoff = int(n * config.DEV_FRACTION)

        if n >= 2:
            cutoff = max(1, min(n - 1, cutoff))  # ensure at least 1 dev and 1 test per domain
        else:
            cutoff = 1 if config.DEV_FRACTION >= 0.5 else 0  # if only 1 trap, assign based on fraction

        dev_traps.update(ids[:cutoff])
        test_traps.update(ids[cutoff:])

    dev = [d for d in dialogues if d["trap_id"] in dev_traps]
    test = [d for d in dialogues if d["trap_id"] in test_traps]

    write_jsonl(config.DEV_JSONL, dev)
    write_jsonl(config.TEST_JSONL, test)
    write_jsonl(config.TRAPS_JSONL, accepted_traps)

    # Summary: trap families and dialogues per domain in dev/test (quick sanity check)
    print("\nSplit summary by domain:")
    for domain in sorted(traps_by_domain.keys()):
        dom_trap_ids = set(traps_by_domain[domain])
        dev_fams = sum(1 for tid in dom_trap_ids if tid in dev_traps)
        test_fams = sum(1 for tid in dom_trap_ids if tid in test_traps)
        dev_dialogs = sum(1 for d in dev if d.get("domain") == domain)
        test_dialogs = sum(1 for d in test if d.get("domain") == domain)
        total_fams = len(dom_trap_ids)
        total_dialogs = dev_dialogs + test_dialogs
        print(f"  {domain}: families dev={dev_fams}, test={test_fams}, total={total_fams} | dialogues dev={dev_dialogs}, test={test_dialogs}, total={total_dialogs}")

    print("\nDONE")
    print(f"dataset_id:  {dataset_id}")
    print(f"Wrote traps: {config.TRAPS_JSONL} ({len(accepted_traps)} trap families)")
    print(f"Wrote dev:   {config.DEV_JSONL} ({len(dev)} dialogues)")
    print(f"Wrote test:  {config.TEST_JSONL} ({len(test)} dialogues)")


if __name__ == "__main__":
    main()
