#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Validate Memory triplets generated by gen_memory_triplets.py.
Hard checks:
- 3 lines per group (full -> nocue -> cf)
- identity: task/env_id/seed/group_id
- action parity: actions_id identical
- length parity: frames/state_seq == T+1
- physics consistency: agent.pos/dir identical across variants
- outcomes: full succeed reward>0; cf fail reward==0; nocue matches full
- NoCue audit: diffs only in cue-tiles for window_steps; no diff elsewhere
- nocue_meta schema aligned with NoCue_spec
"""

from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Any, Dict, List, Tuple

import numpy as np
from PIL import Image

from minigrid.core.constants import OBJECT_TO_IDX, COLOR_TO_IDX

ALLOWED_INPUTS = ["frames", "mission", "terminated", "truncated"]
MASK_RGB = (0, 0, 0)

def _die(msg: str) -> None:
    raise SystemExit(msg)

def _load_png(path: Path) -> np.ndarray:
    with Image.open(path) as im:
        return np.asarray(im.convert("RGB"))

def _cue_tile_mask(state_encoding: List[List[List[int]]], cue_type: str, cue_color: str) -> np.ndarray:
    enc = np.asarray(state_encoding, dtype=np.uint8)  # (7,7,3)
    obj = OBJECT_TO_IDX.get(cue_type, None)
    col = COLOR_TO_IDX.get(cue_color, None)
    if obj is None or col is None:
        return np.zeros((7, 7), dtype=bool)
    return (enc[:, :, 0] == int(obj)) & (enc[:, :, 1] == int(col))

def _pixel_mask_from_tile_mask(tile_mask: np.ndarray, frame_h: int, frame_w: int) -> np.ndarray:
    if frame_h % 7 != 0 or frame_w % 7 != 0:
        _die(f"Frame shape not divisible by 7: {(frame_h, frame_w)}")
    th, tw = frame_h // 7, frame_w // 7
    return np.kron(tile_mask.astype(np.uint8), np.ones((th, tw), dtype=np.uint8)).astype(bool)

def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument("--root", type=str, required=True, help="dataset root (where triplets.jsonl exists)")
    ap.add_argument("--max-groups", type=int, default=-1)
    args = ap.parse_args()

    root = Path(args.root).resolve()
    jsonl = root / "triplets.jsonl"
    if not jsonl.exists():
        _die(f"Not found: {jsonl}")

    rows: List[Dict[str, Any]] = []
    with jsonl.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            rows.append(json.loads(line))

    if len(rows) % 3 != 0:
        _die(f"JSONL lines must be multiple of 3, got {len(rows)}")

    n_groups = len(rows) // 3
    if args.max_groups > 0:
        n_groups = min(n_groups, args.max_groups)

    for gi in range(n_groups):
        full = rows[3 * gi + 0]
        nocue = rows[3 * gi + 1]
        cf = rows[3 * gi + 2]

        gid = full.get("group_id", f"<idx={gi}>")

        # ---- order ----
        if full.get("variant") != "full" or nocue.get("variant") != "nocue" or cf.get("variant") != "cf":
            _die(f"[{gid}] variant order must be full/nocue/cf")

        # ---- identity ----
        for k in ("task", "env_id", "seed", "group_id"):
            if not (full.get(k) == nocue.get(k) == cf.get(k)):
                _die(f"[{gid}] identity mismatch on {k}")

        if full["task"] != "memory":
            _die(f"[{gid}] task != memory: {full['task']}")

        # ---- model_input_fields ----
        for rec, name in ((full, "full"), (nocue, "nocue"), (cf, "cf")):
            mif = rec.get("model_input_fields", None)
            if mif != ALLOWED_INPUTS:
                _die(f"[{gid}:{name}] model_input_fields must be {ALLOWED_INPUTS}, got {mif}")

        # ---- actions parity ----
        if full.get("actions_id") != nocue.get("actions_id") or full.get("actions_id") != cf.get("actions_id"):
            _die(f"[{gid}] actions_id mismatch across variants")

        T = len(full["actions_id"])

        # ---- lengths ----
        for rec, name in ((full, "full"), (nocue, "nocue"), (cf, "cf")):
            if len(rec.get("frames", [])) != T + 1:
                _die(f"[{gid}:{name}] frames len != T+1")
            if len(rec.get("state_seq", [])) != T + 1:
                _die(f"[{gid}:{name}] state_seq len != T+1")

        # ---- physics consistency ----
        for t in range(T + 1):
            for key in ("pos", "dir"):
                a = full["state_seq"][t]["agent"][key]
                b = nocue["state_seq"][t]["agent"][key]
                c = cf["state_seq"][t]["agent"][key]
                if a != b or a != c:
                    _die(f"[{gid}] physics mismatch at t={t} agent.{key}")

        # ---- outcomes ----
        if not (full["success"] is True and float(full["reward"]) > 0.0):
            _die(f"[{gid}] full must succeed with reward>0")
        if not (cf["success"] is False and float(cf["reward"]) == 0.0):
            _die(f"[{gid}] cf must fail with reward==0")
        if not (nocue["success"] == full["success"] and float(nocue["reward"]) == float(full["reward"])):
            _die(f"[{gid}] nocue outcome must match full")

        # ---- NoCue meta schema ----
        meta = nocue.get("nocue_meta", None)
        if not isinstance(meta, dict):
            _die(f"[{gid}] nocue missing nocue_meta dict")

        required_meta = [
            "targets", "window_policy", "window_steps",
            "mask_strength_target", "mask_strength_actual",
            "alignment_score", "alignment_threshold",
            "masked_frames", "mask_type", "physics_check_passed",
        ]
        for k in required_meta:
            if k not in meta:
                _die(f"[{gid}] nocue_meta missing {k}")

        if meta["window_policy"] != "EARLY":
            _die(f"[{gid}] window_policy must be EARLY")

        window_steps = [int(x) for x in meta["window_steps"]]
        wset = set(window_steps)
        if len(wset) != len(window_steps):
            _die(f"[{gid}] window_steps contains duplicates")
        if any(t < 0 or t > T for t in window_steps):
            _die(f"[{gid}] window_steps out of range")

        # We need cue type+color for audit; stored in meta['cue'] by generator
        cue = meta.get("cue", {})
        cue_type = cue.get("type", meta["targets"][0] if meta["targets"] else None)
        cue_color = cue.get("color", None)
        if cue_type is None or cue_color is None:
            _die(f"[{gid}] nocue_meta must include cue.type and cue.color for color-safe audit")

        # ---- NoCue image audits ----
        # 1) Unmasked frames must be identical
        for t in range(T + 1):
            if t in wset:
                continue
            f_img = _load_png(root / full["frames"][t])
            n_img = _load_png(root / nocue["frames"][t])
            if not np.array_equal(f_img, n_img):
                _die(f"[{gid}] NoCue illegal diff at unmasked t={t}")

        # 2) After window_end, must be identical (no late leakage)
        if window_steps:
            window_end = max(window_steps)
            for t in range(window_end + 1, T + 1):
                f_img = _load_png(root / full["frames"][t])
                n_img = _load_png(root / nocue["frames"][t])
                if not np.array_equal(f_img, n_img):
                    _die(f"[{gid}] NoCue leakage after window_end at t={t}")

        # 3) Masked frames: diff subset of cue tiles (tile-level -> pixel-level)
        for t in sorted(wset):
            f_img = _load_png(root / full["frames"][t])
            n_img = _load_png(root / nocue["frames"][t])
            if f_img.shape != n_img.shape:
                _die(f"[{gid}] frame shape mismatch at masked t={t}")

            st = full["state_seq"][t]
            tile_mask = _cue_tile_mask(st["state_encoding"], cue_type, cue_color)
            if tile_mask.sum() <= 0:
                _die(f"[{gid}] masked t={t}: cue tile mask empty (inconsistent)")

            pix_mask = _pixel_mask_from_tile_mask(tile_mask, f_img.shape[0], f_img.shape[1])

            # diff must not appear outside pix_mask
            diff = np.any(f_img != n_img, axis=2)
            if np.any(diff & (~pix_mask)):
                _die(f"[{gid}] masked t={t}: diff leaks outside cue tiles")

            # diff must actually happen inside pix_mask
            if not np.any(diff & pix_mask):
                _die(f"[{gid}] masked t={t}: no diff inside cue tiles")

        print(f"[OK] {gid}")

    print(f"[DONE] validated {n_groups} groups")

if __name__ == "__main__":
    main()
