"""
Validate that transition matrix T_next matches actual env transitions.

Workflow:
- Load KeyLockEnv and transition matrix (T_next, T_done).
- Map env observations to discrete indices using state_to_index (simplified state, removed has_yellow_key and has_blue_key).
- Randomly initialize states using sample_random_initial_state and reset_env_to_state.
- Roll out random actions; after each step compare env-derived next index with T_next lookup.
- On mismatch, print detailed diagnostics (obs/env vs decoded from matrix).
"""

import os
import argparse
import numpy as np
import random

from key_lock_env import KeyLockEnv
from key_lock_options import (
    state_to_index,
    index_to_state,
    reset_env_to_state,
)
from generate_keylock_transition_matrix import build_keylock_transition_matrix


def sample_random_initial_state(env, size):
    """Sample an initial state; fixed start pos (1,1)."""
    x, y = 1, 1
    dir = 0
    config_type = random.randint(0, 4)
    if config_type == 0:
        return (x, y, dir, 0, 0, 1, 1)
    elif config_type == 1:
        return (x, y, dir, 0, 0, 1, 0)
    elif config_type == 2:
        return (x, y, dir, 0, 1, 1, 0)
    elif config_type == 3:
        return (x, y, dir, 0, 1, 0, 0)
    else:
        return (x, y, dir, 1, 1, 0, 0)


def obs_to_index(obs, size):
    """Map env observation to flattened state index (simplified state, removed has_yellow_key and has_blue_key)."""
    (
        x, y, dir,
        yellow_door_open, blue_door_open,
        yellow_key_on_map, blue_key_on_map,
    ) = obs
    return state_to_index(
        x, y, dir,
        yellow_door_open, blue_door_open,
        yellow_key_on_map, blue_key_on_map,
        size,
    )


def load_or_build_matrix(args):
    script_dir = os.path.dirname(os.path.abspath(__file__))
    tm_dir = os.path.join(script_dir, args.transition_matrix_dir)
    T_next_path = os.path.join(tm_dir, "T_next.npy")
    T_done_path = os.path.join(tm_dir, "T_done.npy")
    valid_states_path = os.path.join(tm_dir, "valid_states.npy")

    # Always regenerate for debugging to avoid stale matrices
    print(f"[TM] Building transition matrix (force regenerate for debug)")
    env = KeyLockEnv(
        size=args.size,
        agent_start_pos=(1, 1),
        agent_start_dir=0,
        yellow_key_pos=tuple(args.yellow_key_pos),
        yellow_door_pos=tuple(args.yellow_door_pos),
        blue_key_pos=tuple(args.blue_key_pos),
        blue_door_pos=tuple(args.blue_door_pos),
        goal_pos=tuple(args.goal_pos),
        max_steps=args.env_max_steps,
        render_mode=None,
    )
    env.reset()
    T_next, T_done, valid_states = build_keylock_transition_matrix(
        env, args.size,
        tuple(args.yellow_key_pos),
        tuple(args.yellow_door_pos),
        tuple(args.blue_key_pos),
        tuple(args.blue_door_pos),
        tuple(args.goal_pos),
    )
    os.makedirs(tm_dir, exist_ok=True)
    np.save(T_next_path, T_next)
    np.save(T_done_path, T_done)
    np.save(valid_states_path, np.array(list(valid_states), dtype=np.int32))
    print(f"[TM] Saved to {tm_dir}")
    return T_next, T_done, valid_states


def run_checks(args):
    env = KeyLockEnv(
        size=args.size,
        agent_start_pos=(1, 1),
        agent_start_dir=0,
        yellow_key_pos=tuple(args.yellow_key_pos),
        yellow_door_pos=tuple(args.yellow_door_pos),
        blue_key_pos=tuple(args.blue_key_pos),
        blue_door_pos=tuple(args.blue_door_pos),
        goal_pos=tuple(args.goal_pos),
        max_steps=args.env_max_steps,
        render_mode=None,
    )
    T_next, T_done, valid_states = load_or_build_matrix(args)
    mismatches = 0

    for ep in range(args.episodes):
        # Sample a logical start state
        state = sample_random_initial_state(env, args.size)
        reset_env_to_state(env, *state)
        obs = env._get_obs()
        s_idx = obs_to_index(obs, args.size)        # env index
        s_idx_tm = s_idx                             # tm index (start aligned)

        # Trajectory buffers for debugging (store full until mismatch)
        tm_traj = [(None, s_idx_tm, index_to_state(s_idx_tm, args.size))]  # (action, idx, decoded)
        env_traj = [(None, s_idx, obs)]  # (action, idx, obs)
        mismatch_found = False

        for step in range(args.max_steps):
            a = random.randint(0, 5)
            # Matrix lookup
            s_idx_next_tm = int(T_next[s_idx_tm, a])
            done_tm = bool(T_done[s_idx_tm, a])
            decoded_tm = index_to_state(s_idx_next_tm, args.size)

            # Env step
            obs_next, reward, terminated, truncated, info = env.step(a)
            s_idx_next_env = obs_to_index(obs_next, args.size)
            done_env = terminated or truncated

            # Record trajectories
            tm_traj.append((a, s_idx_next_tm, decoded_tm))
            env_traj.append((a, s_idx_next_env, obs_next))

            # Check for mismatch
            idx_mismatch = (s_idx_next_tm != s_idx_next_env)
            done_mismatch = (done_tm != done_env)
            
            if idx_mismatch or done_mismatch:
                mismatches += 1
                print("\n" + "="*80)
                print(f"[MISMATCH #{mismatches}] Episode {ep}, Step {step}")
                print("="*80)
                print(f"Action: {a} ({'up' if a==0 else 'down' if a==1 else 'left' if a==2 else 'right' if a==3 else 'pickup' if a==4 else 'toggle'})")
                print()
                
                # Current state (before action)
                print(f"Current State (before action):")
                print(f"  TM index: {s_idx_tm}")
                print(f"  TM decoded: {index_to_state(s_idx_tm, args.size)}")
                print(f"  Env obs: {obs}")
                print(f"  Env index: {s_idx}")
                print()
                
                # Next state (after action)
                print(f"Next State (after action):")
                print(f"  TM -> index: {s_idx_next_tm}, done: {done_tm}")
                print(f"  TM -> decoded: {decoded_tm}")
                print(f"  Env -> index: {s_idx_next_env}, done: {done_env}")
                print(f"  Env -> obs: {obs_next}")
                print()
                
                # Detailed comparison
                if idx_mismatch:
                    print(f"  ❌ INDEX MISMATCH: TM={s_idx_next_tm} vs Env={s_idx_next_env}")
                    # Compare decoded states field by field
                    tm_state = decoded_tm
                    env_state = obs_next
                    print(f"  Field-by-field comparison:")
                    print(f"    (x, y, dir): TM=({tm_state[0]}, {tm_state[1]}, {tm_state[2]}) vs Env=({env_state[0]}, {env_state[1]}, {env_state[2]})")
                    print(f"    yellow_door_open: TM={tm_state[3]} vs Env={env_state[3]}")
                    print(f"    blue_door_open: TM={tm_state[4]} vs Env={env_state[4]}")
                    print(f"    yellow_key_on_map: TM={tm_state[5]} vs Env={env_state[5]}")
                    print(f"    blue_key_on_map: TM={tm_state[6]} vs Env={env_state[6]}")
                    
                    # For pickup actions, show key position and carrying status debugging info
                    if a == 4:
                        front_x, front_y = env._get_front_pos()
                        yellow_key_pos = env.yellow_key_pos
                        blue_key_pos = env.blue_key_pos
                        can_pickup_yellow = env._can_pickup_key('yellow')
                        can_pickup_blue = env._can_pickup_key('blue')
                        carrying = env.carrying
                        print(f"  Pickup action debug:")
                        print(f"    Agent position: ({obs[0]}, {obs[1]}), direction: {obs[2]}")
                        print(f"    Front position: ({front_x}, {front_y})")
                        print(f"    Yellow key position: {yellow_key_pos}")
                        print(f"    Blue key position: {blue_key_pos}")
                        print(f"    Can pickup yellow: {can_pickup_yellow}")
                        print(f"    Can pickup blue: {can_pickup_blue}")
                        print(f"    Currently carrying: {carrying} (type: {type(carrying)})")
                        # Check if front position matches any key
                        if (front_x, front_y) == yellow_key_pos:
                            yellow_key_cell = env.grid.get(front_x, front_y)
                            print(f"    Yellow key cell: {yellow_key_cell}, type: {type(yellow_key_cell)}")
                            if yellow_key_cell is not None:
                                print(f"      Key color: {yellow_key_cell.color if hasattr(yellow_key_cell, 'color') else 'N/A'}")
                        if (front_x, front_y) == blue_key_pos:
                            blue_key_cell = env.grid.get(front_x, front_y)
                            print(f"    Blue key cell: {blue_key_cell}, type: {type(blue_key_cell)}")
                            if blue_key_cell is not None:
                                print(f"      Key color: {blue_key_cell.color if hasattr(blue_key_cell, 'color') else 'N/A'}")
                    
                    # For toggle actions, show door position debugging info
                    if a == 5:
                        front_x, front_y = env._get_front_pos()
                        yellow_door_pos = env.yellow_door_pos
                        blue_door_pos = env.blue_door_pos
                        can_toggle_yellow = env._can_toggle_door('yellow')
                        can_toggle_blue = env._can_toggle_door('blue')
                        print(f"  Toggle action debug:")
                        print(f"    Agent position: ({obs[0]}, {obs[1]}), direction: {obs[2]}")
                        print(f"    Front position: ({front_x}, {front_y})")
                        print(f"    Yellow door position: {yellow_door_pos}")
                        print(f"    Blue door position: {blue_door_pos}")
                        print(f"    Can toggle yellow: {can_toggle_yellow}")
                        print(f"    Can toggle blue: {can_toggle_blue}")
                        # Check if front position matches any door
                        if (front_x, front_y) == yellow_door_pos:
                            yellow_door_cell = env.grid.get(front_x, front_y)
                            print(f"    Yellow door cell: {yellow_door_cell}, type: {type(yellow_door_cell)}")
                            if yellow_door_cell is not None:
                                print(f"      Door color: {yellow_door_cell.color if hasattr(yellow_door_cell, 'color') else 'N/A'}")
                                print(f"      Door is_open: {yellow_door_cell.is_open if hasattr(yellow_door_cell, 'is_open') else 'N/A'}")
                                print(f"      Door is_locked: {yellow_door_cell.is_locked if hasattr(yellow_door_cell, 'is_locked') else 'N/A'}")
                        if (front_x, front_y) == blue_door_pos:
                            blue_door_cell = env.grid.get(front_x, front_y)
                            print(f"    Blue door cell: {blue_door_cell}, type: {type(blue_door_cell)}")
                            if blue_door_cell is not None:
                                print(f"      Door color: {blue_door_cell.color if hasattr(blue_door_cell, 'color') else 'N/A'}")
                                print(f"      Door is_open: {blue_door_cell.is_open if hasattr(blue_door_cell, 'is_open') else 'N/A'}")
                                print(f"      Door is_locked: {blue_door_cell.is_locked if hasattr(blue_door_cell, 'is_locked') else 'N/A'}")
                if done_mismatch:
                    print(f"  ❌ DONE MISMATCH: TM={done_tm} vs Env={done_env}")
                print()
                
                print("="*80)
                import sys
                sys.stdout.flush()  # Force immediate output
                
                mismatch_found = True
                break

            # advance
            obs = obs_next
            s_idx = s_idx_next_env
            s_idx_tm = s_idx_next_tm
            if done_env:
                break

        if mismatch_found and mismatches >= args.max_mismatch_print:
            print(f"Reached max mismatch prints ({args.max_mismatch_print}); stopping.")
            return

    if mismatches == 0:
        print("All checked transitions match between env and transition matrix.")
    else:
        print(f"Finished with {mismatches} mismatches (see logs above).")


def main():
    pa = argparse.ArgumentParser()
    pa.add_argument("--size", type=int, default=15)
    pa.add_argument("--yellow_key_pos", type=int, nargs=2, default=[12, 3])
    pa.add_argument("--yellow_door_pos", type=int, nargs=2, default=[3, 8])
    pa.add_argument("--blue_key_pos", type=int, nargs=2, default=[12, 12])
    pa.add_argument("--blue_door_pos", type=int, nargs=2, default=[9, 3])
    pa.add_argument("--goal_pos", type=int, nargs=2, default=[3, 12])
    pa.add_argument("--transition_matrix_dir", type=str, default="keylock_transition_matrix")
    pa.add_argument("--generate", action="store_true", help="force regenerate transition matrix")
    pa.add_argument("--episodes", type=int, default=20)
    pa.add_argument("--max_steps", type=int, default=5000)
    pa.add_argument("--env_max_steps", type=int, default=1000000, help="prevent truncation during test")
    pa.add_argument("--max_mismatch_print", type=int, default=10)
    pa.add_argument("--trace_len", type=int, default=20, help="tail length of trajectory to print on mismatch")
    args = pa.parse_args()

    run_checks(args)


if __name__ == "__main__":
    main()

