#!/usr/bin/env python
"""Train Random/Eigen/VPS options on KeyLockEnv (tabular).

Two-stage training for VPS and Random options (buffer then offline Q),
eigenoptions from graph Laplacian eigenvectors. Produces saved Q-tables
compatible with downstream experiments and visualization.
"""
import os, argparse, random, collections, ast
import numpy as np
import scipy.sparse as sp
from scipy.sparse.linalg import eigsh
import matplotlib.pyplot as plt
from key_lock_env import KeyLockEnv
def state_to_index(
    x, y, dir,
    yellow_door_open, blue_door_open,
    yellow_key_on_map, blue_key_on_map,
    size=15,
):
    index = x
    index = index * size + y
    index = index * 4 + dir
    index = index * 2 + yellow_door_open
    index = index * 2 + blue_door_open
    index = index * 2 + yellow_key_on_map
    index = index * 2 + blue_key_on_map
    return int(index)


def index_to_state(state_index, size=15):
    total = size * size * 4 * 2 * 2 * 2 * 2
    state_index = state_index % total

    blue_key_on_map = state_index % 2
    state_index //= 2

    yellow_key_on_map = state_index % 2
    state_index //= 2

    blue_door_open = state_index % 2
    state_index //= 2

    yellow_door_open = state_index % 2
    state_index //= 2

    dir = state_index % 4
    state_index //= 4

    y = state_index % size
    state_index //= size

    x = state_index % size

    return (
        x, y, dir,
        yellow_door_open, blue_door_open,
        yellow_key_on_map, blue_key_on_map,
    )


def set_seed(sd):
    if sd is not None:
        np.random.seed(sd)
        random.seed(sd)


def _parse_num_opts_list(x) -> list[int]:
    if isinstance(x, int):
        return [int(x)]
    if isinstance(x, (list, tuple)):
        return [int(v) for v in x]
    s = str(x).strip()
    if not s:
        raise ValueError("--num_opts is empty")
    if "," in s and not (s.startswith("[") or s.startswith("(")):
        return [int(p.strip()) for p in s.split(",") if p.strip()]
    try:
        val = ast.literal_eval(s)
    except Exception:
        return [int(s)]
    if isinstance(val, int):
        return [int(val)]
    if isinstance(val, (list, tuple)):
        return [int(v) for v in val]
    raise ValueError(f"Unsupported --num_opts format: {x!r}")


def option_rows_dup(arr, sign):
    """Duplicate positive/negative rows when sign=True."""
    return np.vstack([arr, arr]) if sign else arr


def build_state_index_mapping(valid_states, N_full):
    valid_states = np.array(valid_states, dtype=np.int32)
    valid_states = np.sort(valid_states)  # Sort for consistent ordering
    
    N_valid = len(valid_states)
    
    # Create mapping: full_index -> compressed_index
    state_to_compressed = {}
    for compressed_idx, full_idx in enumerate(valid_states):
        state_to_compressed[int(full_idx)] = compressed_idx
    
    # Create reverse mapping: compressed_index -> full_index
    compressed_to_state = valid_states.copy()
    
    print(f"[State Mapping] Created mapping: {N_valid} valid states out of {N_full} total states")
    print(f"[State Mapping] Compression ratio: {N_valid / N_full * 100:.2f}%")
    
    return state_to_compressed, compressed_to_state, N_valid


def convert_buffer_to_compressed(buffer, state_to_compressed):
    compressed_buffer = []
    for s, a, sn, done in buffer:
        s_full = int(s)
        sn_full = int(sn)
        
        # Convert to compressed indices
        s_compressed = state_to_compressed.get(s_full)
        sn_compressed = state_to_compressed.get(sn_full)
        
        # Skip if state not in mapping (shouldn't happen, but safety check)
        if s_compressed is None or sn_compressed is None:
            continue
        
        compressed_buffer.append((s_compressed, a, sn_compressed, done))
    
    return compressed_buffer


def expand_compressed_features(features_compressed, compressed_to_state, N_full):
    is_1d = features_compressed.ndim == 1
    if is_1d:
        features_compressed = features_compressed[np.newaxis, :]
    
    k, N_valid = features_compressed.shape
    
    # Initialize full feature array with zeros
    features_full = np.zeros((k, N_full), dtype=features_compressed.dtype)
    
    # Map compressed indices to full indices
    for compressed_idx in range(N_valid):
        full_idx = int(compressed_to_state[compressed_idx])
        features_full[:, full_idx] = features_compressed[:, compressed_idx]
    
    if is_1d:
        features_full = features_full[0]
    
    return features_full


def get_wall_positions(env, size):
    wall_positions = []
    from minigrid.core.world_object import Wall
    
    for x in range(size):
        for y in range(size):
            cell = env.grid.get(x, y)
            if cell is not None and isinstance(cell, Wall):
                wall_positions.append((x, y))
    
    return wall_positions


def visualize_feature_distribution(
    features, size: int, save_dir: str, prefix: str,
    iteration: int = 0, max_options_to_plot: int = 4, env=None
):
    # Handle both 1D (eigenvector) and 2D (VPS features) cases
    is_single_feature = features.ndim == 1
    if is_single_feature:
        features = features[np.newaxis, :]  # (1, N)
    
    k, N = features.shape
    num_to_plot = min(k, max_options_to_plot)
    
    state_configs = [
        {
            "name": "No keys",
            "yellow_door_open": 0,
            "blue_door_open": 0,
            "yellow_key_on_map": 1,
            "blue_key_on_map": 1,
        },
        {
            "name": "Blue key picked",
            "yellow_door_open": 0,
            "blue_door_open": 0,
            "yellow_key_on_map": 1,  # Yellow key still on map
            "blue_key_on_map": 0,     # Blue key removed from map when picked
        },
        {
            "name": "Blue door open",
            "yellow_door_open": 0,
            "blue_door_open": 1,
            "yellow_key_on_map": 1,  # Yellow key still on map
            "blue_key_on_map": 0,     # Blue key was consumed, not on map
        },
        {
            "name": "Yellow key picked",
            "yellow_door_open": 0,
            "blue_door_open": 1,  # Blue door must be open to reach yellow key
            "yellow_key_on_map": 0,  # Yellow key removed from map when picked
            "blue_key_on_map": 0,     # Blue key was consumed, not on map
        },
        {
            "name": "Both doors open",
            "yellow_door_open": 1,
            "blue_door_open": 1,
            "yellow_key_on_map": 0,  # Both keys consumed, not on map
            "blue_key_on_map": 0,
        },
    ]
    
    # Create figure with subplots
    if num_to_plot == 1 and len(state_configs) == 1:
        fig, axes = plt.subplots(1, 1, figsize=(6, 6))
        axes = np.array([[axes]])
    elif num_to_plot == 1:
        fig, axes = plt.subplots(1, len(state_configs), figsize=(4 * len(state_configs), 4))
        axes = axes[np.newaxis, :] if len(state_configs) > 1 else np.array([[axes]])
    elif len(state_configs) == 1:
        fig, axes = plt.subplots(num_to_plot, 1, figsize=(6, 4 * num_to_plot))
        axes = axes[:, np.newaxis]
    else:
        fig, axes = plt.subplots(
            num_to_plot, len(state_configs), 
            figsize=(4 * len(state_configs), 4 * num_to_plot)
        )
    
    # Get wall positions if env is provided
    wall_positions = []
    if env is not None:
        wall_positions = get_wall_positions(env, size)
    
    for opt_idx in range(num_to_plot):
        for config_idx, config in enumerate(state_configs):
            heatmap = np.zeros((size, size))
            heatmap.fill(np.nan)
            
            for x in range(size):
                for y in range(size):
                    if (x, y) in wall_positions:
                        continue
                    vals = []
                    for dir_fixed in range(4):
                        s = state_to_index(
                            x, y, dir_fixed,
                            config["yellow_door_open"],
                            config["blue_door_open"],
                            config.get("yellow_key_on_map", 1),
                            config.get("blue_key_on_map", 1),
                            size
                        )
                        vals.append(features[opt_idx, s])
                    if len(vals) > 0:
                        heatmap[y, x] = np.mean(vals)
            
            ax = axes[opt_idx, config_idx]
            masked_heatmap = np.ma.masked_invalid(heatmap)
            cmap = plt.cm.viridis.copy()
            cmap.set_bad(color='gray', alpha=0.5)
            im = ax.imshow(masked_heatmap, origin='upper', cmap=cmap, aspect='auto', 
                          extent=[0, size, 0, size])
            
            ax.set_title(f"{prefix} opt {opt_idx}\n{config['name']}", fontsize=10)
            ax.set_xlabel('X')
            ax.set_ylabel('Y')
            plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    
    plt.tight_layout()
    
    # Save figure
    os.makedirs(save_dir, exist_ok=True)
    filename = os.path.join(save_dir, f"{prefix}_distribution_iter{iteration:04d}.png")
    plt.savefig(filename, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"  [Viz] Saved {filename}")


def reset_env_to_state(env, x, y, dir,
                      yellow_door_open, blue_door_open,
                      yellow_key_on_map=1, blue_key_on_map=1):
    env.reset()
    env.yellow_key_on_map = int(yellow_key_on_map)
    env.blue_key_on_map = int(blue_key_on_map)
    env._place_keys_doors_goal()
    env.agent_pos = (x, y)
    env.agent_dir = dir
    
    from minigrid.core.world_object import Key, Door
    env.carrying = None
    
    if yellow_key_on_map == 0 and yellow_door_open == 0:
        env.carrying = Key('yellow')
    elif blue_key_on_map == 0 and blue_door_open == 0:
        env.carrying = Key('blue')
    
    yellow_door_cell = env.grid.get(*env.yellow_door_pos)
    if yellow_door_cell is not None and isinstance(yellow_door_cell, Door):
        yellow_door_cell.is_open = bool(yellow_door_open)
        yellow_door_cell.is_locked = False if yellow_key_on_map == 0 else not yellow_door_open
        if yellow_door_open:
            env.yellow_key_on_map = 0
            env.grid.set(*env.yellow_key_pos, None)
    
    blue_door_cell = env.grid.get(*env.blue_door_pos)
    if blue_door_cell is not None and isinstance(blue_door_cell, Door):
        blue_door_cell.is_open = bool(blue_door_open)
        blue_door_cell.is_locked = False if blue_key_on_map == 0 else not blue_door_open
        if blue_door_open:
            env.blue_key_on_map = 0
            env.grid.set(*env.blue_key_pos, None)


# ---------- phase-0 : collect buffer using transition matrix ------------------------
def collect_buffer(T_next, T_done, valid_states, *, episodes=1000, max_len=200, use_random_init=True, size=15):
    buffer = []
    visited = set()
    
    # Convert valid_states to array if it's a set
    if isinstance(valid_states, set):
        valid_states_array = np.array(list(valid_states), dtype=np.int32)
    else:
        valid_states_array = valid_states
    
    for _ in range(episodes):
        if use_random_init:
            # Randomly sample a valid initial state
            s = int(np.random.choice(valid_states_array))
        else:
            # Default start state with all keys on map, doors closed
            s = state_to_index(1, 1, 0,
                               0, 0,  # yellow_door_open, blue_door_open
                               1, 1,  # yellow_key_on_map, blue_key_on_map
                               size)
        
        visited.add(s)
        
        for _ in range(max_len):
            # Random action (0-5: up, down, left, right, pickup, toggle)
            a = random.randint(0, 5)
            
            # Lookup next state and done flag from transition matrix
            sn = int(T_next[s, a])
            done = bool(T_done[s, a])
            
            buffer.append((s, a, sn, done))
            visited.add(sn)
            s = sn
            
            if done:
                break
    
    return buffer, visited


def collect_buffer_intrinsic_reward(
    T_next, T_done, valid_states, N, *,
    episodes=1000, max_len=200, alpha_m=0.1, alpha_q=0.1,
    gamma=0.99, epsilon=0.1, init_m=None, init_q=None, size=15,
):
    buffer = []
    visited = set()
    num_actions = 6
    M = np.zeros((N, N), dtype=np.float32) if init_m is None else init_m.copy()
    Q = np.zeros((N, num_actions), dtype=np.float32) if init_q is None else init_q.copy()
    
    # Convert valid_states to array if it's a set
    if isinstance(valid_states, set):
        valid_states_array = np.array(list(valid_states), dtype=np.int32)
    else:
        valid_states_array = valid_states
    
    for ep in range(episodes):
        # Random initial state from valid states
        s = int(np.random.choice(valid_states_array))
        visited.add(s)
        
        # Select initial action using epsilon-greedy
        if random.random() < epsilon:
            a = random.randint(0, num_actions - 1)
        else:
            a = int(np.argmax(Q[s]))
        
        for step in range(max_len):
            # Lookup next state and done flag from transition matrix
            sn = int(T_next[s, a])
            done = bool(T_done[s, a])
            
            visited.add(sn)
            
            if not done:
                target_m = np.zeros(N, dtype=np.float32)
                target_m[sn] = 1.0
                target_m += gamma * M[sn, :]
                delta_m = target_m - M[s, :]
            else:
                target_m = np.zeros(N, dtype=np.float32)
                target_m[sn] = 1.0
                delta_m = target_m - M[s, :]
            M[s, :] += alpha_m * delta_m
            
            intrinsic_reward = M[s, sn] - np.sum(np.abs(M[s, :]))
            
            if done:
                target_q = intrinsic_reward
                a_next = random.randint(0, num_actions - 1)
            else:
                a_next = random.randint(0, num_actions - 1) if random.random() < epsilon else int(np.argmax(Q[sn]))
                target_q = intrinsic_reward + gamma * Q[sn, a_next]
            Q[s, a] += alpha_q * (target_q - Q[s, a])
            
            # Store transition in buffer
            buffer.append((s, a, sn, done))
            
            # Move to next state
            s = sn
            a = a_next
            
            if done:
                break
    
    return buffer, visited, M, Q


def train_successor_representation(
    buffer, *, gamma, alpha, sr_epochs=1, sr_lambda=0.0,
    state_to_compressed=None, compressed_to_state=None, N_valid=None, N=None,
    visualize=False, save_dir=None, viz_freq=500, env=None, size=None,
    k_base_viz=4, viz_iteration=0, valid_states=None,
):
    use_compressed = (state_to_compressed is not None and 
                     compressed_to_state is not None and N_valid is not None)
    
    if use_compressed:
        print(f"[SR] Using compressed state space: {N_valid} valid states (reduced from {N})")
        buffer_compressed = convert_buffer_to_compressed(buffer, state_to_compressed)
        print(f"[SR] Converted buffer: {len(buffer)} -> {len(buffer_compressed)} transitions")
        buffer_to_use = buffer_compressed
        state_space_size = N_valid
    else:
        print(f"[SR] Using full state space: {N} states")
        buffer_to_use = buffer
        state_space_size = N
    
    psi = np.zeros((state_space_size, state_space_size), dtype=np.float32)
    alpha_sr = alpha

    reward_weights_viz = None
    if visualize and save_dir is not None and size is not None:
        rng_viz = np.random.RandomState(42)
        randR_raw = rng_viz.randn(k_base_viz, state_space_size).astype(np.float32)
        Q_mat, _ = np.linalg.qr(randR_raw.T)
        reward_weights_viz = Q_mat.T.astype(np.float32)
        print(f"[SR] Generated {k_base_viz} fixed orthogonal reward weights for visualization")
    
    total_steps = 0
    
    if sr_lambda > 0.0:
        print(f"[SR] Using TD(λ) with λ={sr_lambda}")
        for epoch in range(sr_epochs):
            eligibility = np.zeros(state_space_size, dtype=np.float32)
            for s, a, sn, done in buffer_to_use:
                s, sn = int(s), int(sn)
                eligibility *= gamma * sr_lambda
                eligibility[s] += 1.0
                
                if done:
                    target = np.zeros(state_space_size, dtype=np.float32)
                    target[s] = 1.0
                    delta = target - psi[s].copy()
                else:
                    target = np.zeros(state_space_size, dtype=np.float32)
                    target[s] = 1.0
                    target += gamma * psi[sn]
                    delta = target - psi[s].copy()
                
                for state in range(state_space_size):
                    if eligibility[state] > 0:
                        psi[state] += alpha_sr * eligibility[state] * delta
                
                if done:
                    eligibility.fill(0.0)
                
                total_steps += 1
                if visualize and save_dir is not None and reward_weights_viz is not None and total_steps % viz_freq == 0:
                    V_base_current = reward_weights_viz @ psi.T
                    if use_compressed:
                        V_base_current = expand_compressed_features(V_base_current, compressed_to_state, N)
                    k_base_viz_actual = reward_weights_viz.shape[0]
                    sr_save_dir = os.path.join(save_dir, "sr")
                    os.makedirs(sr_save_dir, exist_ok=True)
                    visualize_feature_distribution(
                        V_base_current, size, sr_save_dir, f"sr_V_step{total_steps}", viz_iteration,
                        max_options_to_plot=min(4, k_base_viz_actual), env=env
                    )
    else:
        print(f"[SR] Using TD(0)")
        for epoch in range(sr_epochs):
            indices = np.random.permutation(len(buffer_to_use))
            for idx in indices:
                s, a, sn, done = buffer_to_use[idx]
                s, sn = int(s), int(sn)
                if done:
                    target = np.zeros(state_space_size, dtype=np.float32)
                    target[s] = 1.0
                    delta = target - psi[s].copy()
                else:
                    target = np.zeros(state_space_size, dtype=np.float32)
                    target[s] = 1.0
                    target += gamma * psi[sn]
                    delta = target - psi[s].copy()
                psi[s] += alpha_sr * delta
                total_steps += 1
                if visualize and save_dir is not None and reward_weights_viz is not None and total_steps % viz_freq == 0:
                    V_base_current = reward_weights_viz @ psi.T
                    if use_compressed:
                        V_base_current = expand_compressed_features(V_base_current, compressed_to_state, N)
                    k_base_viz_actual = reward_weights_viz.shape[0]
                    sr_save_dir = os.path.join(save_dir, "sr")
                    os.makedirs(sr_save_dir, exist_ok=True)
                    visualize_feature_distribution(
                        V_base_current, size, sr_save_dir, f"sr_V_step{total_steps}", viz_iteration,
                        max_options_to_plot=min(4, k_base_viz_actual), env=env
                    )
    
    print(f"[SR] Training completed: {total_steps} total steps, {sr_epochs} epochs")
    return psi, state_space_size, use_compressed


# ---------- 1. VPS-Option -----------------------------------
def train_vps_options(
    size,
    N,
    buffer,  # Shared buffer passed in (uses full state indices)
    *,
    k_base,
    sign,
    gamma,
    alpha,
    lam=0.9,   # kept for backward-compatibility, not used in SR updates
    sr_epochs=1,
    phi_epochs=1,
    q_epochs=1,
    sr_lambda=0.0,  # TD(λ) parameter for SR learning (0.0 = TD(0), >0 = TD(λ))
    value_type: str = "sr",  # "sr": V=w^T psi; "td": learn V via TD under random rewards
    visualize=False,
    save_dir=None,
    viz_iteration=0,
    viz_freq=500,
    env=None,  # Optional: only needed for visualization
    state_to_compressed=None,  # Optional: dict mapping full index -> compressed index
    compressed_to_state=None,  # Optional: array mapping compressed index -> full index
    N_valid=None,  # Optional: number of valid states
    psi=None,  # Optional: pre-trained successor representation (if None, will train SR)
    state_space_size=None,  # Optional: state space size (required if psi is provided)
    use_compressed=None,  # Optional: whether compressed state space is used (required if psi is provided)
):
    """
    Two stages (tabular VPS-options with SR-based value predictions):
      (1) Construct k_base mutually orthogonal reward weight vectors over states
          via QR decomposition. If psi is not provided, learn SR from buffer.
      (2) Compute option-specific value functions V_i(s) = w_i^T ψ(s),
          build VPS features φ_i(s) ≈ E[(V_i(s') - V_i(s))^2] from the
          buffer, and finally run offline Q-learning with intrinsic
          rewards r_i = φ_i(s') - φ_i(s).

    This mirrors the newer VPS option design: intrinsic rewards are
    derived from value-change signals induced by fixed, orthogonal
    reward weights applied to a shared SR.
    
    If psi is provided, SR training is skipped (optimization: SR is shared across all VPS options).
    If state_to_compressed is provided, uses compressed state space for efficiency.
    """
    num_actions = 6  # up, down, left, right, pickup, toggle
    total_opts = k_base * (2 if sign else 1)

    # Determine if using compressed state space
    if use_compressed is None:
        use_compressed = (state_to_compressed is not None and 
                         compressed_to_state is not None and 
                         N_valid is not None)
    
    if use_compressed:
        print(f"[VPS] Using compressed state space: {N_valid} valid states (reduced from {N})")
        # Convert buffer to compressed indices
        buffer_compressed = convert_buffer_to_compressed(buffer, state_to_compressed)
        print(f"[VPS] Converted buffer: {len(buffer)} -> {len(buffer_compressed)} transitions")
        buffer_to_use = buffer_compressed
        if state_space_size is None:
            state_space_size = N_valid
    else:
        print(f"[VPS] Using full state space: {N} states")
        buffer_to_use = buffer
        if state_space_size is None:
            state_space_size = N

    value_type = str(value_type).strip().lower()
    if value_type not in ("sr", "td"):
        raise ValueError(f"Unsupported value_type={value_type!r}; expected 'sr' or 'td'.")

    # ------- Phase-1 : construct reward weights & learn SR ψ(s) (if not provided) -------

    # 1) QR-based orthogonal reward weights over states.
    #    Each row w_i ∈ R^{state_space_size} defines an intrinsic reward function.
    randR_raw = np.random.randn(k_base, state_space_size).astype(np.float32)
    # QR on transpose → columns of Q_mat are orthonormal → rows of Q_mat.T are orthonormal
    Q_mat, _ = np.linalg.qr(randR_raw.T)           # (state_space_size, k_base)
    reward_weights = Q_mat.T.astype(np.float32)    # (k_base, state_space_size) orthonormal rows

    # 2) Build value functions V(s):
    #    - value_type="sr": learn / reuse SR ψ(s), then V_i(s)=w_i^T ψ(s)
    #    - value_type="td": learn V_i(s) directly by TD(0) under reward r_i(s)=w_i[s]

    if value_type == "sr":
        # Use pre-trained SR if provided, otherwise train SR
        if psi is None:
            print(f"[VPS] value_type=sr → training SR from scratch (not provided)")
            psi, state_space_size, use_compressed = train_successor_representation(
                buffer,
                gamma=gamma,
                alpha=alpha,
                sr_epochs=sr_epochs,
                sr_lambda=sr_lambda,
                state_to_compressed=state_to_compressed,
                compressed_to_state=compressed_to_state,
                N_valid=N_valid,
                N=N,
                visualize=False,  # Don't visualize during shared SR training
            )
            # Update buffer_to_use based on use_compressed (already set above, but ensure consistency)
            if use_compressed:
                buffer_to_use = convert_buffer_to_compressed(buffer, state_to_compressed)
            else:
                buffer_to_use = buffer
        else:
            print(f"[VPS] value_type=sr → using pre-trained SR (shape: {psi.shape})")
            if psi.shape[0] != state_space_size or psi.shape[1] != state_space_size:
                raise ValueError(f"SR shape {psi.shape} does not match state_space_size {state_space_size}")
            if use_compressed:
                buffer_to_use = convert_buffer_to_compressed(buffer, state_to_compressed)
            else:
                buffer_to_use = buffer

        # Value functions for each option i:
        #   V_i(s) = w_i^T ψ(s)  where w_i is reward_weights[i]
        V_base = reward_weights @ psi.T

    else:
        # TD value learning under random (orthogonalized) Gaussian state rewards.
        # r_i(s) = reward_weights[i, s]
        print(f"[VPS] value_type=td → learning V by TD(0) under random state rewards")
        # Ensure buffer_to_use is aligned with (compressed/full) state indexing
        if use_compressed:
            buffer_to_use = convert_buffer_to_compressed(buffer, state_to_compressed)
        else:
            buffer_to_use = buffer

        V_base = np.zeros((k_base, state_space_size), dtype=np.float32)
        # Reuse sr_epochs as the number of TD sweeps over the buffer.
        for epoch in range(int(sr_epochs)):
            indices = np.random.permutation(len(buffer_to_use))
            for idx in indices:
                s, a, sn, done = buffer_to_use[idx]
                s = int(s)
                sn = int(sn)
                r_vec = reward_weights[:, s]  # (k_base,)
                if done:
                    target = r_vec
                else:
                    target = r_vec + gamma * V_base[:, sn]
                delta = target - V_base[:, s]
                V_base[:, s] += alpha * delta

    # ------- Phase-1.5 : derive V_i(s) and VPS features φ_i(s) --------

    # Build VPS feature φ_i(s) as expected squared value-change:
    #   φ_i(s) ≈ E[(V_i(s') − V_i(s))^2 | s]
    phi_base = np.zeros_like(V_base)
    visit_counts = np.zeros(state_space_size, dtype=np.int32)

    # ------- Phase-1.5 : compute VPS features φ_i(s) with random sampling -------
    # Expand V_base to full state space for visualization (if using compressed)
    if use_compressed:
        V_base_full = expand_compressed_features(V_base, compressed_to_state, N)
    else:
        V_base_full = V_base
    
    # Create VPS-specific save directory
    vps_save_dir = os.path.join(save_dir, "vps") if (visualize and save_dir is not None) else None
    if vps_save_dir is not None:
        os.makedirs(vps_save_dir, exist_ok=True)
    
    # Track total steps for strict viz_freq visualization
    # Start from 0 for VPS training (SR training steps are separate)
    total_vps_steps = 0
    
    phi_step = 0
    for epoch in range(phi_epochs):
        # Random shuffle buffer for each epoch
        indices = np.random.permutation(len(buffer_to_use))
        
        for idx in indices:
            s, a, sn, done = buffer_to_use[idx]
            s = int(s)
            sn = int(sn)
            # For terminal states, V(sn) = 0 (no future value)
            V_sn = np.zeros_like(V_base[:, s]) if done else V_base[:, sn]
            dV = V_sn - V_base[:, s]   # (k_base,)
            td2 = dV ** 2
            phi_base[:, s] += td2
            visit_counts[s] += 1
            
            phi_step += 1
            total_vps_steps += 1
            
            # Visualize strictly every viz_freq steps
            if visualize and vps_save_dir is not None and total_vps_steps % viz_freq == 0:
                # Compute current phi (unnormalized for visualization)
                phi_current = phi_base.copy()
                non_zero_current = visit_counts > 0
                if np.any(non_zero_current):
                    phi_current[:, non_zero_current] /= visit_counts[non_zero_current]
                # Expand to full state space if using compressed
                if use_compressed:
                    phi_current = expand_compressed_features(
                        phi_current, compressed_to_state, N
                    )
                # 1. Periodic visualization of VPS function distribution
                visualize_feature_distribution(
                    phi_current, size, vps_save_dir, 
                    f"vps_phi_step{total_vps_steps}", viz_iteration, 
                    max_options_to_plot=min(4, k_base), env=env
                )
                # 2. Periodic visualization of value function V_base
                visualize_feature_distribution(
                    V_base_full, size, vps_save_dir, 
                    f"vps_V_step{total_vps_steps}", viz_iteration, 
                    max_options_to_plot=min(4, k_base), env=env
                )

    non_zero = visit_counts > 0
    if np.any(non_zero):
        phi_base[:, non_zero] /= visit_counts[non_zero]

    # Final visualization of VPS features after phi computation (only if not already saved at last step)
    if visualize and vps_save_dir is not None:
        # Expand to full state space if using compressed
        if use_compressed:
            phi_base_full = expand_compressed_features(phi_base, compressed_to_state, N)
        else:
            phi_base_full = phi_base
        
        # Only save final if it's not already saved at the last viz_freq step
        if total_vps_steps % viz_freq != 0:
            visualize_feature_distribution(
                phi_base_full, size, vps_save_dir, 
                f"vps_phi_step{total_vps_steps}", viz_iteration, 
                max_options_to_plot=min(4, k_base), env=env
            )
            visualize_feature_distribution(
                V_base_full, size, vps_save_dir, 
                f"vps_V_step{total_vps_steps}", viz_iteration, 
                max_options_to_plot=min(4, k_base), env=env
            )

    # ------- Phase-2 : learn Q options with VPS rewards (random sampling) ---------------
    Q = np.zeros((total_opts, state_space_size, num_actions), np.float32)
    for epoch in range(q_epochs):
        # Random shuffle buffer for each epoch
        indices = np.random.permutation(len(buffer_to_use))
        
        for idx in indices:
            s, a, sn, done = buffer_to_use[idx]
            s = int(s)
            sn = int(sn)
            r_vec = phi_base[:, sn] - phi_base[:, s]   # (+φ) reward
            for i, r in enumerate(r_vec):
                pos = i
                # Q-learning update: if done, no bootstrap
                if done:
                    target = r
                else:
                    target = r + gamma * Q[pos, sn].max()
                Q[pos, s, a] += alpha * (target - Q[pos, s, a])
                
                if sign:
                    neg = i + k_base
                    if done:
                        target_neg = -r
                    else:
                        target_neg = -r + gamma * Q[neg, sn].max()
                    Q[neg, s, a] += alpha * (target_neg - Q[neg, s, a])

    # Expand V and phi to full state space for return (for compatibility)
    if use_compressed:
        V_base_full = expand_compressed_features(V_base, compressed_to_state, N)
        phi_base_full = expand_compressed_features(phi_base, compressed_to_state, N)
        # Expand Q to full state space
        Q_full = np.zeros((total_opts, N, num_actions), dtype=np.float32)
        for opt_idx in range(total_opts):
            for compressed_idx in range(N_valid):
                full_idx = int(compressed_to_state[compressed_idx])
                Q_full[opt_idx, full_idx, :] = Q[opt_idx, compressed_idx, :]
        return Q_full, option_rows_dup(V_base_full, sign), option_rows_dup(phi_base_full, sign)
    else:
        # For logging / analysis, return both V and φ with positive/negative
        # duplication when sign=True to stay consistent with the old API.
        return Q, option_rows_dup(V_base, sign), option_rows_dup(phi_base, sign)


# ---------- 2. Eigen-Option ---------------------------------
def train_eigen_options(
    size,
    N,
    buffer,  # Shared buffer passed in (uses full state indices)
    *,
    k_base,
    sign,
    gamma,
    alpha,
    q_epochs=1,
    visualize=False,
    save_dir=None,
    viz_iteration=0,
    viz_freq=500,
    env=None,  # Optional: only needed for visualization
    state_to_compressed=None,  # Optional: dict mapping full index -> compressed index
    compressed_to_state=None,  # Optional: array mapping compressed index -> full index
    N_valid=None,  # Optional: number of valid states
):

    # Determine if using compressed state space
    use_compressed = (state_to_compressed is not None and 
                     compressed_to_state is not None and 
                     N_valid is not None)
    
    if use_compressed:
        print(f"[Eigen] Using compressed state space: {N_valid} valid states (reduced from {N})")
        # Convert buffer to compressed indices
        buffer_compressed = convert_buffer_to_compressed(buffer, state_to_compressed)
        print(f"[Eigen] Converted buffer: {len(buffer)} -> {len(buffer_compressed)} transitions")
        buffer_to_use = buffer_compressed
        state_space_size = N_valid
    else:
        print(f"[Eigen] Using full state space: {N} states")
        buffer_to_use = buffer
        state_space_size = N

    # ----- build a SYMMETRIC graph Laplacian for eigenoptions -----
    #
    # IMPORTANT:
    # KeyLock transitions are generally *non-reversible* (e.g., picking up a key
    # is irreversible), so the empirical transition matrix P (and L_rw=I-P) is
    # typically non-symmetric. Using `eigsh` on a non-symmetric matrix can
    # return meaningless results and yields poor eigenoptions.
    #
    # We instead build an undirected (symmetric) adjacency from transitions and
    # compute the symmetric normalized Laplacian:
    #   L = I - D^{-1/2} A D^{-1/2}
    #
    # where A is a symmetric weighted adjacency (we symmetrize directed counts).
    rows = np.fromiter((int(s) for s, _, _, _ in buffer_to_use), dtype=np.int64)
    cols = np.fromiter((int(sn) for _, _, sn, _ in buffer_to_use), dtype=np.int64)
    data = np.ones_like(rows, dtype=np.float32)
    A_dir = sp.coo_matrix((data, (rows, cols)), shape=(state_space_size, state_space_size), dtype=np.float32).tocsr()
    # Symmetrize to make the operator valid for `eigsh`
    A = A_dir + A_dir.T
    # Remove self-loops to avoid diagonal dominance
    A.setdiag(0.0)
    A.eliminate_zeros()

    deg = np.asarray(A.sum(axis=1)).ravel().astype(np.float32)
    deg = np.maximum(deg, 1e-8)
    Dinv_sqrt = sp.diags(1.0 / np.sqrt(deg), dtype=np.float32)
    L_rw = sp.eye(state_space_size, dtype=np.float32) - (Dinv_sqrt @ A @ Dinv_sqrt)

    k_need = k_base + 1  # skip the trivial eigenvector
    # Compute smallest eigenvalues (which="SM" = smallest magnitude)
    # The smallest eigenvalue corresponds to the stationary distribution
    _, vecs = eigsh(L_rw, k=k_need, which="SM")
    eig_vecs = vecs[:, 1:]  # Skip the first eigenvector (trivial constant vector)

    # ------- Visualization: plot eigenvector distribution ------------
    # Only visualize once after computing eigenvectors
    if visualize and save_dir is not None:
        eigen_save_dir = os.path.join(save_dir, "eigen")
        os.makedirs(eigen_save_dir, exist_ok=True)
        # Visualize first few eigenvectors (each as a separate figure)
        num_to_viz = min(4, k_base)
        for i in range(num_to_viz):
            eig_vec_viz = eig_vecs[:, i]
            # Expand to full state space if using compressed
            if use_compressed:
                eig_vec_viz = expand_compressed_features(
                    eig_vec_viz, compressed_to_state, N
                )
            visualize_feature_distribution(
                eig_vec_viz, size, eigen_save_dir, 
                f"eigen_vec{i}", viz_iteration, max_options_to_plot=1, env=env
            )

    total = k_base * (2 if sign else 1)
    num_actions = 6
    Q = np.zeros((total, state_space_size, num_actions), np.float32)

    opt_id = 0
    q_step = 0
    for phi in eig_vecs.T:
        for sg in ([1, -1] if sign else [1]):
            if opt_id >= total:
                break
            
            # Q-learning with random sampling
            for epoch in range(q_epochs):
                # Random shuffle buffer for each epoch
                indices = np.random.permutation(len(buffer_to_use))
                
                for idx in indices:
                    s, a, sn, done = buffer_to_use[idx]
                    s = int(s)
                    sn = int(sn)
                    r = sg * (phi[sn] - phi[s])
                    # Q-learning update: if done, no bootstrap
                    if done:
                        target = r
                    else:
                        target = r + gamma * Q[opt_id, sn].max()
                    Q[opt_id, s, a] += alpha * (target - Q[opt_id, s, a])
                    
                    q_step += 1
            opt_id += 1
        if opt_id >= total:
            break
    
    # Expand Q to full state space for return (for compatibility)
    if use_compressed:
        Q_full = np.zeros((total, N, num_actions), dtype=np.float32)
        for opt_idx in range(total):
            for compressed_idx in range(N_valid):
                full_idx = int(compressed_to_state[compressed_idx])
                Q_full[opt_idx, full_idx, :] = Q[opt_idx, compressed_idx, :]
        return Q_full
    else:
        return Q


# ------------------------------------------------------------
#  Random-Option (KeyLockEnv)
#  · Each option uses a random potential φ ~ N(0,1)
#  · shaped-reward  r = φ(sn) − φ(s)
# ------------------------------------------------------------
def train_random_options(
    size,
    N,
    buffer,  # Shared buffer passed in
    *,
    k_base,
    gamma,
    alpha,
    q_epochs=1,
    seed=None,
    visualize=False,
    save_dir=None,
    viz_iteration=0,
    env=None,
):
    rng = np.random.default_rng(seed)
    num_actions = 6

    phi = rng.standard_normal((k_base, N)).astype(np.float32)
    if visualize and save_dir is not None:
        random_save_dir = os.path.join(save_dir, "random")
        os.makedirs(random_save_dir, exist_ok=True)
        num_to_viz = min(4, k_base)
        for i in range(num_to_viz):
            phi_viz = phi[i]  # (N,)
            visualize_feature_distribution(
                phi_viz, size, random_save_dir,
                f"random_phi_initial_opt{i}", viz_iteration, max_options_to_plot=1, env=env
            )
        print(f"  [Viz] Saved initial random potential field visualization for {num_to_viz} options")

    Q = np.zeros((k_base, N, num_actions), np.float32)
    for epoch in range(q_epochs):
        # Random shuffle buffer for each epoch
        indices = np.random.permutation(len(buffer))
        
        for idx in indices:
            s, a, sn, done = buffer[idx]
            s = int(s)
            sn = int(sn)
            td = phi[:, sn] - phi[:, s]          # shaped reward
            if not done:
                td += gamma * Q[:, sn].max(1)    # bootstrap only if not done
            td -= Q[:, s, a]
            Q[:, s, a] += alpha * td

    return Q, phi


def main():
    pa = argparse.ArgumentParser()
    g = pa.add_argument
    g(
        "--num_opts",
        type=str,
        default="[4,8,16]",
        help='Number(s) of options to train. Examples: "4", "2,4,6", "[2,4,6]".',
    )
    g("--sign", type=bool, default=True)
    g("--gamma", type=float, default=0.999)
    g("--alpha", type=float, default=0.1, help="Learning rate (reduced for larger state space for more stable SR learning)")
    g("--collect_ep", type=int, default=2000, help="random-walk episodes (increased for larger state space)")
    g("--ep_len", type=int, default=500, help="steps per episode (increased for larger state space)")
    g("--seed", type=int, default=0)
    g("--save_dir", default="keylock_option_results")
    g("--only", choices=["vps", "eigen", "rand", "all"], default="eigen")
    g(
        "--outer_num",
        type=int,
        default=5,
        help="number of independent option sets",
    )
    g(
        "--outer_idx",
        type=int,
        default=None,
        help=(
            "Only train a single outer index m (0-based), and save files with that suffix. "
            "Example: --outer_idx 4 produces '*_4.npy'. If omitted, trains m=0..outer_num-1."
        ),
    )
    g("--size", type=int, default=15, help="grid size")
    g("--yellow_key_pos", type=int, nargs=2, default=[12, 3], help="yellow key position (x, y)")
    g("--yellow_door_pos", type=int, nargs=2, default=[3, 8], help="yellow door position (x, y)")
    g("--blue_key_pos", type=int, nargs=2, default=[12, 12], help="blue key position (x, y)")
    g("--blue_door_pos", type=int, nargs=2, default=[9, 3], help="blue door position (x, y)")
    g("--goal_pos", type=int, nargs=2, default=[3, 12], help="goal position (x, y)")
    g("--visualize", type=bool, default=True, help="Enable visualization of VPS/Eigen features during training")
    g("--viz_freq", type=int, default=5000000, help="Visualization frequency (every N training steps, default: 500)")
    g("--use_intrinsic_buffer", action="store_true", help="Use intrinsic reward exploration for buffer collection")
    g("--alpha_m", type=float, default=0.1, help="Learning rate for M matrix in intrinsic reward exploration")
    g("--alpha_q_buffer", type=float, default=0.1, help="Learning rate for Q function in intrinsic reward exploration")
    g("--epsilon_buffer", type=float, default=0.1, help="Epsilon for epsilon-greedy in intrinsic reward exploration")
    g("--sr_epochs", type=int, default=100, help="Number of epochs for SR learning (default: 50, increased for larger state space)")
    g("--phi_epochs", type=int, default=5, help="Number of epochs for phi computation (default: 2)")
    g("--q_epochs", type=int, default=20, help="Number of epochs for Q-learning (default: 20)")
    g("--sr_lambda", type=float, default=0.0, help="TD(λ) parameter for SR learning (0.0 = TD(0), >0 = TD(λ), default: 0.9 for faster convergence)")
    g(
        "--value_type",
        type=str,
        default="td",
        choices=["sr", "td"],
        help=(
            "How to obtain the k value functions for VPS. "
            "'sr' (default): learn successor representation psi(s) then V_i(s)=w_i^T psi(s). "
            "'td': learn V_i(s) directly with TD(0) under k orthogonalized random Gaussian state rewards r_i(s)=w_i[s]."
        ),
    )
    g("--transition_matrix_dir", type=str, default="keylock_transition_matrix", 
      help="Directory containing transition matrix files (T_next.npy, T_done.npy, valid_states.npy)")
    g("--generate_transition_matrix", action="store_true", 
      help="Force regenerate transition matrix even if it exists")
    args = pa.parse_args()

    N = args.size * args.size * 4 * 2 * 2 * 2 * 2
    script_dir = os.path.dirname(os.path.abspath(__file__))
    save_dir = os.path.join(script_dir, args.save_dir)
    os.makedirs(save_dir, exist_ok=True)
    
    transition_matrix_dir = os.path.join(script_dir, args.transition_matrix_dir)
    T_next_path = os.path.join(transition_matrix_dir, "T_next.npy")
    T_done_path = os.path.join(transition_matrix_dir, "T_done.npy")
    valid_states_path = os.path.join(transition_matrix_dir, "valid_states.npy")
    
    if os.path.exists(T_next_path) and os.path.exists(T_done_path) and os.path.exists(valid_states_path) and not args.generate_transition_matrix:
        print(f"[Transition Matrix] Loading from {transition_matrix_dir}...")
        T_next = np.load(T_next_path)
        T_done = np.load(T_done_path)
        valid_states = np.load(valid_states_path)
        print(f"[Transition Matrix] Loaded: T_next shape {T_next.shape}, {len(valid_states)} valid states")
    else:
        if args.generate_transition_matrix and os.path.exists(T_next_path):
            print(f"[Transition Matrix] Force regenerating transition matrix...")
        else:
            print(f"[Transition Matrix] Not found. Generating transition matrix...")
        env = KeyLockEnv(
            size=args.size,
            agent_start_pos=(1, 1),
            agent_start_dir=0,
            yellow_key_pos=tuple(args.yellow_key_pos),
            yellow_door_pos=tuple(args.yellow_door_pos),
            blue_key_pos=tuple(args.blue_key_pos),
            blue_door_pos=tuple(args.blue_door_pos),
            goal_pos=tuple(args.goal_pos),
            render_mode=None,
        )
        env.reset()
        
        from generate_keylock_transition_matrix import build_keylock_transition_matrix
        T_next, T_done, valid_states_set = build_keylock_transition_matrix(
            env, args.size,
            tuple(args.yellow_key_pos),
            tuple(args.yellow_door_pos),
            tuple(args.blue_key_pos),
            tuple(args.blue_door_pos),
            tuple(args.goal_pos),
        )
        valid_states = np.array(list(valid_states_set), dtype=np.int32)
        
        os.makedirs(transition_matrix_dir, exist_ok=True)
        np.save(T_next_path, T_next)
        np.save(T_done_path, T_done)
        np.save(valid_states_path, valid_states)
        print(f"[Transition Matrix] Saved to {transition_matrix_dir}/")
    
    # Build state index mapping for compressed state space
    state_to_compressed, compressed_to_state, N_valid = build_state_index_mapping(valid_states, N)
    
    # Create environment only for visualization if needed
    env = None
    if args.visualize:
        env = KeyLockEnv(
            size=args.size,
            agent_start_pos=(1, 1),
            agent_start_dir=0,
            yellow_key_pos=tuple(args.yellow_key_pos),
            yellow_door_pos=tuple(args.yellow_door_pos),
            blue_key_pos=tuple(args.blue_key_pos),
            blue_door_pos=tuple(args.blue_door_pos),
            goal_pos=tuple(args.goal_pos),
            render_mode=None,
        )
        env.reset()

    print(f"[Buffer] Collecting buffer with {args.collect_ep} episodes, max_len={args.ep_len}...")
    if args.use_intrinsic_buffer:
        buffer, _, M_buffer, Q_buffer = collect_buffer_intrinsic_reward(
            T_next, T_done, valid_states, N,
            episodes=args.collect_ep, max_len=args.ep_len,
            alpha_m=args.alpha_m,
            alpha_q=args.alpha_q_buffer,
            gamma=args.gamma,
            epsilon=args.epsilon_buffer,
            size=args.size,
        )
        print(f"[Buffer] Collected {len(buffer)} transitions using intrinsic reward exploration")
    else:
        buffer, _ = collect_buffer(
            T_next, T_done, valid_states, 
            episodes=args.collect_ep, max_len=args.ep_len, size=args.size
        )
        print(f"[Buffer] Collected {len(buffer)} transitions using random walk")

    psi_shared = None
    state_space_size_shared = None
    use_compressed_shared = None
    if args.only in ("vps", "all") and args.value_type == "sr":
        print(f"[SR] Training shared successor representation (will be reused for all VPS option sets)...")
        set_seed(int(args.seed))
        num_list = _parse_num_opts_list(args.num_opts)
        k_base_viz = min(num_list) if num_list else 4
        print(f"[SR] Using k_base={k_base_viz} for visualization (minimum from num_opts)")

        psi_shared, state_space_size_shared, use_compressed_shared = train_successor_representation(
            buffer,
            gamma=args.gamma,
            alpha=args.alpha,
            sr_epochs=args.sr_epochs,
            sr_lambda=args.sr_lambda,
            state_to_compressed=state_to_compressed,
            compressed_to_state=compressed_to_state,
            N_valid=N_valid,
            N=N,
            visualize=args.visualize,
            save_dir=save_dir,
            viz_freq=args.viz_freq,
            env=env,
            size=args.size,
            k_base_viz=k_base_viz,
            viz_iteration=0,
            valid_states=set(valid_states.tolist()) if valid_states is not None else None,
        )
        print(f"[SR] Shared SR training completed (shape: {psi_shared.shape})")
    elif args.only in ("vps", "all") and args.value_type == "td":
        print("[VPS] value_type=td → skipping shared SR pre-training")

    num_list = _parse_num_opts_list(args.num_opts)
    for k_base in num_list:
        outer_indices = [int(args.outer_idx)] if args.outer_idx is not None else range(args.outer_num)
        for m in outer_indices:
            if m < 0:
                raise ValueError(f"--outer_idx must be >=0, got {m}")
            if args.outer_idx is not None and m >= args.outer_num:
                raise ValueError(f"--outer_idx={m} out of range for --outer_num={args.outer_num}")
            set_seed(int(args.seed) + 10000 * int(k_base) + m)
            if args.only in ("vps", "all"):
                print(f"[VPS] k={k_base} training {m} …")
                Q_vps, V_vps, phi_vps = train_vps_options(
                    args.size,
                    N,
                    buffer,  # Pass shared buffer (uses full state indices)
                    k_base=k_base,
                    sign=args.sign,
                    gamma=args.gamma,
                    alpha=args.alpha,
                    sr_epochs=args.sr_epochs,  # Not used if psi is provided
                    phi_epochs=args.phi_epochs,
                    q_epochs=args.q_epochs,
                    sr_lambda=args.sr_lambda,  # Not used if psi is provided
                    value_type=args.value_type,
                    visualize=args.visualize,
                    save_dir=save_dir,
                    viz_iteration=m,
                    viz_freq=args.viz_freq,
                    env=env,  # Pass env only for visualization
                    state_to_compressed=state_to_compressed,
                    compressed_to_state=compressed_to_state,
                    N_valid=N_valid,
                    psi=(psi_shared if args.value_type == "sr" else None),
                    state_space_size=(state_space_size_shared if args.value_type == "sr" else None),
                    use_compressed=(use_compressed_shared if args.value_type == "sr" else None),
                )
                total = Q_vps.shape[0]
                fn = f"keylock_{total}_VPSOpt_{m}.npy"
                np.save(os.path.join(save_dir, fn), Q_vps)
                print("  saved", fn)

            if args.only in ("eigen", "all"):
                print(f"[Eigen] k={k_base} training {m} …")
                Q_eig = train_eigen_options(
                    args.size,
                    N,
                    buffer,  # Pass shared buffer (uses full state indices)
                    k_base=k_base,
                    sign=args.sign,
                    gamma=args.gamma,
                    alpha=args.alpha,
                    q_epochs=args.q_epochs,
                    visualize=args.visualize,
                    save_dir=save_dir,
                    viz_iteration=m,
                    viz_freq=args.viz_freq,
                    env=env,  # Pass env only for visualization
                    state_to_compressed=state_to_compressed,
                    compressed_to_state=compressed_to_state,
                    N_valid=N_valid,
                )
                total = Q_eig.shape[0]
                fn = f"keylock_{total}_EigenOpt_{m}.npy"
                np.save(os.path.join(save_dir, fn), Q_eig)
                print("  saved", fn)

            if args.only in ("rand", "all"):
                print(f"[Random] k={k_base} training {m} …")
                total = k_base * (2 if args.sign else k_base)
                Q_rnd, phi_rnd = train_random_options(
                    args.size,
                    N,
                    buffer,  # Pass shared buffer
                    k_base=total,
                    gamma=args.gamma,
                    alpha=args.alpha,
                    q_epochs=args.q_epochs,
                    visualize=args.visualize,
                    save_dir=save_dir,
                    viz_iteration=m,
                    env=env,
                )
                fn = f"keylock_{total}_RandomOption_{m}.npy"
                np.save(os.path.join(save_dir, fn), Q_rnd)
                print("  saved", fn)


if __name__ == "__main__":
    main()
