#!/usr/bin/env python
"""
坐标维度投影可视化

无视钥匙、门的状态标志位和智能体朝向，对所有状态求平均，
只保留位置坐标 (x, y) 的投影 heatmap。

这种可视化的参考价值：
1. 优点：
   - 简洁直观，一张图看到整体空间偏好
   - 不受状态配置限制，可以看到"平均"行为
   - 可能揭示 option 在空间上的整体偏好区域
   - 适合快速了解 option 的空间分布模式

2. 缺点：
   - 丢失了关键信息：不同状态配置下的行为差异
   - 可能掩盖重要的状态依赖行为
   - 对于 KeyLockEnv 这种状态空间分离的环境，平均可能没有太大意义
   - 无法看到 option 如何适应不同的任务阶段

3. 使用建议：
   - 作为补充可视化，不能替代按状态配置的可视化
   - 适合快速浏览多个 option 的空间偏好
   - 结合按状态配置的可视化一起使用，获得更全面的理解
"""

import os
import numpy as np
import matplotlib.pyplot as plt
from key_lock_options import (
    state_to_index,
    get_wall_positions,
)


def visualize_projection(
    features,  # shape: (k, N) for VPS or (N,) for eigenvector
    size: int,
    save_dir: str,
    prefix: str = "projection",
    iteration: int = 0,
    max_features_to_plot: int = 4,
    env=None,
    feature_names=None,
):
    """
    绘制特征在坐标维度上的投影 heatmap。
    
    对所有状态配置、所有朝向求平均，只保留位置坐标 (x, y) 的投影。
    
    Args:
        features: Feature values, shape (k, N) for VPS or (N,) for eigenvector
        size: Grid size
        save_dir: Directory to save figures
        prefix: Prefix for saved files
        iteration: Current iteration number
        max_features_to_plot: Maximum number of features to plot
        env: Optional KeyLockEnv instance for wall visualization
        feature_names: Optional list of feature names for titles
    """
    # Handle both 1D and 2D cases
    if features.ndim == 1:
        features = features[np.newaxis, :]  # (1, N)
    
    k, N = features.shape
    num_to_plot = min(k, max_features_to_plot)
    
    # Get wall positions if env is provided
    wall_positions = []
    if env is not None:
        wall_positions = get_wall_positions(env, size)
    
    # All possible state configurations (simplified: removed has_yellow_key and has_blue_key)
    # We'll average over all valid combinations
    all_configs = []
    for y_door in [0, 1]:
        for b_door in [0, 1]:
            for y_on_map in [0, 1]:
                for b_on_map in [0, 1]:
                    all_configs.append({
                        "yellow_door_open": y_door,
                        "blue_door_open": b_door,
                        "yellow_key_on_map": y_on_map,
                        "blue_key_on_map": b_on_map,
                    })
    
    # Create figure
    if num_to_plot == 1:
        fig, axes = plt.subplots(1, 1, figsize=(8, 8))
        axes = [axes]
    else:
        rows = (num_to_plot + 1) // 2
        cols = 2
        fig, axes = plt.subplots(rows, cols, figsize=(8 * cols, 8 * rows))
        if rows == 1:
            axes = axes if isinstance(axes, np.ndarray) else [axes]
        else:
            axes = axes.flatten()
    
    for feat_idx in range(num_to_plot):
        ax = axes[feat_idx]
        
        # Create projection heatmap
        heatmap = np.zeros((size, size))
        heatmap.fill(np.nan)
        
        for x in range(size):
            for y in range(size):
                if (x, y) in wall_positions:
                    continue
                
                # Average over all directions and all state configurations
                vals = []
                for dir_fixed in range(4):  # 4 directions
                    for config in all_configs:
                        try:
                            s = state_to_index(
                                x, y, dir_fixed,
                                config["yellow_door_open"],
                                config["blue_door_open"],
                                config["yellow_key_on_map"],
                                config["blue_key_on_map"],
                                size
                            )
                            vals.append(features[feat_idx, s])
                        except (IndexError, ValueError):
                            # Skip invalid states
                            continue
                
                if len(vals) > 0:
                    heatmap[y, x] = np.mean(vals)
        
        # Plot heatmap
        masked_heatmap = np.ma.masked_invalid(heatmap)
        
        # Choose colormap based on feature type
        # For eigenvectors (can be negative), use coolwarm
        # For VPS features (usually positive), use viridis or hot
        if np.any(features[feat_idx] < 0):
            cmap = plt.cm.coolwarm.copy()
        else:
            cmap = plt.cm.viridis.copy()
        
        cmap.set_bad(color='gray', alpha=0.5)
        
        im = ax.imshow(masked_heatmap, origin='upper', cmap=cmap,
                      aspect='auto', extent=[0, size, 0, size])
        
        # Set title
        if feature_names is not None and feat_idx < len(feature_names):
            title = f"{feature_names[feat_idx]}\n(坐标投影)"
        else:
            title = f"特征 {feat_idx}\n(坐标投影 - 所有状态平均)"
        
        ax.set_title(title, fontsize=12, fontweight='bold')
        ax.set_xlabel('X', fontsize=10)
        ax.set_ylabel('Y', fontsize=10)
        plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    
    # Hide unused subplots
    for idx in range(num_to_plot, len(axes)):
        axes[idx].axis('off')
    
    plt.tight_layout()
    
    # Save figure
    os.makedirs(save_dir, exist_ok=True)
    filename = os.path.join(save_dir, f"{prefix}_iter{iteration:04d}.png")
    plt.savefig(filename, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"  [Viz] Saved projection visualization: {filename}")


def visualize_projection_simple(
    features,  # shape: (k, N) for VPS or (N,) for eigenvector
    size: int,
    save_dir: str,
    prefix: str = "projection_simple",
    iteration: int = 0,
    max_features_to_plot: int = 4,
    env=None,
    feature_names=None,
):
    """
    简化版投影可视化：只对可达状态配置求平均。
    
    只考虑 5 个逻辑上可达的状态配置，而不是所有可能的组合。
    这样更符合环境的实际行为。
    """
    # Handle both 1D and 2D cases
    if features.ndim == 1:
        features = features[np.newaxis, :]  # (1, N)
    
    k, N = features.shape
    num_to_plot = min(k, max_features_to_plot)
    
    # Get wall positions if env is provided
    wall_positions = []
    if env is not None:
        wall_positions = get_wall_positions(env, size)
    
    # Only consider reachable state configurations (simplified: removed has_yellow_key and has_blue_key)
    reachable_configs = [
        {
            "yellow_door_open": 0,
            "blue_door_open": 0,
            "yellow_key_on_map": 1,
            "blue_key_on_map": 1,
        },
        {
            "yellow_door_open": 0,
            "blue_door_open": 0,
            "yellow_key_on_map": 1,
            "blue_key_on_map": 0,
        },
        {
            "yellow_door_open": 0,
            "blue_door_open": 1,
            "yellow_key_on_map": 1,
            "blue_key_on_map": 0,
        },
        {
            "yellow_door_open": 0,
            "blue_door_open": 1,
            "yellow_key_on_map": 0,
            "blue_key_on_map": 0,
        },
        {
            "yellow_door_open": 1,
            "blue_door_open": 1,
            "yellow_key_on_map": 0,
            "blue_key_on_map": 0,
        },
    ]
    
    # Create figure
    if num_to_plot == 1:
        fig, axes = plt.subplots(1, 1, figsize=(8, 8))
        axes = [axes]
    else:
        rows = (num_to_plot + 1) // 2
        cols = 2
        fig, axes = plt.subplots(rows, cols, figsize=(8 * cols, 8 * rows))
        if rows == 1:
            axes = axes if isinstance(axes, np.ndarray) else [axes]
        else:
            axes = axes.flatten()
    
    for feat_idx in range(num_to_plot):
        ax = axes[feat_idx]
        
        # Create projection heatmap
        heatmap = np.zeros((size, size))
        heatmap.fill(np.nan)
        
        for x in range(size):
            for y in range(size):
                if (x, y) in wall_positions:
                    continue
                
                # Average over all directions and reachable state configurations
                vals = []
                for dir_fixed in range(4):  # 4 directions
                    for config in reachable_configs:
                        try:
                            s = state_to_index(
                                x, y, dir_fixed,
                                config["yellow_door_open"],
                                config["blue_door_open"],
                                config["yellow_key_on_map"],
                                config["blue_key_on_map"],
                                size
                            )
                            vals.append(features[feat_idx, s])
                        except (IndexError, ValueError):
                            continue
                
                if len(vals) > 0:
                    heatmap[y, x] = np.mean(vals)
        
        # Plot heatmap
        masked_heatmap = np.ma.masked_invalid(heatmap)
        
        # Choose colormap
        if np.any(features[feat_idx] < 0):
            cmap = plt.cm.coolwarm.copy()
        else:
            cmap = plt.cm.viridis.copy()
        
        cmap.set_bad(color='gray', alpha=0.5)
        
        im = ax.imshow(masked_heatmap, origin='upper', cmap=cmap,
                      aspect='auto', extent=[0, size, 0, size])
        
        # Set title (use English to avoid font issues)
        if feature_names is not None and feat_idx < len(feature_names):
            title = f"{feature_names[feat_idx]}\n(Coordinate Projection - Reachable States Avg)"
        else:
            title = f"Feature {feat_idx}\n(Coordinate Projection - Reachable States Avg)"
        
        ax.set_title(title, fontsize=12, fontweight='bold')
        ax.set_xlabel('X', fontsize=10)
        ax.set_ylabel('Y', fontsize=10)
        plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    
    # Hide unused subplots
    for idx in range(num_to_plot, len(axes)):
        axes[idx].axis('off')
    
    plt.tight_layout()
    
    # Save figure
    os.makedirs(save_dir, exist_ok=True)
    filename = os.path.join(save_dir, f"{prefix}_iter{iteration:04d}.png")
    plt.savefig(filename, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"  [Viz] Saved simple projection visualization: {filename}")


if __name__ == "__main__":
    print("This module provides projection visualization functions.")
    print("Import and use:")
    print("  from projection_visualization import visualize_projection_simple")
    print("  visualize_projection_simple(features, size, save_dir, ...)")

