#!/usr/bin/env python
"""
改进的 VPS Option 可视化方案

设计思路：
1. 值函数 V(s) 和 VPS 特征 φ(s) 分开可视化
2. 每个 option 一张大图，包含所有状态配置的子图
3. 使用清晰的标题和颜色映射
4. 可以选择性地只展示关键状态配置
"""

import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from key_lock_options import (
    state_to_index,
    get_wall_positions,
)


def visualize_vps_option_comprehensive(
    V_base,  # Value functions, shape (k_base, N)
    phi_base,  # VPS features, shape (k_base, N)
    size: int,
    save_dir: str,
    prefix: str = "vps_comprehensive",
    iteration: int = 0,
    max_options_to_plot: int = 4,
    env=None,
    state_configs=None,
):
    """
    综合可视化 VPS option 的值函数和 VPS 特征。
    
    为每个 option 创建一张大图，包含：
    - 值函数 V(s) 在不同状态配置下的分布（上半部分）
    - VPS 特征 φ(s) 在不同状态配置下的分布（下半部分）
    
    Args:
        V_base: Value functions, shape (k_base, N)
        phi_base: VPS features, shape (k_base, N)
        size: Grid size
        save_dir: Directory to save figures
        prefix: Prefix for saved files
        iteration: Current iteration number
        max_options_to_plot: Maximum number of options to plot
        env: Optional KeyLockEnv instance for wall visualization
        state_configs: Optional list of state configs to visualize (if None, uses default 5 configs)
    """
    k_base, N = V_base.shape
    
    # Default state configurations
    if state_configs is None:
        state_configs = [
            {
                "name": "初始状态\n(无钥匙，门关闭)",
                "short_name": "初始",
                "yellow_door_open": 0,
                "blue_door_open": 0,
                "yellow_key_on_map": 1,
                "blue_key_on_map": 1,
            },
            {
                "name": "蓝钥匙已拾取\n(门关闭)",
                "short_name": "蓝钥匙",
                "yellow_door_open": 0,
                "blue_door_open": 0,
                "yellow_key_on_map": 1,
                "blue_key_on_map": 0,
            },
            {
                "name": "蓝门已打开\n(蓝钥匙已消耗)",
                "short_name": "蓝门开",
                "yellow_door_open": 0,
                "blue_door_open": 1,
                "yellow_key_on_map": 1,
                "blue_key_on_map": 0,
            },
            {
                "name": "黄钥匙已拾取\n(蓝门已开)",
                "short_name": "黄钥匙",
                "yellow_door_open": 0,
                "blue_door_open": 1,
                "yellow_key_on_map": 0,
                "blue_key_on_map": 0,
            },
            {
                "name": "两门皆开\n(钥匙已消耗)",
                "short_name": "完成",
                "yellow_door_open": 1,
                "blue_door_open": 1,
                "yellow_key_on_map": 0,
                "blue_key_on_map": 0,
            },
        ]
    
    num_configs = len(state_configs)
    num_to_plot = min(k_base, max_options_to_plot)
    
    # Get wall positions if env is provided
    wall_positions = []
    if env is not None:
        wall_positions = get_wall_positions(env, size)
    
    # Visualize each option separately
    for opt_idx in range(num_to_plot):
        # Create a large figure for this option
        # Layout: 2 rows (Value function, VPS feature) × num_configs columns
        fig = plt.figure(figsize=(4 * num_configs, 8))
        gs = GridSpec(2, num_configs, figure=fig, hspace=0.3, wspace=0.3)
        
        # Row 1: Value functions V(s)
        for config_idx, config in enumerate(state_configs):
            ax_v = fig.add_subplot(gs[0, config_idx])
            heatmap_v = _create_heatmap(
                V_base[opt_idx], size, config, wall_positions
            )
            _plot_heatmap(
                ax_v, heatmap_v, size,
                title=f"值函数 V(s)\n{config['short_name']}",
                cmap='viridis'
            )
        
        # Row 2: VPS features φ(s)
        for config_idx, config in enumerate(state_configs):
            ax_phi = fig.add_subplot(gs[1, config_idx])
            heatmap_phi = _create_heatmap(
                phi_base[opt_idx], size, config, wall_positions
            )
            _plot_heatmap(
                ax_phi, heatmap_phi, size,
                title=f"VPS 特征 φ(s)\n{config['short_name']}",
                cmap='hot'
            )
        
        # Add overall title
        fig.suptitle(f'VPS Option {opt_idx} - 值函数和特征分布', 
                     fontsize=14, fontweight='bold', y=0.98)
        
        # Save figure
        os.makedirs(save_dir, exist_ok=True)
        filename = os.path.join(save_dir, f"{prefix}_opt{opt_idx}_iter{iteration:04d}.png")
        plt.savefig(filename, dpi=150, bbox_inches='tight')
        plt.close()
        print(f"  [Viz] Saved {filename}")


def visualize_vps_option_side_by_side(
    V_base,  # Value functions, shape (k_base, N)
    phi_base,  # VPS features, shape (k_base, N)
    size: int,
    save_dir: str,
    prefix: str = "vps_side_by_side",
    iteration: int = 0,
    max_options_to_plot: int = 4,
    env=None,
    state_configs=None,
):
    """
    并排可视化：值函数和 VPS 特征并排显示。
    
    布局：每个 option 一行，每个状态配置两列（值函数 | VPS 特征）
    """
    k_base, N = V_base.shape
    
    # Default state configurations
    if state_configs is None:
        state_configs = [
            {
                "name": "初始状态",
                "short_name": "初始",
                "yellow_door_open": 0,
                "blue_door_open": 0,
                "yellow_key_on_map": 1,
                "blue_key_on_map": 1,
            },
            {
                "name": "蓝钥匙已拾取",
                "short_name": "蓝钥匙",
                "yellow_door_open": 0,
                "blue_door_open": 0,
                "yellow_key_on_map": 1,
                "blue_key_on_map": 0,
            },
            {
                "name": "蓝门已打开",
                "short_name": "蓝门开",
                "yellow_door_open": 0,
                "blue_door_open": 1,
                "yellow_key_on_map": 1,
                "blue_key_on_map": 0,
            },
            {
                "name": "黄钥匙已拾取",
                "short_name": "黄钥匙",
                "yellow_door_open": 0,
                "blue_door_open": 1,
                "yellow_key_on_map": 0,
                "blue_key_on_map": 0,
            },
            {
                "name": "两门皆开",
                "short_name": "完成",
                "yellow_door_open": 1,
                "blue_door_open": 1,
                "yellow_key_on_map": 0,
                "blue_key_on_map": 0,
            },
        ]
    
    num_configs = len(state_configs)
    num_to_plot = min(k_base, max_options_to_plot)
    
    # Get wall positions if env is provided
    wall_positions = []
    if env is not None:
        wall_positions = get_wall_positions(env, size)
    
    # Visualize each option separately
    for opt_idx in range(num_to_plot):
        # Create figure: rows = num_configs, cols = 2 (V | φ)
        fig, axes = plt.subplots(
            num_configs, 2,
            figsize=(10, 4 * num_configs)
        )
        
        if num_configs == 1:
            axes = axes[np.newaxis, :]
        
        for config_idx, config in enumerate(state_configs):
            # Left: Value function
            ax_v = axes[config_idx, 0]
            heatmap_v = _create_heatmap(
                V_base[opt_idx], size, config, wall_positions
            )
            _plot_heatmap(
                ax_v, heatmap_v, size,
                title=f"值函数 V(s) - {config['short_name']}",
                cmap='viridis'
            )
            
            # Right: VPS feature
            ax_phi = axes[config_idx, 1]
            heatmap_phi = _create_heatmap(
                phi_base[opt_idx], size, config, wall_positions
            )
            _plot_heatmap(
                ax_phi, heatmap_phi, size,
                title=f"VPS 特征 φ(s) - {config['short_name']}",
                cmap='hot'
            )
        
        # Add overall title
        fig.suptitle(f'VPS Option {opt_idx} - 值函数和特征分布', 
                     fontsize=14, fontweight='bold')
        
        plt.tight_layout()
        
        # Save figure
        os.makedirs(save_dir, exist_ok=True)
        filename = os.path.join(save_dir, f"{prefix}_opt{opt_idx}_iter{iteration:04d}.png")
        plt.savefig(filename, dpi=150, bbox_inches='tight')
        plt.close()
        print(f"  [Viz] Saved {filename}")


def visualize_vps_option_key_stages(
    V_base,  # Value functions, shape (k_base, N)
    phi_base,  # VPS features, shape (k_base, N)
    size: int,
    save_dir: str,
    prefix: str = "vps_key_stages",
    iteration: int = 0,
    max_options_to_plot: int = 4,
    env=None,
):
    """
    只可视化关键阶段：初始状态、蓝门打开、完成状态。
    
    这是最简洁的可视化方案，适合快速了解 option 的行为。
    """
    key_configs = [
        {
            "name": "初始状态\n(无钥匙，门关闭)",
            "short_name": "初始",
            "yellow_door_open": 0,
            "blue_door_open": 0,
            "yellow_key_on_map": 1,
            "blue_key_on_map": 1,
        },
        {
            "name": "蓝门已打开\n(关键转折点)",
            "short_name": "蓝门开",
            "yellow_door_open": 0,
            "blue_door_open": 1,
            "yellow_key_on_map": 1,
            "blue_key_on_map": 0,
        },
        {
            "name": "两门皆开\n(任务完成)",
            "short_name": "完成",
            "yellow_door_open": 1,
            "blue_door_open": 1,
            "yellow_key_on_map": 0,
            "blue_key_on_map": 0,
        },
    ]
    
    visualize_vps_option_comprehensive(
        V_base, phi_base, size, save_dir, prefix, iteration,
        max_options_to_plot, env, state_configs=key_configs
    )


def _create_heatmap(feature_array, size, config, wall_positions):
    """创建单个状态配置的 heatmap."""
    heatmap = np.zeros((size, size))
    heatmap.fill(np.nan)
    
    for x in range(size):
        for y in range(size):
            if (x, y) in wall_positions:
                continue
            
            vals = []
            for dir_fixed in range(4):
                s = state_to_index(
                    x, y, dir_fixed,
                    config["yellow_door_open"],
                    config["blue_door_open"],
                    config.get("yellow_key_on_map", 1),
                    config.get("blue_key_on_map", 1),
                    size
                )
                vals.append(feature_array[s])
            
            if len(vals) > 0:
                heatmap[y, x] = np.mean(vals)
    
    return heatmap


def _plot_heatmap(ax, heatmap, size, title, cmap='viridis'):
    """在给定的 axes 上绘制 heatmap."""
    masked_heatmap = np.ma.masked_invalid(heatmap)
    
    cmap_obj = plt.cm.get_cmap(cmap).copy()
    cmap_obj.set_bad(color='gray', alpha=0.5)
    
    im = ax.imshow(masked_heatmap, origin='upper', cmap=cmap_obj, 
                  aspect='auto', extent=[0, size, 0, size])
    
    ax.set_title(title, fontsize=11, fontweight='bold')
    ax.set_xlabel('X', fontsize=10)
    ax.set_ylabel('Y', fontsize=10)
    plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)


def visualize_eigenvector_comprehensive(
    eig_vecs,  # Eigenvectors, shape (N, k_base)
    size: int,
    save_dir: str,
    prefix: str = "eigen_comprehensive",
    iteration: int = 0,
    max_eigenvectors_to_plot: int = 4,
    env=None,
    state_configs=None,
):
    """
    综合可视化拉普拉斯特征向量。
    
    为每个特征向量创建一张大图，包含所有状态配置的子图。
    
    Args:
        eig_vecs: Eigenvectors, shape (N, k_base) - 注意：这是转置后的形状
        size: Grid size
        save_dir: Directory to save figures
        prefix: Prefix for saved files
        iteration: Current iteration number
        max_eigenvectors_to_plot: Maximum number of eigenvectors to plot
        env: Optional KeyLockEnv instance for wall visualization
        state_configs: Optional list of state configs to visualize (if None, uses default 5 configs)
    """
    N, k_base = eig_vecs.shape
    
    # Default state configurations
    if state_configs is None:
        state_configs = [
            {
                "name": "初始状态\n(无钥匙，门关闭)",
                "short_name": "初始",
                "yellow_door_open": 0,
                "blue_door_open": 0,
                "yellow_key_on_map": 1,
                "blue_key_on_map": 1,
            },
            {
                "name": "蓝钥匙已拾取\n(门关闭)",
                "short_name": "蓝钥匙",
                "yellow_door_open": 0,
                "blue_door_open": 0,
                "yellow_key_on_map": 1,
                "blue_key_on_map": 0,
            },
            {
                "name": "蓝门已打开\n(蓝钥匙已消耗)",
                "short_name": "蓝门开",
                "yellow_door_open": 0,
                "blue_door_open": 1,
                "yellow_key_on_map": 1,
                "blue_key_on_map": 0,
            },
            {
                "name": "黄钥匙已拾取\n(蓝门已开)",
                "short_name": "黄钥匙",
                "yellow_door_open": 0,
                "blue_door_open": 1,
                "yellow_key_on_map": 0,
                "blue_key_on_map": 0,
            },
            {
                "name": "两门皆开\n(钥匙已消耗)",
                "short_name": "完成",
                "yellow_door_open": 1,
                "blue_door_open": 1,
                "yellow_key_on_map": 0,
                "blue_key_on_map": 0,
            },
        ]
    
    num_configs = len(state_configs)
    num_to_plot = min(k_base, max_eigenvectors_to_plot)
    
    # Get wall positions if env is provided
    wall_positions = []
    if env is not None:
        wall_positions = get_wall_positions(env, size)
    
    # Visualize each eigenvector separately
    for eig_idx in range(num_to_plot):
        # Create a large figure for this eigenvector
        # Layout: 1 row × num_configs columns
        fig, axes = plt.subplots(1, num_configs, figsize=(4 * num_configs, 4))
        
        if num_configs == 1:
            axes = [axes]
        
        for config_idx, config in enumerate(state_configs):
            ax = axes[config_idx]
            heatmap = _create_heatmap(
                eig_vecs[:, eig_idx], size, config, wall_positions
            )
            _plot_heatmap(
                ax, heatmap, size,
                title=f"特征向量 {eig_idx}\n{config['short_name']}",
                cmap='coolwarm'  # Use coolwarm for eigenvectors (can be positive/negative)
            )
        
        # Add overall title
        fig.suptitle(f'拉普拉斯特征向量 {eig_idx} - 分布', 
                     fontsize=14, fontweight='bold', y=0.98)
        
        plt.tight_layout()
        
        # Save figure
        os.makedirs(save_dir, exist_ok=True)
        filename = os.path.join(save_dir, f"{prefix}_vec{eig_idx}_iter{iteration:04d}.png")
        plt.savefig(filename, dpi=150, bbox_inches='tight')
        plt.close()
        print(f"  [Viz] Saved {filename}")


def visualize_eigenvector_key_stages(
    eig_vecs,  # Eigenvectors, shape (N, k_base)
    size: int,
    save_dir: str,
    prefix: str = "eigen_key_stages",
    iteration: int = 0,
    max_eigenvectors_to_plot: int = 4,
    env=None,
):
    """
    只可视化关键阶段：初始状态、蓝门打开、完成状态。
    """
    key_configs = [
        {
            "name": "初始状态\n(无钥匙，门关闭)",
            "short_name": "初始",
            "yellow_door_open": 0,
            "blue_door_open": 0,
            "yellow_key_on_map": 1,
            "blue_key_on_map": 1,
        },
        {
            "name": "蓝门已打开\n(关键转折点)",
            "short_name": "蓝门开",
            "yellow_door_open": 0,
            "blue_door_open": 1,
            "yellow_key_on_map": 1,
            "blue_key_on_map": 0,
        },
        {
            "name": "两门皆开\n(任务完成)",
            "short_name": "完成",
            "yellow_door_open": 1,
            "blue_door_open": 1,
            "yellow_key_on_map": 0,
            "blue_key_on_map": 0,
        },
    ]
    
    visualize_eigenvector_comprehensive(
        eig_vecs, size, save_dir, prefix, iteration,
        max_eigenvectors_to_plot, env, state_configs=key_configs
    )


if __name__ == "__main__":
    # Example usage
    print("This module provides improved visualization functions for VPS options and Eigenvectors.")
    print("Import and use:")
    print("  from improved_vps_visualization import visualize_vps_option_comprehensive")
    print("  visualize_vps_option_comprehensive(V_base, phi_base, size, save_dir, ...)")
    print("  from improved_vps_visualization import visualize_eigenvector_comprehensive")
    print("  visualize_eigenvector_comprehensive(eig_vecs, size, save_dir, ...)")

