#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Validate KeyCorridor triplets generated by gen_keycorridor_triplets.py.
Hard checks:
- 3 lines per group (full -> nocue -> cf)
- identity: task/env_id/seed/group_id
- action parity: actions_id identical
- length parity: frames/state_seq == T+1
- physics consistency: agent.pos/dir identical across variants
- outcomes: full succeed reward>0; cf fail reward==0; nocue matches full
- NoCue audit: diffs only in key-tiles for window_steps; no diff elsewhere
- nocue_meta schema aligned with NoCue_spec
"""

from __future__ import annotations

import argparse
import json
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple, Optional

import numpy as np
from PIL import Image

from minigrid.core.constants import OBJECT_TO_IDX, IDX_TO_OBJECT, COLOR_TO_IDX

ALLOWED_INPUTS = {"frames", "mission", "terminated", "truncated"}
PROHIBITED_INPUTS = {"state_seq", "actions_id", "success", "reward", "cf_meta", "nocue_meta", "seed"}
MASK_RGB = (0, 0, 0)

def _die(msg: str) -> None:
    raise SystemExit(msg)

def _load_png(path: Path) -> np.ndarray:
    with Image.open(path) as im:
        return np.asarray(im.convert("RGB"))

def _cue_tile_mask(state_encoding: List[List[List[int]]], cue_type: str, cue_color: str) -> np.ndarray:
    enc = np.asarray(state_encoding, dtype=np.uint8)  # (7,7,3)
    obj = OBJECT_TO_IDX.get(cue_type, None)
    col = COLOR_TO_IDX.get(cue_color, None)
    if obj is None or col is None:
        return np.zeros((7, 7), dtype=bool)
    return (enc[:, :, 0] == int(obj)) & (enc[:, :, 1] == int(col))

def _pixel_mask_from_tile_mask(tile_mask: np.ndarray, frame_h: int, frame_w: int) -> np.ndarray:
    if frame_h % 7 != 0 or frame_w % 7 != 0:
        _die(f"Frame shape not divisible by 7: {(frame_h, frame_w)}")
    th, tw = frame_h // 7, frame_w // 7
    return np.kron(tile_mask.astype(np.uint8), np.ones((th, tw), dtype=np.uint8)).astype(bool)

def validate_state_encoding(enc) -> bool:
    if enc is None:
        return False
    arr = np.array(enc)
    if len(arr.shape) != 3 or arr.shape[2] != 3:
        return False
    if not np.issubdtype(arr.dtype, np.integer):
        return False
    H, W, _ = arr.shape
    for i in range(H):
        for j in range(W):
            obj_idx, _, state = map(int, arr[i, j])
            if obj_idx not in IDX_TO_OBJECT:
                return False
            name = IDX_TO_OBJECT[obj_idx]
            if name == "door" and state not in (0, 1, 2):
                return False
            if name not in ("door", "empty", "unseen") and state != 0:
                return False
    return True

def get_tile_mask(enc, target_name: str = "key") -> np.ndarray:
    arr = np.array(enc)
    H, W, _ = arr.shape
    mask = np.zeros((H, W), dtype=bool)
    for i in range(H):
        for j in range(W):
            if OBJECT_TO_IDX.get(int(arr[i, j, 0])) == target_name:
                mask[i, j] = True
    return mask

def _agent_traj(rec: Dict) -> List[Tuple[Tuple[int, int], int]]:
    traj = []
    for st in rec["state_seq"]:
        pos = tuple(st["agent"]["pos"])
        d = int(st["agent"]["dir"])
        traj.append((pos, d))
    return traj

def door_state_at(st: Dict, door_pos: Tuple[int, int]) -> Optional[int]:
    for o in st.get("objects", []):
        if o.get("type") == "door" and tuple(o.get("pos", [])) == door_pos:
            return int(o.get("state", -1))
    return None

def door_opens_after_action(states: List[Dict], t: int, door_pos: Tuple[int, int], horizon: int = 4) -> bool:
    for j in range(t + 1, min(t + 1 + horizon, len(states))):
        ds = door_state_at(states[j], door_pos)
        if ds == 0:
            return True
    return False

def encoding_has_object(enc: List, obj_type: str, obj_color: Optional[str] = None) -> bool:
    try:
        arr = np.array(enc, dtype=np.int64)
        obj_idx = OBJECT_TO_IDX.get(obj_type, None)
        if obj_idx is None:
            return False
        if obj_color is None:
            return bool(np.any(arr[:, :, 0] == int(obj_idx)))
        col_idx = COLOR_TO_IDX.get(obj_color, None)
        if col_idx is None:
            return bool(np.any(arr[:, :, 0] == int(obj_idx)))
        return bool(np.any((arr[:, :, 0] == int(obj_idx)) & (arr[:, :, 1] == int(col_idx))))
    except Exception:
        return False

def parse_target_from_mission(mission: str) -> Tuple[Optional[str], Optional[str]]:
    toks = mission.lower().strip().split()
    if not toks:
        return None, None
    obj = toks[-1]
    color = toks[-2] if len(toks) >= 2 and toks[-2] in COLOR_TO_IDX else None
    if obj not in OBJECT_TO_IDX:
        return None, color
    return obj, color

def _pos_to_room_index(env_unwrapped, pos: Tuple[int, int]) -> Optional[int]:
    rooms = getattr(env_unwrapped, "rooms", None)
    if rooms is None:
        return None

    x, y = pos
    for i, r in enumerate(rooms):
        top = getattr(r, "top", None)
        size = getattr(r, "size", None)

        if top is None and isinstance(r, dict):
            top = r.get("top")
            size = r.get("size")
        if top is None and isinstance(r, (list, tuple)) and len(r) >= 2:
            top, size = r[0], r[1]

        if top is None or size is None:
            continue

        rx, ry = int(top[0]), int(top[1])
        rw, rh = int(size[0]), int(size[1])
        if rx <= x < rx + rw and ry <= y < ry + rh:
            return i

    return None

def find_goal_pos_from_objects(objects: List[Dict], target_obj: Optional[str], target_color: Optional[str]) -> Optional[Tuple[int, int]]:
    if target_obj is not None:
        candidates = [o for o in objects if o.get("type") == target_obj]
        if target_color is not None:
            candidates = [o for o in candidates if o.get("color") == target_color]
        if len(candidates) == 1:
            return tuple(candidates[0]["pos"])
        if candidates:
            return tuple(candidates[0]["pos"])

    for o in objects:
        if o.get("type") in ("ball", "key", "box"):
            return tuple(o["pos"])
    return None

def validate_group(gid: str, full: Dict, nocue: Dict, cf: Dict, root: Path, audit_stats: Dict) -> Tuple[bool, str]:
    # D1 Action Identity
    if full["actions_id"] != nocue["actions_id"] or full["actions_id"] != cf["actions_id"]:
        return False, "action_mismatch"

    for v_name, v in [("full", full), ("nocue", nocue), ("cf", cf)]:
        if len(v["frames"]) != len(v["actions_id"]) + 1:
            return False, f"{v_name}_len_mismatch"
        if len(v["state_seq"]) != len(v["actions_id"]) + 1:
            return False, f"{v_name}_state_len_mismatch"

        mif = set(v.get("model_input_fields", []))
        if mif != ALLOWED_INPUTS:
            return False, f"{v_name}_bad_contract_fields"
        if not all(k in v for k in mif):
            return False, f"{v_name}_missing_contract_data"
        if any(k in v.keys() and k in PROHIBITED_INPUTS for k in mif): return False, f"{v_name}_info_leak"

        for st in v["state_seq"]:
            if not validate_state_encoding(st.get("state_encoding")):
                return False, f"{v_name}_invalid_encoding"

    # Termination expectations
    if full["terminated"] is not True:
        return False, "full_not_terminated"
    if cf["terminated"] is True:
        return False, "cf_terminated_error"
    if cf["truncated"] is True:
        return False, "cf_truncated_error"
    if nocue["terminated"] != full["terminated"]:
        return False, "nocue_term_mismatch"

    # Outcomes
    if not full.get("success") or not nocue.get("success") or cf.get("success"):
        return False, "outcome_fail"
    if float(cf.get("reward", 0.0)) != 0.0:
        return False, "cf_reward_nonzero"

    # D6 Physics (agent position/dir sequence)
    ft, nt, ct = _agent_traj(full), _agent_traj(nocue), _agent_traj(cf)
    if ft != nt or ft != ct:
        return False, "physics_drift"

    actions, states = full["actions_id"], full["state_seq"]

    # Identify the locked door at t=0 to pin down the *unlock key* color and door position.
    door0 = next((o for o in states[0].get("objects", []) if o.get("type") == "door" and int(o.get("state", -1)) == 2), None)
    if door0 is None:
        return False, "missing_locked_door"
    door_pos = tuple(door0.get("pos", []))
    door_color = str(door0.get("color"))

    # KeyCorridor semantics (robust): pickup unlock-key -> interact w/ door -> pickup mission target
    pickup_key_t = next(
        (
            t for t, a in enumerate(actions)
            if a == 3
            and states[t]["front_cell"]["type"] == "key"
            and str(states[t]["front_cell"].get("color")) == door_color
        ),
        None,
    )

    door_interact_t = next(
        (
            t for t, a in enumerate(actions)
            if a in (2, 5)
            and states[t]["front_cell"]["type"] == "door"
            and tuple(states[t]["front_cell"].get("pos", (-999, -999))) == door_pos
        ),
        None,
    )

    tgt_obj, tgt_color = parse_target_from_mission(full.get("mission", ""))
    if tgt_obj is None:
        pickup_goal_t = next(
            (t for t, a in enumerate(actions) if a == 3 and states[t]["front_cell"]["type"] in ("ball", "key")),
            None,
        )
    else:
        pickup_goal_t = next(
            (
                t for t, a in enumerate(actions)
                if a == 3
                and states[t]["front_cell"]["type"] == tgt_obj
                and (tgt_color is None or str(states[t]["front_cell"].get("color")) == str(tgt_color))
            ),
            None,
        )

    # For CF trajectories, pickup_goal_t may be None due to intervention (goal replaced with wall)
    # But we still require the key and door events to be present and in correct order
    if pickup_key_t is None or door_interact_t is None:
        return False, "semantic_event_missing"

    # CF trajectories may not have pickup_goal_t due to intervention
    if pickup_goal_t is None and cf.get("success") == False:
        # This is acceptable for CF - intervention prevents goal pickup
        pass
    elif pickup_goal_t is None:
        return False, "semantic_event_missing"
    # Check sequence: pickup_key -> door_interact
    if not (pickup_key_t < door_interact_t):
        return False, "semantic_seq_fail"

    # For successful trajectories (Full/NoCue), also check door_interact -> pickup_goal
    if pickup_goal_t is not None and not (door_interact_t < pickup_goal_t):
        return False, "semantic_seq_fail"

    if int(states[door_interact_t]["front_cell"].get("state", -1)) != 2:
        return False, "door_not_locked"
    if not door_opens_after_action(states, door_interact_t, door_pos):
        return False, "door_not_open"

    # E2 KeyCorridor: skip room exploration check since key is in starting corridor
    # The success of the trajectory implies proper exploration and target reaching

    # E3 door interaction should enable entering the door cell soon after
    entered = False
    for j in range(door_interact_t + 1, min(door_interact_t + 8, len(states))):
        if tuple(states[j]["agent"]["pos"]) == door_pos:
            entered = True
            break
    if not entered:
        return False, "did_not_enter_after_door"

    # CF observability: target object must not be visible in cf (type+color precise)
    target_obj = tgt_obj if tgt_obj is not None else "ball"
    target_color = tgt_color if tgt_obj is not None else None
    for st in cf["state_seq"]:
        if encoding_has_object(st.get("state_encoding"), target_obj, target_color):
            return False, "cf_target_visible"

    # CF meta consistency
    cfm = cf.get("cf_meta", {})
    if cfm.get("cf_mode") not in ("remove_goal_object", "replace_goal_with_wall"):
        return False, "cf_meta_missing_or_wrong"

    goal_pos = find_goal_pos_from_objects(states[0].get("objects", []), tgt_obj, tgt_color)
    goal_from = tuple(cfm.get("goal_from", [])) if isinstance(cfm.get("goal_from", None), list) else None
    if goal_pos is not None and goal_from is not None and goal_from != goal_pos:
        return False, "cf_goal_from_mismatch"

    # NoCue hard gates
    meta = nocue.get("nocue_meta", {})
    masked = meta.get("window_steps", [])
    if not masked:
        return False, "empty_mask"

    # NoCue_spec schema check (must exist; extra keys allowed)
    required_keys = {
        "targets",
        "window_policy",
        "window_steps",
        "mask_strength_target",
        "mask_strength_actual",
        "mask_strength_threshold",
        "mask_type",
        "alignment_score",
        "alignment_threshold",
        "masked_frames",
        "physics_check_passed",
    }
    if not required_keys.issubset(set(meta.keys())):
        return False, "nocue_meta_missing"
    if meta.get("window_policy") != "EARLY":
        return False, "nocue_window_policy_bad"
    if not isinstance(meta.get("targets"), list) or "key" not in meta.get("targets"):
        return False, "nocue_targets_bad"
    if int(meta.get("masked_frames")) != int(len(masked)):
        return False, "nocue_mask_count_mismatch"
    if float(meta.get("mask_strength_actual")) > float(meta.get("mask_strength_threshold")):
        return False, "nocue_mask_strength_exceeded"
    if float(meta.get("alignment_score")) < float(meta.get("alignment_threshold")):
        return False, "nocue_alignment_fail"

    # window_steps must be strictly before pickup_key_t and exclude interaction frames
    if any((not isinstance(i, int)) or i < 0 or i >= pickup_key_t for i in masked):
        return False, "nocue_window_invalid"
    if any(states[i]["front_cell"]["type"] == "key" for i in masked):
        return False, "nocue_masks_interaction_frame"

    if int(meta.get("masked_tiles_total", 10**18)) > int(meta.get("mask_budget_tiles", -1)):
        return False, "budget_exceeded"

    # Diff-in-Mask: for each masked frame, pixel diffs must be within key tiles only
    for idx in masked:
        fp_f = root / full["frames"][idx]
        fp_n = root / nocue["frames"][idx]
        if not fp_f.exists() or not fp_n.exists():
            return False, "missing_png"

        img_f = _load_png(fp_f)
        img_n = _load_png(fp_n)
        if img_f.shape != img_n.shape:
            return False, "shape_mismatch"

        diff = np.any(img_f != img_n, axis=2)

        enc = full["state_seq"][idx]["state_encoding"]
        tile_mask = _cue_tile_mask(enc, "key", door_color)
        if tile_mask.sum() <= 0:
            return False, "masked_t_no_key_tiles"

        pix_mask = _pixel_mask_from_tile_mask(tile_mask, img_f.shape[0], img_f.shape[1])

        if np.any(diff & (~pix_mask)):
            return False, "diff_outside_key_tiles"
        if not np.any(diff & pix_mask):
            return False, "no_diff_in_key_tiles"

    audit_stats["mask_counts"].append(len(masked))
    return True, "pass"

def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument("--triplets", type=str, required=True, help="Path to triplets.jsonl")
    ap.add_argument("--root", type=str, required=True, help="Dataset root dir that contains per-group folders")
    ap.add_argument("--out-audit", type=str, default="audit_report.json")
    args = ap.parse_args()

    data = []
    with open(args.triplets, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            data.append(json.loads(line))

    if len(data) % 3 != 0:
        _die(f"JSONL lines must be multiple of 3, got {len(data)}")

    groups = defaultdict(dict)
    for r in data:
        groups[r["group_id"]][r["variant"]] = r

    audit = {"passed": 0, "failed": 0, "reasons": defaultdict(int), "mask_counts": []}
    root = Path(args.root)

    for gid, vs in groups.items():
        if not {"full", "nocue", "cf"}.issubset(vs.keys()):
            audit["failed"] += 1
            audit["reasons"]["missing_variant"] += 1
            continue

        ok, reason = validate_group(gid, vs["full"], vs["nocue"], vs["cf"], root, audit)
        if ok:
            audit["passed"] += 1
        else:
            audit["failed"] += 1
            audit["reasons"][reason] += 1

    out = dict(audit)
    out["reasons"] = dict(out["reasons"])
    with open(args.out_audit, "w", encoding="utf-8") as f:
        json.dump(out, f, ensure_ascii=False, indent=2)

    print(json.dumps(out, ensure_ascii=False, indent=2))

if __name__ == "__main__":
    main()
