"""
Visual gridworld exploration comparison.

Methods compared (coverage = # unique visited states vs steps):
  1) FPVR (deep neural network version using FPVRVisualAgent)
  2) Tabular FPVR (tabular version from run.py)
  3) Random walk baseline
  4) SR (Successor Representation) + DQN  
  5) SP (Successor-Predecessor) + DQN

Each method runs on the same visual gridworld environment and we compare:
- Cumulative coverage curves
- Windowed coverage curves (periodic reset for visualization)
- State visitation heatmaps (position-based, ignoring agent orientation)

Usage:
  python exploration_comparison.py --total_steps 20000 --n_seeds 3 --coverage_reset_interval 2000
"""

from __future__ import annotations

import argparse
import json
import os
import subprocess
import sys
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.colors import ListedColormap

# Work around occasional Windows OpenMP runtime duplication (torch + MKL/numpy).
# Safe in this script context; avoids hard-crash on import for some environments.
if os.name == "nt" and "KMP_DUPLICATE_LIB_OK" not in os.environ:
    os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

import torch
# Ensure project root for absolute imports
_THIS_DIR = os.path.dirname(os.path.abspath(__file__))
_ROOT_DIR = os.path.dirname(_THIS_DIR)
if _ROOT_DIR not in sys.path:
    sys.path.insert(0, _ROOT_DIR)

# Local imports
try:
    from visual_minigrid_maze.visual_minigrid import SimpleEnv
    from visual_minigrid_maze.fpvr_agent import FPVRVisualAgent
    from visual_minigrid_maze.config import Config, Presets
except Exception:
    SimpleEnv = None
    FPVRVisualAgent = None
    Config = None


@dataclass
class ComparisonConfig:
    # Environment
    env_size: int = 20
    seed_base: int = 1
    n_seeds: int = 3
    
    # Training
    total_steps: int = 20000
    learning_starts: int = 100
    
    # Coverage analysis
    coverage_reset_interval: int = 2000  # Reset windowed coverage every K steps
    
    # Output
    out_dir: str = "runs/exploration_comparison"
    
    # Method-specific configs (will be passed to individual scripts)
    common_args: Dict[str, any] = None


# Consistent color mapping across all methods (matching fourrooms_exploration.py)
METHOD_COLORS = {
    "FPVR": "C0",              # blue - 1st position
    "Tabular FPVR": "C1",      # orange - 2nd position  
    "SP + DQN": "C2",          # green - 3rd position
    "SR + DQN": "C3",          # red - 4th position
    "Random Walk": "C4",       # purple - 5th position
}


def _lighten(color, amount=0.5):
    """Lighten a matplotlib color for confidence bands."""
    try:
        import matplotlib.colors as mc
        import colorsys
        c = mc.cnames.get(color, color)
        c = colorsys.rgb_to_hls(*mc.to_rgb(c))
        return colorsys.hls_to_rgb(c[0], 1 - amount * (1 - c[1]), c[2])
    except:
        return color


def load_coverage_data(cumulative_path: str, windowed_path: str, total_points: int) -> Tuple[np.ndarray, np.ndarray]:
    """
    Load cumulative and windowed coverage data from separate files.
    
    Args:
        cumulative_path: Path to cumulative coverage data (never resets)
        windowed_path: Path to windowed coverage data (with resets)  
        total_points: Expected length (typically total_steps+1)
        
    Returns:
        cov: cumulative coverage [total_points]
        cov_win: windowed coverage with periodic reset [total_points]
        
    Raises:
        FileNotFoundError: If required files are missing
        ValueError: If data shapes are incorrect
    """
    # Check file existence (fail fast if data is missing)
    if not os.path.exists(cumulative_path):
        raise FileNotFoundError(f"Missing cumulative coverage data: {cumulative_path}")
    
    if not os.path.exists(windowed_path):
        raise FileNotFoundError(f"Missing windowed coverage data: {windowed_path}")
    
    # Load the two data files
    cov = np.load(cumulative_path)
    cov_win = np.load(windowed_path)
    
    # Handle length mismatch (common off-by-one issues)
    def adjust_length(data, expected_length, data_name):
        if len(data) == expected_length:
            return data
        elif len(data) == expected_length + 1:
            # One extra data point: truncate to expected_length
            print(f"  Warning: {data_name} has {len(data)} points, truncating to {expected_length}")
            return data[:expected_length]
        elif len(data) == expected_length - 1:
            # One missing data point: pad with the last value
            print(f"  Warning: {data_name} has {len(data)} points, padding to {expected_length}")
            padded = np.zeros(expected_length, dtype=data.dtype)
            padded[:len(data)] = data
            padded[len(data):] = data[-1] if len(data) > 0 else 0
            return padded
        else:
            # Length mismatch is too large: raise an error
            raise ValueError(f"{data_name} length {len(data)} significantly differs from expected {expected_length}")
    
    cov = adjust_length(cov, total_points, "Cumulative coverage")
    cov_win = adjust_length(cov_win, total_points, "Windowed coverage")
    
    return cov, cov_win


def extract_position_counts(counts_npy_path: str, env_size: int) -> np.ndarray:
    """
    Extract position-based visit counts from counts.npy file.
    Returns [env_size, env_size] array of visit counts.
    """
    if not os.path.exists(counts_npy_path):
        # Return zeros silently to avoid excessive warnings
        return np.zeros((env_size, env_size))
    
    counts = np.load(counts_npy_path)
    
    # Ensure counts has correct shape
    if counts.shape != (env_size, env_size):
        print(f"Warning: counts shape {counts.shape} != expected {(env_size, env_size)}")
        resized = np.zeros((env_size, env_size))
        min_h = min(counts.shape[0], env_size)
        min_w = min(counts.shape[1], env_size)
        resized[:min_h, :min_w] = counts[:min_h, :min_w]
        counts = resized
    
    return counts


def sanity_check_outputs(
    *,
    method_name: str,
    cov: np.ndarray,
    cov_win: np.ndarray,
    pos_counts: np.ndarray,
    total_points: int,
    env_size: int,
    reset_interval: int,
) -> None:
    """Lightweight sanity checks with warnings (non-fatal)."""
    max_states = int(env_size * env_size)

    if cov.shape[0] != total_points:
        print(f"  Warning: {method_name} cumulative length={cov.shape[0]} expected={total_points}")
    if cov_win.shape[0] != total_points:
        print(f"  Warning: {method_name} windowed length={cov_win.shape[0]} expected={total_points}")

    if cov.size > 1 and np.any(np.diff(cov) < 0):
        print(f"  Warning: {method_name} cumulative coverage is not non-decreasing (unexpected).")

    if cov.size > 0 and int(np.max(cov)) > max_states:
        print(f"  Warning: {method_name} cumulative max={int(np.max(cov))} > env_size^2={max_states}")
    if cov_win.size > 0 and int(np.max(cov_win)) > max_states:
        print(f"  Warning: {method_name} windowed max={int(np.max(cov_win))} > env_size^2={max_states}")

    if reset_interval > 0 and cov_win.size > (reset_interval + 1):
        # Expect a drop right after each boundary (reset occurs after boundary step is recorded)
        for boundary in range(reset_interval, min(total_points - 1, total_points - 1), reset_interval):
            if boundary + 1 >= cov_win.size:
                break
            if cov_win[boundary + 1] > cov_win[boundary]:
                print(
                    f"  Warning: {method_name} windowed did not drop after reset boundary at step={boundary} "
                    f"(cov_win[{boundary}]={int(cov_win[boundary])}, cov_win[{boundary+1}]={int(cov_win[boundary+1])})"
                )
                break

    if pos_counts is None or pos_counts.shape != (env_size, env_size):
        print(f"  Warning: {method_name} counts shape={None if pos_counts is None else pos_counts.shape} expected={(env_size, env_size)}")
    else:
        total_visits = int(np.sum(pos_counts))
        if total_visits <= 0:
            print(f"  Warning: {method_name} total visit counts is 0 (unexpected).")
        # counts includes step visits + reset-position visits, so it should be >= total_points (usually >)
        if total_visits < total_points:
            print(f"  Warning: {method_name} total_visits={total_visits} < total_points={total_points} (unexpected).")


def run_sp_dqn(cfg: ComparisonConfig, seed: int, *, timeout_s: int) -> Tuple[str, str]:
    """Run SP+DQN exploration and return paths to output files."""
    out_dir = os.path.join(cfg.out_dir, f"sp_seed{seed}")
    
    # Use an absolute path to ensure correct execution directory
    script_path = os.path.join(os.path.dirname(__file__), "sp_dqn_explore.py")
    cmd = [
        sys.executable, script_path,
        "--total_steps", str(cfg.total_steps),
        "--learning_starts", str(cfg.learning_starts),
        "--seed", str(seed),
        "--env_size", str(cfg.env_size),
        "--out_dir", out_dir,
        "--coverage_reset_interval", str(cfg.coverage_reset_interval),
        "--log_every", "0",  # Disable logging for cleaner output
        "--save_recon_every", "0",  # Disable recon saving
        "--no_render"
    ]
    
    print(f"Running SP+DQN (seed {seed})...")
    try:
        result = subprocess.run(cmd, cwd=os.path.dirname(__file__), 
                              capture_output=True, text=True, timeout=int(timeout_s))
        if result.returncode != 0:
            print(f"SP+DQN failed: {result.stderr}")
            return "", ""
    except subprocess.TimeoutExpired:
        print(f"SP+DQN timeout for seed {seed}")
        return "", ""
    
    coverage_path = os.path.join(out_dir, "coverage.npy")
    counts_path = os.path.join(out_dir, "counts.npy")
    return coverage_path, counts_path


def run_sr_dqn(cfg: ComparisonConfig, seed: int, *, timeout_s: int) -> Tuple[str, str]:
    """Run SR+DQN exploration and return paths to output files."""
    out_dir = os.path.join(cfg.out_dir, f"sr_seed{seed}")
    
    # Use an absolute path to ensure correct execution directory
    script_path = os.path.join(os.path.dirname(__file__), "sr_dqn_explore.py")
    cmd = [
        sys.executable, script_path,
        "--total_steps", str(cfg.total_steps),
        "--learning_starts", str(cfg.learning_starts),
        "--seed", str(seed),
        "--env_size", str(cfg.env_size),
        "--out_dir", out_dir,
        "--coverage_reset_interval", str(cfg.coverage_reset_interval),
        "--log_every", "0",
        "--save_recon_every", "0",
        "--no_render"
    ]
    
    print(f"Running SR+DQN (seed {seed})...")
    try:
        result = subprocess.run(cmd, cwd=os.path.dirname(__file__),
                              capture_output=True, text=True, timeout=int(timeout_s))
        if result.returncode != 0:
            print(f"SR+DQN failed: {result.stderr}")
            return "", ""
    except subprocess.TimeoutExpired:
        print(f"SR+DQN timeout for seed {seed}")
        return "", ""
        
    coverage_path = os.path.join(out_dir, "coverage.npy")
    counts_path = os.path.join(out_dir, "counts.npy")
    return coverage_path, counts_path


def run_random_walk(cfg: ComparisonConfig, seed: int) -> Tuple[str, str]:
    """Run random walk baseline once and extract both cumulative and windowed data."""
    if SimpleEnv is None:
        return "", ""
    
    out_dir = os.path.join(cfg.out_dir, f"random_seed{seed}")
    os.makedirs(out_dir, exist_ok=True)
    
    print(f"Running Random Walk (seed {seed})...")
    
    # Import the random baseline function
    sys.path.insert(0, os.path.dirname(__file__))
    from fpvr_run import run_random_baseline
    
    try:
        np.random.seed(seed)
        env = SimpleEnv(size=cfg.env_size, render_mode="rgb_array", highlight=False)
        env.reset(seed=seed)

        # Run once; runner internally tracks both cumulative + windowed
        result = run_random_baseline(env=env, steps=cfg.total_steps, reset_interval=cfg.coverage_reset_interval)
        env.close()

        coverage_path = os.path.join(out_dir, "coverage.npy")
        windowed_path = os.path.join(out_dir, "coverage_windowed.npy")
        counts_path = os.path.join(out_dir, "counts.npy")

        cov_cum = np.asarray(result.get("coverage_cumulative", result["coverage"]), dtype=np.int32)
        cov_win = np.asarray(result.get("coverage_windowed", result["coverage"]), dtype=np.int32)
        np.save(coverage_path, cov_cum)
        np.save(windowed_path, cov_win)
        np.save(counts_path, result["counts"].astype(np.int32))

        return coverage_path, counts_path

    except Exception as e:
        print(f"Random walk failed: {e}")
        import traceback
        traceback.print_exc()
        return "", ""


def run_deep_fpvr(
    cfg: ComparisonConfig,
    seed: int,
    *,
    fpvr_k_train: Optional[int],
    fpvr_whitening_update_every: Optional[int],
) -> Tuple[str, str]:
    """Run deep FPVR using FPVRVisualAgent and return paths to output files."""
    if SimpleEnv is None or FPVRVisualAgent is None:
        print(f"Missing dependencies for deep FPVR (seed {seed})")
        return "", ""
    
    # Try to import Config
    try:
        sys.path.insert(0, os.path.dirname(__file__))
        from config import Config
        default_config = Config()
    except ImportError as e:
        print(f"Failed to import Config for deep FPVR (seed {seed}): {e}")
        return "", ""
    
    out_dir = os.path.join(cfg.out_dir, f"deep_fpvr_seed{seed}")
    os.makedirs(out_dir, exist_ok=True)
    
    print(f"Running Deep FPVR (seed {seed})...")
    
    # Create env once (avoid leaks) and get a valid observation for shape
    np.random.seed(seed)
    torch.manual_seed(seed)
    env = SimpleEnv(size=cfg.env_size, render_mode="rgb_array", highlight=False)
    obs, _ = env.reset(seed=seed)
    if obs is None:
        print(f"Failed to get valid observation from environment (seed {seed})")
        env.close()
        return "", ""

    # Use the canonical preprocess from run.py
    sys.path.insert(0, os.path.dirname(__file__))
    from fpvr_run import preprocess as preprocess_run

    obs_shape = preprocess_run(obs).shape
    n_actions = env.action_space.n

    whitening_every = int(fpvr_whitening_update_every) if fpvr_whitening_update_every is not None else default_config.whitening_update_every
    agent = FPVRVisualAgent(
        obs_shape=obs_shape,
        n_actions=n_actions,
        lr=default_config.lr,                              # 0.001 from config
        phi_dim=default_config.phi_dim,                    # 400 from config  
        psi_dim=default_config.psi_dim,                    # None->400 from config
        sf_gamma=default_config.sf_gamma,                  # 0.9 from config
        beta=default_config.beta,                          # 1.0 from config
        capacity=default_config.capacity,                  # 3000 from config
        batch_size=default_config.batch_size,              # 64 from config
        update_after=default_config.update_after,          # 1 from config
        update_every=default_config.update_every,          # 1 from config
        whitening_update_every=whitening_every  # override-able for speed
    )
    
    # Import the run function
    sys.path.insert(0, os.path.dirname(__file__))
    from fpvr_run import run_fpvr_training
    import types
    import builtins
    
    def create_config(reset_interval):
        """Create config object for FPVR training."""
        config = types.SimpleNamespace()
        config.print_interval = 0  # Disable printing for comparison
        config.lambda_c = default_config.lambda_c          # 0.95 from config.py
        config.reset_interval = reset_interval
        config.env_size = cfg.env_size
        config.env_seed = seed
        config.steps = cfg.total_steps
        config.k_train = int(fpvr_k_train) if fpvr_k_train is not None else default_config.k_train
        config.phi_dim = default_config.phi_dim             # 400 from config.py
        config.psi_dim = default_config.psi_dim             # 400 from config.py
        config.beta = default_config.beta                   # 1.0 from config.py
        config.sf_gamma = default_config.sf_gamma           # 0.9 from config.py
        return config
    
    try:
        # Run once; runner internally tracks both cumulative + windowed
        config = create_config(reset_interval=cfg.coverage_reset_interval)
        builtins.config = config
        
        result = run_fpvr_training(
            env=env, 
            agent=agent, 
            steps=cfg.total_steps,
            k_train=config.k_train,
            visualize=False,
            feat_dim=default_config.phi_dim,
            env_seed=seed,
            reset_interval=cfg.coverage_reset_interval
        )
        
        if hasattr(builtins, 'config'):
            delattr(builtins, 'config')
        env.close()
        
        # Save cumulative results
        coverage_path = os.path.join(out_dir, "coverage.npy")
        windowed_path = os.path.join(out_dir, "coverage_windowed.npy")
        counts_path = os.path.join(out_dir, "counts.npy")
        
        cov_cum = np.asarray(result.get("coverage_cumulative", result["coverage"]), dtype=np.int32)
        cov_win = np.asarray(result.get("coverage_windowed", result["coverage"]), dtype=np.int32)
        np.save(coverage_path, cov_cum)
        np.save(windowed_path, cov_win)
        np.save(counts_path, result['counts'].astype(np.int32))
        
        return coverage_path, counts_path
        
    except Exception as e:
        print(f"Deep FPVR failed: {e}")
        env.close()
        return "", ""


def run_tabular_fpvr(cfg: ComparisonConfig, seed: int) -> Tuple[str, str]:
    """Run tabular FPVR using run_tabular_fpvr and return paths to output files."""
    if SimpleEnv is None:
        print(f"Missing SimpleEnv for tabular FPVR (seed {seed})")
        return "", ""
    
    # Import config to get exact parameters from config.py
    sys.path.insert(0, os.path.dirname(__file__))
    from config import Config
    default_config = Config()
    
    out_dir = os.path.join(cfg.out_dir, f"tabular_fpvr_seed{seed}")
    os.makedirs(out_dir, exist_ok=True)
    
    print(f"Running Tabular FPVR (seed {seed})...")
    
    # Import the tabular function
    from fpvr_run import run_tabular_fpvr
    
    try:
        np.random.seed(seed)
        env = SimpleEnv(size=cfg.env_size, render_mode="rgb_array", highlight=False)
        env.reset(seed=seed)

        result = run_tabular_fpvr(
            env=env,
            steps=cfg.total_steps,
            n_actions=env.action_space.n,
            feat_dim=cfg.env_size * cfg.env_size,  # Use position-based features
            sf_gamma=default_config.sf_gamma,       # 0.9 from config.py
            beta=default_config.beta,               # 1.0 from config.py
            reset_interval=cfg.coverage_reset_interval
        )
        env.close()

        coverage_path = os.path.join(out_dir, "coverage.npy")
        windowed_path = os.path.join(out_dir, "coverage_windowed.npy")
        counts_path = os.path.join(out_dir, "counts.npy")

        cov_cum = np.asarray(result.get("coverage_cumulative", result["coverage"]), dtype=np.int32)
        cov_win = np.asarray(result.get("coverage_windowed", result["coverage"]), dtype=np.int32)
        np.save(coverage_path, cov_cum)
        np.save(windowed_path, cov_win)
        np.save(counts_path, result['counts'].astype(np.int32))

        return coverage_path, counts_path

    except Exception as e:
        print(f"Tabular FPVR failed: {e}")
        env.close()
        return "", ""




def plot_comparison(all_curves: Dict[str, List[np.ndarray]], 
                   ylabel: str, 
                   out_png: str,
                   out_eps: str,
                   reset_interval: Optional[int] = None):
    """Plot comparison curves with error bands."""
    
    # Ensure output directories exist
    os.makedirs(os.path.dirname(os.path.abspath(out_png)), exist_ok=True)
    os.makedirs(os.path.dirname(os.path.abspath(out_eps)), exist_ok=True)
    
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # Typography aligned with fourrooms_exploration.py
    label_fs = 18
    tick_fs = 14
    legend_fs = 14
    
    # Plot methods in a fixed order to keep legend order stable
    desired_order = ["FPVR", "Tabular FPVR", "SP + DQN", "SR + DQN", "Random Walk"]
    
    for method_name in desired_order:
        if method_name not in all_curves or len(all_curves[method_name]) == 0:
            continue
            
        curves = all_curves[method_name]
        curves_array = np.stack(curves, axis=0)  # [n_seeds, n_steps]
        mean = curves_array.mean(axis=0)
        std = curves_array.std(axis=0)
        
        x = np.arange(len(mean))
        
        color = METHOD_COLORS.get(method_name, f"C{len(all_curves) % 10}")
        band_color = _lighten(color, amount=0.88)
        
        ax.fill_between(x, mean - std, mean + std, color=band_color, linewidth=0.0, alpha=0.3, zorder=1)
        ax.plot(x, mean, label=method_name, linewidth=2.6, color=color, zorder=2)
    
    # Add dark-red dashed vertical lines at reset boundaries (if enabled)
    if reset_interval is not None and reset_interval > 0:
        # Determine maximum horizon from available curves
        valid_curves = [curves for curves in all_curves.values() if len(curves) > 0]
        if valid_curves:
            max_steps = max(len(curves[0]) for curves in valid_curves)
            # Add a dashed line at each reset point
            for reset_step in range(reset_interval, max_steps, reset_interval):
                ax.axvline(x=reset_step, color='darkred', linestyle='--', alpha=0.7, linewidth=1.5, zorder=0)
    
    ax.set_xlabel("Steps", fontsize=label_fs)
    ax.set_ylabel(ylabel, fontsize=label_fs)
    ax.tick_params(axis="both", which="major", labelsize=tick_fs)
    ax.grid(alpha=0.3)
    
    # Only add legend if there are labeled artists
    handles, labels = ax.get_legend_handles_labels()
    if handles:
        ax.legend(fontsize=legend_fs, frameon=True, loc="lower right")
    
    # No title: keep plots clean for paper-ready figures
    
    fig.tight_layout()
    fig.savefig(out_png, dpi=200)
    fig.savefig(out_eps, format="eps")
    plt.close(fig)


def get_wall_positions(env_size: int) -> np.ndarray:
    """Dynamically detect wall positions from the environment.
    Returns a 2D boolean array where True indicates wall positions.
    This method automatically adapts to any changes in visual_minigrid.py.
    """
    # First, try to use SimpleEnv from this directory
    try:
        # Ensure current directory is on the Python path
        import sys
        current_dir = os.path.dirname(os.path.abspath(__file__))
        if current_dir not in sys.path:
            sys.path.insert(0, current_dir)
        
        from visual_minigrid import SimpleEnv
        
        # Create an environment instance to detect walls
        env = SimpleEnv(size=env_size, render_mode=None)
        env.reset()
        
        # Create a wall mask
        wall_mask = np.zeros((env_size, env_size), dtype=bool)
        
        # Scan the grid to detect wall cells
        for x in range(env_size):
            for y in range(env_size):
                cell = env.grid.get(x, y)
                if cell is not None:
                    # Check whether the object is a Wall
                    if hasattr(cell, '__class__') and 'Wall' in cell.__class__.__name__:
                        wall_mask[y, x] = True  # Note: MiniGrid(x,y) -> numpy[y,x]
        
        env.close()
        print(f"Dynamically detected {wall_mask.sum()} wall positions")
        return wall_mask
        
    except ImportError as e:
        print(f"Info: Could not import required modules for wall detection: {e}")
        print("This is expected if minigrid is not installed. Walls will not be visualized.")
        # If import fails (e.g., missing minigrid), return an empty wall mask
        return np.zeros((env_size, env_size), dtype=bool)
        
    except Exception as e:
        print(f"Warning: Could not detect walls from environment: {e}")
        print("Falling back to empty wall mask")
        # If environment creation fails, return an empty wall mask
        return np.zeros((env_size, env_size), dtype=bool)


def plot_heatmaps(all_counts: Dict[str, List[np.ndarray]], out_dir: str, env_size: int):
    """Plot state visitation heatmaps for each method."""
    
    if not all_counts:
        print("No data available for heatmaps generation.")
        return
    
    # Get wall positions (dynamic detection)
    wall_mask = get_wall_positions(env_size)
    has_walls = wall_mask.sum() > 0
    
    if not has_walls:
        print("Note: No walls detected or wall detection unavailable. Heatmaps will show visit counts only.")
    
    for method_name, counts_list in all_counts.items():
        if len(counts_list) == 0:
            print(f"Skipping heatmap for {method_name}: no valid data")
            continue
            
        # Average across seeds
        mean_counts = np.stack(counts_list, axis=0).mean(axis=0)
        
        # Use raw visit counts (no log transform)
        visit_counts = mean_counts.copy()
        
        fig, ax = plt.subplots(figsize=(8, 8))
        
        if has_walls:
            # With walls: set wall locations to NaN so they can be rendered in gray
            masked_counts = visit_counts.copy().astype(float)
            masked_counts[wall_mask] = np.nan
            
            # Plot visit-count heatmap
            im = ax.imshow(masked_counts, cmap='hot', origin='upper', 
                          extent=[0, env_size, 0, env_size])
            
            # Overlay walls in deep gray
            wall_display = np.where(wall_mask, 1, np.nan) 
            # Use deep gray (#404040) for walls
            wall_cmap = ListedColormap(['#404040'])  # deep gray
            ax.imshow(wall_display, cmap=wall_cmap, origin='upper', 
                     extent=[0, env_size, 0, env_size], alpha=1.0, vmin=0, vmax=1)
        else:
            # Without walls: plot the heatmap directly
            im = ax.imshow(visit_counts, cmap='hot', origin='upper', 
                          extent=[0, env_size, 0, env_size])
        
        # Integer tick marks
        ax.set_xticks(range(0, env_size+1, 2))  # tick every 2 units
        ax.set_yticks(range(0, env_size+1, 2))
        ax.set_xlim(0, env_size)
        ax.set_ylim(0, env_size)
        
        # Titles and axes labels (use paper-consistent font sizes)
        ax.set_title(f"{method_name} State Visit Count", fontsize=18)
        ax.set_xlabel("X Position", fontsize=18)  # label_fs = 18
        ax.set_ylabel("Y Position", fontsize=18)  # label_fs = 18
        ax.tick_params(axis="both", which="major", labelsize=14)  # tick_fs = 14
        
        cbar = plt.colorbar(im, ax=ax)
        cbar.set_label("Visit Count", fontsize=18)
        cbar.ax.tick_params(labelsize=16)
        
        # Save with safe filename - replace all problematic characters
        safe_name = method_name.replace("+", "").replace(" ", "_").replace("/", "_").lower()
        heatmap_path = os.path.join(out_dir, f"heatmap_{safe_name}.png")
        fig.savefig(heatmap_path, dpi=200, bbox_inches='tight')
        plt.close(fig)
        
        wall_info = " (with walls)" if has_walls else " (no walls detected)"
        print(f"Saved heatmap: {heatmap_path}{wall_info}")


def main():
    parser = argparse.ArgumentParser(description="Compare exploration methods in visual gridworld")
    
    # Environment
    parser.add_argument("--env_size", type=int, default=20)
    parser.add_argument("--n_seeds", type=int, default=10)
    parser.add_argument("--seed_base", type=int, default=1)
    
    # Training
    parser.add_argument("--total_steps", type=int, default=9000)
    parser.add_argument("--learning_starts", type=int, default=100)
    
    # Analysis
    parser.add_argument("--coverage_reset_interval", type=int, default=3000)
    parser.add_argument("--subprocess_timeout_s", type=int, default=1800)

    # Optional speed overrides for deep FPVR (comparison only; does not change defaults unless set)
    parser.add_argument("--fpvr_k_train", type=int, default=None, help="Override FPVR k_train (updates per env step).")
    parser.add_argument(
        "--fpvr_whitening_update_every",
        type=int,
        default=None,
        help="Override FPVR whitening_update_every (larger => faster, less frequent SVD).",
    )
    
    # Output - keep results under this directory by default
    default_out_dir = os.path.join(os.path.dirname(__file__), "results")
    parser.add_argument("--out_dir", type=str, default=default_out_dir)
    parser.add_argument("--out_png", type=str, default=None)
    
    # Method selection
    parser.add_argument("--methods", nargs="+", 
                       choices=["random", "sr", "sp", "deep_fpvr", "tabular_fpvr"],
                       default=["sr", "sp"],
                       help="Which methods to run")
    
    args = parser.parse_args()
    
    if args.out_png is None:
        args.out_png = os.path.join(args.out_dir, "coverage_comparison.png")
    
    cfg = ComparisonConfig(
        env_size=args.env_size,
        seed_base=args.seed_base,
        n_seeds=args.n_seeds,
        total_steps=args.total_steps,
        learning_starts=args.learning_starts,
        coverage_reset_interval=args.coverage_reset_interval,
        out_dir=args.out_dir
    )
    
    # Ensure output directories exist
    os.makedirs(cfg.out_dir, exist_ok=True)
    
    # Ensure the PNG output directory exists
    png_dir = os.path.dirname(os.path.abspath(args.out_png))
    os.makedirs(png_dir, exist_ok=True)
    
    # Save config
    with open(os.path.join(cfg.out_dir, "config.json"), "w") as f:
        json.dump(vars(args), f, indent=2)
    
    # Method runners - keep method names consistent with the legend
    timeout_s = int(args.subprocess_timeout_s)

    def _wrap_runner(method_key: str):
        if method_key == "sr":
            return lambda cfg0, seed0: run_sr_dqn(cfg0, seed0, timeout_s=timeout_s)
        if method_key == "sp":
            return lambda cfg0, seed0: run_sp_dqn(cfg0, seed0, timeout_s=timeout_s)
        if method_key == "deep_fpvr":
            return lambda cfg0, seed0: run_deep_fpvr(
                cfg0,
                seed0,
                fpvr_k_train=args.fpvr_k_train,
                fpvr_whitening_update_every=args.fpvr_whitening_update_every,
            )
        if method_key == "random":
            return run_random_walk
        if method_key == "tabular_fpvr":
            return run_tabular_fpvr
        raise KeyError(method_key)

    method_runners = {
        "random": ("Random Walk", _wrap_runner("random")),
        "sr": ("SR + DQN", _wrap_runner("sr")),
        "sp": ("SP + DQN", _wrap_runner("sp")),
        "deep_fpvr": ("FPVR", _wrap_runner("deep_fpvr")),
        "tabular_fpvr": ("Tabular FPVR", _wrap_runner("tabular_fpvr")),
    }
    
    # Collect results
    all_curves = {}  # method_name -> List[coverage_curve]
    all_curves_win = {}  # method_name -> List[windowed_coverage_curve] 
    all_counts = {}  # method_name -> List[position_counts]
    
    for method_key in args.methods:
        if method_key not in method_runners:
            print(f"Unknown method: {method_key}")
            continue
            
        method_name, runner_func = method_runners[method_key]
        
        curves = []
        curves_win = []
        counts = []
        
        for seed in range(cfg.seed_base, cfg.seed_base + cfg.n_seeds):
            coverage_path, counts_path = runner_func(cfg, seed)
            
            if coverage_path and counts_path:
                try:
                    windowed_path = coverage_path.replace('coverage.npy', 'coverage_windowed.npy')
                    cov, cov_win = load_coverage_data(coverage_path, windowed_path, cfg.total_steps + 1)
                    
                    pos_counts = extract_position_counts(counts_path, cfg.env_size)

                    sanity_check_outputs(
                        method_name=method_name,
                        cov=cov,
                        cov_win=cov_win,
                        pos_counts=pos_counts,
                        total_points=cfg.total_steps + 1,
                        env_size=cfg.env_size,
                        reset_interval=cfg.coverage_reset_interval,
                    )
                    
                    # Debug info: validate collected data
                    max_cov = np.max(cov) if len(cov) > 0 else 0
                    total_visits = np.sum(pos_counts) if pos_counts is not None else 0
                    print(f"  {method_name} (seed {seed}): Max coverage={max_cov}, Total visits={total_visits}")
                    
                    curves.append(cov)
                    curves_win.append(cov_win)
                    counts.append(pos_counts)
                except (FileNotFoundError, ValueError) as e:
                    print(f"ERROR: {method_name} seed {seed} data loading failed: {e}")
                    raise e  # Fail fast if required data is missing
            else:
                print(f"ERROR: {method_name} seed {seed} failed to generate data files")
                raise FileNotFoundError(f"Failed to generate data files for {method_name} seed {seed}")
        
        if len(curves) > 0:
            all_curves[method_name] = curves
            all_curves_win[method_name] = curves_win
            all_counts[method_name] = counts
            print(f"  {method_name}: Successfully collected data from {len(curves)} seed(s)")
        else:
            print(f"  {method_name}: No valid data collected")
    
    # Plot cumulative coverage (no reset markers)
    out_eps = args.out_png.replace(".png", ".eps")
    plot_comparison(all_curves, "Number of visited states", args.out_png, out_eps, reset_interval=None)
    print(f"Saved cumulative coverage plot: {args.out_png}")
    
    # Plot windowed coverage  
    if cfg.coverage_reset_interval > 0:
        win_png = args.out_png.replace(".png", f"_reset{cfg.coverage_reset_interval}.png")
        win_eps = win_png.replace(".png", ".eps")
        plot_comparison(all_curves_win, 
                       f"Visited states (reset every {cfg.coverage_reset_interval} steps)",
                       win_png, win_eps, reset_interval=cfg.coverage_reset_interval)
        print(f"Saved windowed coverage plot: {win_png}")
    
    # Plot heatmaps
    print("\nGenerating heatmaps...")
    plot_heatmaps(all_counts, cfg.out_dir, cfg.env_size)
    
    print(f"\nComparison complete. Results saved in: {cfg.out_dir}")
    
    # Summary
    successful_methods = len(all_curves)
    total_methods = len(args.methods)
    print(f"Successfully analyzed {successful_methods}/{total_methods} methods.")
    
    if successful_methods == 0:
        print("\nNote: No methods ran successfully. This is likely due to missing dependencies.")
        print("To run all methods, please install: pip install minigrid")
    

if __name__ == "__main__":
    main()