import os
import sys
import numpy as np
import torch
import matplotlib.pyplot as plt
import argparse
import time

# Ensure project root is on sys.path
_THIS_DIR = os.path.dirname(os.path.abspath(__file__))
_ROOT_DIR = os.path.dirname(_THIS_DIR)
if _ROOT_DIR not in sys.path:
    sys.path.insert(0, _ROOT_DIR)

from visual_gridworld.fpvr_agent import FPVRVisualAgent
from visual_gridworld.config import Config, Presets

# Optional dependency: minigrid stack (allow `--help` without installing)
try:
    from visual_gridworld.visual_minigrid import SimpleEnv  # type: ignore
except Exception:
    SimpleEnv = None  # type: ignore


# ============================================================================
# Environment and preprocessing
# ============================================================================

def make_env(seed=0, size=10):
    """Create the environment."""
    if SimpleEnv is None:
        raise ModuleNotFoundError(
            "Missing optional dependency for visual_gridworld environment (minigrid). "
            "Please install it (e.g., `pip install minigrid`) to run visual_gridworld."
        )
    env = SimpleEnv(size=size, render_mode="rgb_array", highlight=False)
    env.reset(seed=seed)
    return env


def preprocess(obs, max_side=128):
    """Preprocess observation: RGB -> grayscale single-channel [1,H',W']."""
    if isinstance(obs, dict):
        img = obs.get("image", None)
        if img is None:
            raise ValueError("Expected RGB array or dict with key 'image'")
    else:
        img = obs

    h, w = img.shape[:2]
    # Downsample to limit resolution
    step = max(1, int(np.ceil(max(h, w) / float(max_side))))
    if step > 1:
        img = img[::step, ::step]

    # RGB -> grayscale
    if img.ndim == 3 and img.shape[2] >= 3:
        gray = (0.299 * img[..., 0] + 0.587 * img[..., 1] + 0.114 * img[..., 2]).astype(np.uint8)
    else:
        gray = img.astype(np.uint8)

    return gray[None, :, :]


# ============================================================================
# FPVR training
# ============================================================================

def run_fpvr_training(env, agent, steps, k_train, visualize=False, render_delay=0.05, feat_dim=None, env_seed=0, reset_interval=0):
    """Run FPVR training and collect statistics."""
    env_vis = None
    if visualize:
        if SimpleEnv is None:
            raise ModuleNotFoundError(
                "Missing optional dependency for visual_gridworld environment (minigrid). "
                "Please install it (e.g., `pip install minigrid`) to use visualization."
            )
        env_vis = SimpleEnv(size=env.height, render_mode="human", highlight=False)
        env_vis.reset(seed=env_seed)
        env_vis.render()

    visited_cum = set()
    visited_win = set()
    counts = np.zeros((env.height, env.width), dtype=np.int32)
    coverage = []
    coverage_cum = []
    coverage_win = []
    obs_at_pos = {}
    sr_losses, inv_dyn_losses = [], []
    vic_losses, vic_var_losses, vic_cov_losses = [], [], []

    coverage_intervals = [] if reset_interval > 0 else None
    current_interval_coverage = [] if reset_interval > 0 else None

    obs, _ = env.reset()
    start_pos = tuple(env.agent_pos)
    visited_cum.add(start_pos)
    visited_win.add(start_pos)
    counts[env.agent_pos[1], env.agent_pos[0]] += 1
    coverage_cum.append(len(visited_cum))
    coverage_win.append(len(visited_win))
    coverage.append(len(visited_win) if reset_interval > 0 else len(visited_cum))

    train_counter = 0
    print_interval = int(getattr(config, "print_interval", 0)) if 'config' in globals() else 0

    for t in range(steps):
        x = preprocess(obs)
        a, phi = agent.act(x, greedy=False)
        obs_next, _, terminated, truncated, _ = env.step(a)

        if env_vis is not None:
            env_vis.agent_pos = env.agent_pos
            env_vis.agent_dir = env.agent_dir
            env_vis.render()
            if render_delay > 0:
                time.sleep(render_delay)

        x_next = preprocess(obs_next)
        a_next, _ = agent.act(x_next, greedy=False)

        with torch.no_grad():
            agent.update_c_from_phi(phi, lambda_c=config.lambda_c)

        agent.store(
            x,
            a,
            x_next,
            done=(terminated or truncated),
            next_action=(a_next if agent.sf_target_mode == "current_policy" else None),
        )

        for _ in range(int(k_train)):
            ret = agent.train_step(step_idx=train_counter)
            train_counter += 1
            if ret is not None:
                sr_losses.append(float(ret))

        if print_interval > 0 and (t + 1) % print_interval == 0 and len(sr_losses) > 0:
            print(f"[Train] env_step={t+1} train_updates={train_counter} SR={sr_losses[-1]:.6f}")

        pos = tuple(env.agent_pos)
        obs_at_pos[pos] = x_next
        visited_cum.add(pos)
        visited_win.add(pos)
        counts[env.agent_pos[1], env.agent_pos[0]] += 1
        coverage_cum.append(len(visited_cum))
        coverage_win.append(len(visited_win))
        coverage.append(len(visited_win) if reset_interval > 0 else len(visited_cum))

        if reset_interval > 0:
            current_interval_coverage.append(len(visited_win))

        if reset_interval > 0 and (t + 1) % reset_interval == 0:
            print(f"\n[Step {t+1}] Reset windowed coverage | windowed={len(visited_win)} | total_visits={counts.sum()}")
            coverage_intervals.append(current_interval_coverage.copy())
            visited_win = set()
            visited_win.add(tuple(env.agent_pos))
            current_interval_coverage = []

        obs = obs_next
        if terminated or truncated:
            obs, _ = env.reset()
            if env_vis is not None:
                env_vis.reset()
                env_vis.render()
                if render_delay > 0:
                    time.sleep(render_delay)
            pos0 = tuple(env.agent_pos)
            visited_cum.add(pos0)
            visited_win.add(pos0)
            counts[env.agent_pos[1], env.agent_pos[0]] += 1

    if env_vis is not None:
        env_vis.close()

    if reset_interval > 0 and len(current_interval_coverage) > 0:
        coverage_intervals.append(current_interval_coverage)

    result = {
        'coverage': coverage,
        'coverage_cumulative': coverage_cum,
        'coverage_windowed': coverage_win,
        'counts': counts,
        'obs_at_pos': obs_at_pos,
        'sr_losses': sr_losses,
    }
    if reset_interval > 0:
        result['coverage_intervals'] = coverage_intervals
        result['reset_interval'] = reset_interval
    return result


# ============================================================================
# Baselines
# ============================================================================

def run_random_baseline(env, steps, reset_interval=0):
    """Run a random-walk baseline."""
    visited_cum = set()
    visited_win = set()
    counts = np.zeros((env.height, env.width), dtype=np.int32)
    coverage = []
    coverage_cum = []
    coverage_win = []

    coverage_intervals = [] if reset_interval > 0 else None
    current_interval_coverage = [] if reset_interval > 0 else None

    obs, _ = env.reset()
    start_pos = tuple(env.agent_pos)
    visited_cum.add(start_pos)
    visited_win.add(start_pos)
    counts[env.agent_pos[1], env.agent_pos[0]] += 1
    coverage_cum.append(len(visited_cum))
    coverage_win.append(len(visited_win))
    coverage.append(len(visited_win) if reset_interval > 0 else len(visited_cum))

    for t in range(steps):
        a = env.action_space.sample()
        obs, _, terminated, truncated, _ = env.step(a)
        pos = tuple(env.agent_pos)
        visited_cum.add(pos)
        visited_win.add(pos)
        counts[env.agent_pos[1], env.agent_pos[0]] += 1
        coverage_cum.append(len(visited_cum))
        coverage_win.append(len(visited_win))
        coverage.append(len(visited_win) if reset_interval > 0 else len(visited_cum))

        if reset_interval > 0:
            current_interval_coverage.append(len(visited_win))

        if reset_interval > 0 and (t + 1) % reset_interval == 0:
            print(f"[Random] Step {t+1}: Reset windowed coverage | windowed={len(visited_win)}")
            coverage_intervals.append(current_interval_coverage.copy())
            visited_win = set()
            visited_win.add(tuple(env.agent_pos))
            current_interval_coverage = []

        if terminated or truncated:
            obs, _ = env.reset()
            pos0 = tuple(env.agent_pos)
            visited_cum.add(pos0)
            visited_win.add(pos0)
            counts[env.agent_pos[1], env.agent_pos[0]] += 1

    if reset_interval > 0 and len(current_interval_coverage) > 0:
        coverage_intervals.append(current_interval_coverage)

    result = {
        'coverage': coverage,
        'coverage_cumulative': coverage_cum,
        'coverage_windowed': coverage_win,
        'counts': counts,
    }
    if reset_interval > 0:
        result['coverage_intervals'] = coverage_intervals
        result['reset_interval'] = reset_interval
    return result


def run_tabular_fpvr(env, steps, n_actions, feat_dim, sf_gamma, beta, reset_interval=0):
    """Run tabular FPVR (choose actions by minimizing future-past redundancy)."""
    visited_cum = set()
    visited_win = set()
    counts = np.zeros((env.height, env.width), dtype=np.int32)
    coverage = []
    coverage_cum = []
    coverage_win = []

    coverage_intervals = [] if reset_interval > 0 else None
    current_interval_coverage = [] if reset_interval > 0 else None

    M = np.zeros((feat_dim, n_actions, feat_dim), dtype=np.float32)
    C = np.zeros((feat_dim,), dtype=np.float32)
    alpha_sr = 0.1
    lambda_c = 0.999

    def pos_to_idx(y, x):
        return y * env.width + x

    def choose_action(s_idx):
        costs = np.empty((n_actions,), dtype=np.float32)
        c_norm = float(np.linalg.norm(C) + 1e-12)
        for a in range(n_actions):
            m = M[s_idx, a, :]
            m_norm = float(np.linalg.norm(m) + 1e-12)
            cos = float(np.dot(m, C) / (m_norm * c_norm))
            costs[a] = (cos + 1.0) * m_norm

        mean_cost = np.mean(costs)
        std_cost = np.std(costs)
        costs = (costs - mean_cost) / (std_cost + 1e-8)

        logits = -beta * (costs - costs.min())
        logits = logits - logits.max()
        expv = np.exp(logits)
        probs = expv / np.clip(expv.sum(), 1e-8, None)
        return int(np.random.choice(n_actions, p=probs))

    obs, _ = env.reset()
    start_pos = tuple(env.agent_pos)
    visited_cum.add(start_pos)
    visited_win.add(start_pos)
    counts[env.agent_pos[1], env.agent_pos[0]] += 1
    coverage_cum.append(len(visited_cum))
    coverage_win.append(len(visited_win))
    coverage.append(len(visited_win) if reset_interval > 0 else len(visited_cum))

    prev_s_idx, prev_a = None, None

    for t in range(steps):
        ty, tx = env.agent_pos[1], env.agent_pos[0]
        s_idx = pos_to_idx(ty, tx)

        C *= lambda_c
        C[s_idx] += 1.0

        a = choose_action(s_idx)
        obs, _, terminated, truncated, _ = env.step(a)

        ty_next, tx_next = env.agent_pos[1], env.agent_pos[0]
        s_idx_next = pos_to_idx(ty_next, tx_next)

        if prev_s_idx is not None and prev_a is not None:
            e_cur = np.zeros((feat_dim,), dtype=np.float32)
            e_cur[s_idx_next] = 1.0
            psi_exp = M[s_idx_next, :, :].mean(axis=0)
            done_flag = float(terminated or truncated)
            target = e_cur + sf_gamma * (1.0 - done_flag) * psi_exp
            M[prev_s_idx, prev_a, :] = (1.0 - alpha_sr) * M[prev_s_idx, prev_a, :] + alpha_sr * target

        prev_s_idx, prev_a = s_idx, a

        pos = tuple(env.agent_pos)
        visited_cum.add(pos)
        visited_win.add(pos)
        counts[env.agent_pos[1], env.agent_pos[0]] += 1
        coverage_cum.append(len(visited_cum))
        coverage_win.append(len(visited_win))
        coverage.append(len(visited_win) if reset_interval > 0 else len(visited_cum))

        if reset_interval > 0:
            current_interval_coverage.append(len(visited_win))

        if reset_interval > 0 and (t + 1) % reset_interval == 0:
            print(f"[Tabular FPVR] Step {t+1}: Reset windowed coverage | windowed={len(visited_win)}")
            coverage_intervals.append(current_interval_coverage.copy())
            visited_win = set()
            visited_win.add(tuple(env.agent_pos))
            C.fill(0.0)
            C[s_idx] = 1.0
            current_interval_coverage = []

        if terminated or truncated:
            obs, _ = env.reset()
            pos0 = tuple(env.agent_pos)
            visited_cum.add(pos0)
            visited_win.add(pos0)
            counts[env.agent_pos[1], env.agent_pos[0]] += 1
            prev_s_idx, prev_a = None, None

    if reset_interval > 0 and len(current_interval_coverage) > 0:
        coverage_intervals.append(current_interval_coverage)

    result = {
        'coverage': coverage,
        'coverage_cumulative': coverage_cum,
        'coverage_windowed': coverage_win,
        'counts': counts,
    }
    if reset_interval > 0:
        result['coverage_intervals'] = coverage_intervals
        result['reset_interval'] = reset_interval
    return result


## NOTE:
# We intentionally expose only FPVR naming in the supplementary package to
# avoid confusion from historical internal names used during development.


# ============================================================================
# Visualization
# ============================================================================

def visualize_psi_maps(agent, obs_at_pos, env, k=6):
    """Visualize spatial maps of successor features ψ."""
    keys = list(obs_at_pos.keys())
    if len(keys) == 0:
        return

    k = min(k, len(keys))
    sel_idx = np.random.choice(len(keys), size=k, replace=False)
    sel_keys = [keys[i] for i in sel_idx]

    ncols = min(3, k)
    nrows = int(np.ceil(k / ncols))
    fig, axes = plt.subplots(nrows, ncols, figsize=(4 * ncols, 4 * nrows))

    if nrows == 1 and ncols == 1:
        axes = np.array([[axes]])
    elif nrows == 1:
        axes = np.array([axes])

    ax_list = axes.flatten()

    for i, key in enumerate(sel_keys):
        x_img = obs_at_pos[key]
        with torch.no_grad():
            xt = torch.from_numpy(x_img[None]).to(agent.device).float()
            phi_raw = agent.model(xt)
            phi_tilde = agent._apply_whitening(phi_raw)  # ZCA whitening

            psi_all = agent.model.psi_head(phi_tilde).view(-1, agent.n_actions, phi_tilde.size(-1))
            psi_exp = psi_all.mean(dim=1).squeeze(0).cpu().numpy()

        psi_map = psi_exp.reshape(env.height, env.width)
        ax = ax_list[i]
        im = ax.imshow(psi_map, origin='lower', cmap='magma')
        ax.scatter([key[0]], [key[1]], c='cyan', marker='x', s=80)
        ax.set_title(f"ψ at pos {key}")
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

    plt.tight_layout()
    plt.show()


def compute_psi_dispersion(agent, obs_at_pos, env):
    """Compute and visualize dispersion statistics of ψ."""
    keys = list(obs_at_pos.keys())
    if len(keys) == 0:
        return None

    psi_means = []
    within_spreads = []
    within_spreads_norm = []
    eps = 1e-8

    with torch.no_grad():
        for key in keys:
            x_img = obs_at_pos[key]
            xt = torch.from_numpy(x_img[None]).to(agent.device).float()
            phi_raw = agent.model(xt)
            phi_tilde = agent._apply_whitening(phi_raw)  # ZCA whitening

            psi_all = agent.model.psi_head(phi_tilde).view(-1, agent.n_actions, phi_tilde.size(-1))
            psi_all_np = psi_all.squeeze(0).cpu().numpy()  # [A, D]

            psi_mean = psi_all_np.mean(axis=0)
            psi_means.append(psi_mean)

            diffs = psi_all_np - psi_mean[None, :]
            spread = float(np.sqrt((diffs ** 2).sum(axis=1).mean()))
            within_spreads.append(spread)

            psi_all_norm = psi_all_np / (np.linalg.norm(psi_all_np, axis=1, keepdims=True) + eps)
            mean_norm = psi_all_norm.mean(axis=0)
            diffs_norm = psi_all_norm - mean_norm[None, :]
            spread_norm = float(np.sqrt((diffs_norm ** 2).sum(axis=1).mean()))
            within_spreads_norm.append(spread_norm)

    psi_means = np.stack(psi_means, axis=0)
    global_mean = psi_means.mean(axis=0)

    across = np.linalg.norm(psi_means - global_mean[None, :], axis=1)
    psi_means_norm = psi_means / (np.linalg.norm(psi_means, axis=1, keepdims=True) + eps)
    global_mean_norm = psi_means_norm.mean(axis=0)
    across_norm = np.linalg.norm(psi_means_norm - global_mean_norm[None, :], axis=1)

    print(f"[FPVR] Across-state ψ dispersion: mean={across.mean():.4f}, std={across.std():.4f}")
    print(f"[FPVR] Within-state ψ dispersion: mean={np.mean(within_spreads):.4f}, std={np.std(within_spreads):.4f}")
    print(f"[FPVR] Across-state (norm.) ψ dispersion: mean={across_norm.mean():.4f}, std={across_norm.std():.4f}")
    print(f"[FPVR] Within-state (norm.) ψ dispersion: mean={np.mean(within_spreads_norm):.4f}, std={np.std(within_spreads_norm):.4f}")

    fig, axs = plt.subplots(2, 2, figsize=(12, 8))

    axs[0, 0].hist(across, bins=30, alpha=0.7)
    axs[0, 0].set_xlabel('L2 distance to global mean')
    axs[0, 0].set_ylabel('Count')
    axs[0, 0].set_title('Across-state dispersion')
    axs[0, 0].grid(alpha=0.3)

    axs[0, 1].hist(within_spreads, bins=30, alpha=0.7, color='tab:orange')
    axs[0, 1].set_xlabel('Within-state mean L2')
    axs[0, 1].set_ylabel('Count')
    axs[0, 1].set_title('Within-state action dispersion')
    axs[0, 1].grid(alpha=0.3)

    axs[1, 0].hist(across_norm, bins=30, alpha=0.7, color='tab:green')
    axs[1, 0].set_xlabel('L2 distance (normalized)')
    axs[1, 0].set_ylabel('Count')
    axs[1, 0].set_title('Across-state dispersion (norm.)')
    axs[1, 0].grid(alpha=0.3)

    axs[1, 1].hist(within_spreads_norm, bins=30, alpha=0.7, color='tab:red')
    axs[1, 1].set_xlabel('Within-state L2 spread (norm.)')
    axs[1, 1].set_ylabel('Count')
    axs[1, 1].set_title('Within-state dispersion (norm.)')
    axs[1, 1].grid(alpha=0.3)

    plt.tight_layout()
    plt.show()


def plot_results(results):
    """Plot all results."""
    plt.figure(figsize=(7, 4))
    for name, res in results.items():
        if 'coverage' in res:
            plt.plot(res['coverage'], label=name)
    plt.xlabel("Steps")
    plt.ylabel("Unique visited cells")
    plt.title("Coverage vs Steps (MiniGrid)")
    plt.grid(alpha=0.3)
    plt.legend()
    plt.tight_layout()
    plt.show()

    has_intervals = any('coverage_intervals' in res for res in results.values())
    if has_intervals:
        reset_interval = next((res['reset_interval'] for res in results.values() if 'reset_interval' in res), None)

        fig, axes = plt.subplots(1, 2, figsize=(16, 6))
        ax = axes[0]
        for name, res in results.items():
            if 'coverage_intervals' in res and len(res['coverage_intervals']) > 0:
                max_len = max(len(interval) for interval in res['coverage_intervals'])
                aligned_intervals = []
                for interval in res['coverage_intervals']:
                    if len(interval) < max_len:
                        padded = list(interval) + [interval[-1]] * (max_len - len(interval))
                    else:
                        padded = list(interval)
                    aligned_intervals.append(padded)

                aligned_array = np.array(aligned_intervals)
                mean_cov = aligned_array.mean(axis=0)
                std_cov = aligned_array.std(axis=0)

                steps = np.arange(len(mean_cov))
                ax.plot(steps, mean_cov, label=name, linewidth=2)
                ax.fill_between(steps, mean_cov - std_cov, mean_cov + std_cov, alpha=0.2)

        ax.set_xlabel('Steps (within interval)')
        ax.set_ylabel('# States Explored')
        ax.set_title(f'Average Coverage per Interval (reset every {reset_interval} steps)')
        ax.grid(alpha=0.3)
        ax.legend()

        ax = axes[1]
        for name, res in results.items():
            if 'coverage_intervals' in res:
                final_covs = [interval[-1] if len(interval) > 0 else 0 for interval in res['coverage_intervals']]
                ax.plot(final_covs, marker='o', label=name, alpha=0.7, linewidth=2, markersize=8)
        ax.set_xlabel('Interval Index')
        ax.set_ylabel('Final Coverage in Interval')
        ax.set_title('Final Coverage Comparison Across Intervals')
        ax.grid(alpha=0.3)
        ax.legend()

        plt.tight_layout()
        plt.show()

    heatmap_items = [(name, res['counts']) for name, res in results.items() if 'counts' in res]
    ncols = len(heatmap_items)
    if ncols > 0:
        fig, axes = plt.subplots(1, ncols, figsize=(7 * ncols, 4))
        if ncols == 1:
            axes = [axes]
        for ax, (title, mat) in zip(axes, heatmap_items):
            im = ax.imshow(mat, origin='lower', cmap='viridis')
            ax.set_title(f"{title} visitation count")
            fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
            ax.set_xlabel('X')
            ax.set_ylabel('Y')
        plt.tight_layout()
        plt.show()

    if 'FPVR' in results and 'sr_losses' in results['FPVR']:
        fig, axes = plt.subplots(1, 3, figsize=(15, 4))

        axes[0].plot(results['FPVR']['sr_losses'], alpha=0.7, color='blue')
        axes[0].set_title('SR Loss')
        axes[0].set_xlabel('Update step')
        axes[0].set_ylabel('Loss')
        axes[0].grid(alpha=0.3)

        if 'inv_dyn_losses' in results['FPVR']:
            axes[1].plot(results['FPVR']['inv_dyn_losses'], alpha=0.7, color='green')
            axes[1].set_title('Inverse Dynamics Loss')
            axes[1].set_xlabel('Update step')
            axes[1].set_ylabel('Loss')
            axes[1].grid(alpha=0.3)

        if 'vic_losses' in results['FPVR'] and len(results['FPVR']['vic_losses']) > 0:
            axes[2].plot(results['FPVR']['vic_losses'], alpha=0.7, color='purple', label='vic_total')
            if 'vic_var_losses' in results['FPVR']:
                axes[2].plot(results['FPVR']['vic_var_losses'], alpha=0.5, color='tab:orange', label='vic_var')
            if 'vic_cov_losses' in results['FPVR']:
                axes[2].plot(results['FPVR']['vic_cov_losses'], alpha=0.5, color='tab:red', label='vic_cov')
            axes[2].set_title('VICReg Loss')
            axes[2].set_xlabel('Update step')
            axes[2].set_ylabel('Loss')
            axes[2].grid(alpha=0.3)
            axes[2].legend()
        else:
            axes[2].axis('off')

        plt.tight_layout()
        plt.show()


# ============================================================================
# Main evaluation
# ============================================================================

def evaluate(config: Config = None):
    """Run FPVR and baselines, then plot comparisons."""
    if config is None:
        config = Config()

    print(config)
    time.sleep(2)

    env = make_env(config.env_seed, config.env_size)
    obs, _ = env.reset(seed=config.env_seed)
    x = preprocess(obs)
    c, h, w = x.shape
    n_actions = env.action_space.n
    feat_dim = env.height * env.width

    agent = FPVRVisualAgent(
        (c, h, w), n_actions,
        lr=config.lr,
        beta=config.beta,
        phi_dim=config.phi_dim,
        sf_gamma=config.sf_gamma,
        capacity=config.capacity,
        batch_size=config.batch_size,
        update_after=config.update_after,
        update_every=config.update_every,
        whitening_update_every=config.whitening_update_every,
        psi_dim=config.psi_dim,
        sf_target_mode=config.sf_target,
    )

    results = {}
    print("\nRunning FPVR...")
    fpvr_res = run_fpvr_training(env, agent, config.steps, config.k_train, config.visualize, config.render_delay, feat_dim, config.env_seed, config.reset_interval)
    results['FPVR'] = fpvr_res

    print("Running Random baseline...")
    env.reset(seed=config.env_seed)
    random_res = run_random_baseline(env, config.steps, config.reset_interval)
    results['Random'] = random_res

    print("Running Tabular FPVR...")
    env.reset(seed=config.env_seed + 4)
    tabular_res = run_tabular_fpvr(env, config.steps, n_actions, feat_dim, config.sf_gamma, config.beta, config.reset_interval)
    results['FPVR (tabular)'] = tabular_res

    env.close()
    plot_results(results)
    visualize_psi_maps(agent, fpvr_res['obs_at_pos'], env, k=6)
    compute_psi_dispersion(agent, fpvr_res['obs_at_pos'], env)


# ============================================================================
# CLI entrypoint
# ============================================================================

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='FPVR experiment (visual MiniGrid)')
    parser.add_argument('--config', type=str, help='Path to config JSON file')
    parser.add_argument('--preset', type=str, choices=['default', 'fast_test', 'high_exploration'],
                       help='Use a preset config')

    parser.add_argument('--beta', type=float, help='Exploration temperature')
    parser.add_argument('--lr', type=float, help='Learning rate')
    parser.add_argument('--steps', type=int, help='Number of environment steps')
    parser.add_argument('--print_interval', type=int, help='Print training loss every N env steps (0 disables)')
    parser.add_argument('--env_size', type=int, help='Grid size')
    parser.add_argument('--whitening_update_every', type=int, help='Whitening update frequency (in training steps)')
    parser.add_argument('--visualize', type=lambda x: str(x).lower() != 'false', help='Enable visualization')
    parser.add_argument('--sf_target', type=str, choices=['uniform_policy', 'current_policy', 'min_redundancy'], help='SR target mode')
    parser.add_argument('--reset_interval', type=int, help='Reset windowed coverage stats every K steps (0 disables)')

    args = parser.parse_args()

    if args.config:
        config = Config.from_json(args.config)
        print(f"Loaded config from file: {args.config}")
    elif args.preset:
        preset_map = {
            'default': Presets.default,
            'fast_test': Presets.fast_test,
            'high_exploration': Presets.high_exploration,
        }
        config = preset_map[args.preset]()
        print(f"Using preset config: {args.preset}")
    else:
        config = Config()
        print("Using default config")

    if args.beta is not None:
        config.beta = args.beta
    if args.lr is not None:
        config.lr = args.lr
    if args.steps is not None:
        config.steps = args.steps
    if args.print_interval is not None:
        config.print_interval = args.print_interval
    if args.env_size is not None:
        config.env_size = args.env_size
        config.phi_dim = args.env_size * args.env_size
        config.psi_dim = config.phi_dim
    if args.whitening_update_every is not None:
        config.whitening_update_every = args.whitening_update_every
    if args.visualize is not None:
        config.visualize = args.visualize
    if args.sf_target is not None:
        config.sf_target = args.sf_target
    if args.reset_interval is not None:
        config.reset_interval = args.reset_interval

    if config.reset_interval > 0:
        print(f"\n[Reset] Enabled: reset every {config.reset_interval} steps")
        print(f"  intervals = {config.steps // config.reset_interval}")

    evaluate(config)

