#!/usr/bin/env python3
"""
6つの環境の安全分布をまとめて可視化するスクリプト
- halfcheetah-medium-expert-v2
- halfcheetah-medium-replay-v2  
- walker2d-medium-expert-v2
- walker2d-medium-replay-v2
- hopper-medium-expert-v2
- hopper-medium-replay-v2

各環境でRADAC、ORAAC、DiffusionQLの分布を比較し、安全領域を可視化します。
"""

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import time

# 環境設定
ENVIRONMENTS = [
    "halfcheetah-medium-expert-v2",
    "halfcheetah-medium-replay-v2", 
    "walker2d-medium-expert-v2",
    "walker2d-medium-replay-v2",
    "hopper-medium-expert-v2",
    "hopper-medium-replay-v2"
]

ALGORITHMS = ["radac", "oraac", "ql"]

# 環境ごとの設定（データセットタイプ別に安全領域を調整）
ENV_CONFIGS = {
    "halfcheetah": {
        "file_prefix": "velocity",
        "safe_region_medium_replay": (0, 5),    # medium-replay用
        "safe_region_medium_expert": (0, 10),   # medium-expert用
        "x_label": "Velocity (m/s)",
        "x_lim": (-1, 8)  # medium-replay用に右側の範囲を短縮（14 → 8）
    },
    "walker2d": {
        "file_prefix": "angle", 
        "safe_region": (-0.5, 0.5),
        "x_label": "Angle (rad)",
        "x_lim": (-1, 1)
    },
    "hopper": {
        "file_prefix": "angle",
        "safe_region": (-0.1, 0.1), 
        "x_label": "Angle (rad)",
        "x_lim": (-0.3, 0.3)  # Hopper環境をズーム（-1,1 → -0.3,0.3）
    }
}

# アルゴリズムのスタイル設定
ALGORITHM_STYLES = {
    "radac": {"color": "blue", "linestyle": "-", "alpha": 0.15, "label": "RADAC"},
    "oraac": {"color": "red", "linestyle": "-.", "alpha": 0.15, "label": "ORAAC"},
    "ql": {"color": "purple", "linestyle": "dotted", "alpha": 0.15, "label": "DiffusionQL"}
}

def load_safety_data(base_dir: str, env_name: str, algorithm: str) -> np.ndarray:
    """安全分布データを読み込む"""
    # 環境タイプを抽出（halfcheetah, walker2d, hopper）
    env_type = env_name.split('-')[0]
    config = ENV_CONFIGS[env_type]
    
    file_name = f"{config['file_prefix']}_{algorithm}_{env_name}.npy"
    file_path = os.path.join(base_dir, env_name, file_name)
    
    if os.path.exists(file_path):
        data = np.load(file_path)
        print(f"  Loaded {algorithm}: {len(data):,} samples")
        return data
    else:
        print(f"  Warning: {file_path} not found")
        return None

def plot_safety_distribution(ax, data_dict: dict, env_name: str, config: dict):
    """単一環境の安全分布をプロット（distribution_plots.pyスタイル）"""
    # 右軸の準備（DiffusionQL用）
    ax_right = None
    
    # データセットタイプに応じた安全領域を取得
    if "medium-replay" in env_name and "halfcheetah" in env_name:
        safe_region = config["safe_region_medium_replay"]
    elif "medium-expert" in env_name and "halfcheetah" in env_name:
        safe_region = config["safe_region_medium_expert"]
    else:
        safe_region = config["safe_region"]
    
    # 安全領域を描画
    safe_low, safe_high = safe_region
    if safe_low < safe_high:
        ax.axvspan(safe_low, safe_high, color="green", alpha=0.2, label="Safe Region")
    
    # 各アルゴリズムの分布を描画
    for alg, data in data_dict.items():
        if data is None:
            continue
            
        style = ALGORITHM_STYLES[alg]
        
        # データの基本統計を表示（デバッグ用）
        print(f"    {alg}: {len(data)} samples, range=[{data.min():.3f}, {data.max():.3f}], mean={data.mean():.3f}")
        
        # データが少ない場合はビン数を調整
        bins = min(30, max(5, len(data) // 20))
        
        # DiffusionQLの場合は右軸を使用
        if alg == "ql":
            if ax_right is None:
                ax_right = ax.twinx()
            ax_cur = ax_right
        else:
            ax_cur = ax
        
        # ヒストグラム（DiffusionQLは非表示）
        if alg != "ql":
            ax_cur.hist(data, bins=bins, density=True, alpha=0.25, 
                       color=style["color"], linewidth=0, edgecolor='none')
        
        # KDE（エラーハンドリング付き）
        try:
            if len(data) > 10:  # データが十分にある場合のみKDE
                kde_line = sns.kdeplot(data, color=style["color"], linestyle=style["linestyle"],
                                      linewidth=2.0, alpha=0.9, ax=ax_cur, label=style["label"])
                
                # 塗りつぶしの条件設定
                should_fill = False
                if alg == "ql":  # DiffusionQLは常に塗りつぶし
                    should_fill = True
                elif alg == "oraac" and "hopper-medium-expert" not in env_name:
                    should_fill = True  # ORAAC: hopper-medium-expert以外はFill
                elif alg == "radac" and ("hopper-medium-expert" in env_name or "hopper-medium-replay" in env_name):
                    should_fill = True  # RADAC for hopper environments
                
                if should_fill:
                    line = ax_cur.lines[-1]  # 直近のKDE線
                    x, y = line.get_xdata(), line.get_ydata()
                    ax_cur.fill_between(x, 0, y, color=style["color"], alpha=0.15, zorder=0)
            else:
                # データが少ない場合はヒストグラムのみ
                ax_cur.hist(data, bins=bins, density=True, alpha=0.5, 
                           color=style["color"], linewidth=1, edgecolor=style["color"], 
                           label=f"{style['label']} (n={len(data)})")
        except Exception as e:
            print(f"    Warning: KDE failed for {alg}: {e}")
            # KDEが失敗した場合はヒストグラムのみ表示
            ax_cur.hist(data, bins=bins, density=True, alpha=0.5, 
                       color=style["color"], linewidth=1, edgecolor=style["color"], 
                       label=f"{style['label']} (hist only)")
    
    # 軸の設定（データに基づいて動的に調整）
    data_min = min([data.min() for data in data_dict.values() if data is not None])
    data_max = max([data.max() for data in data_dict.values() if data is not None])
    
    # 既存のx_limとデータ範囲を考慮
    x_min = min(config["x_lim"][0], data_min - 0.1)
    x_max = max(config["x_lim"][1], data_max + 0.1)
    ax.set_xlim(x_min, x_max)
    if ax_right:
        ax_right.set_xlim(x_min, x_max)
        
        # Walker2d環境のDiffusionQLのスケールを調整
        if "walker2d" in env_name and "ql" in data_dict and data_dict["ql"] is not None:
            # 右軸のY軸範囲を制限してDiffusionQLのスケールを調整
            ax_right.set_ylim(0, 4)  # Y軸の最大値をさらに小さくしてDiffusionQLをより大きく表示
        
        # Hopper環境のDiffusionQLのスケールを調整
        if "hopper" in env_name and "ql" in data_dict and data_dict["ql"] is not None:
            # 右軸のY軸範囲を小さく設定してDiffusionQLのスケールを調整
            ax_right.set_ylim(0, 15)  # Y軸の最大値を小さく設定してDiffusionQLを大きく表示
    
    # Y軸の非表示（distribution_plots.pyスタイル）
    ax.set_ylabel("")
    ax.set_yticks([])
    ax.tick_params(axis="y", left=False, right=False)
    if ax_right:
        ax_right.set_ylabel("")
        ax_right.set_yticks([])
        ax_right.tick_params(axis="y", left=False, right=False)
    
    # X軸の非表示（distribution_plots.pyスタイル）
    ax.set_xlabel("")
    ax.set_xticks([])
    ax.tick_params(axis="x", bottom=False, top=False)
    if ax_right:
        ax_right.set_xticks([])
        ax_right.tick_params(axis="x", bottom=False, top=False)
    
    # 枠線の非表示
    ax.spines["left"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.spines["top"].set_visible(False)
    ax.spines["bottom"].set_visible(False)
    if ax_right:
        ax_right.spines["left"].set_visible(False)
        ax_right.spines["right"].set_visible(False)
        ax_right.spines["top"].set_visible(False)
        ax_right.spines["bottom"].set_visible(False)
    
    # 凡例の統合
    handles, labels = ax.get_legend_handles_labels()
    if ax_right:
        handles_r, labels_r = ax_right.get_legend_handles_labels()
        handles.extend(handles_r)
        labels.extend(labels_r)
    
    # Safe Regionを末尾に移動
    if "Safe Region" in labels:
        idx = labels.index("Safe Region")
        safe_h, safe_l = handles.pop(idx), labels.pop(idx)
        handles.append(safe_h)
        labels.append(safe_l)
    
    # 凡例の描画
    if handles:
        leg = ax.legend(handles, labels, loc="upper right", frameon=False, 
                       fontsize=12, handlelength=3, handletextpad=0.6)
        # RADACを太字に
        for txt in leg.get_texts():
            if txt.get_text() == "RADAC":
                txt.set_fontweight("bold")
    
    # 線を太くする
    for line in ax.get_lines() + (ax_right.get_lines() if ax_right else []):
        line.set_linewidth(2.0)
    
    # タイトル
    env_display = env_name.replace('-v2', '').replace('-', ' ').title()
    ax.set_title(env_display, fontsize=12, fontweight='bold')

def create_combined_safety_plot(base_dir: str, output_dir: str):
    """6つの環境の安全分布をまとめて可視化"""
    print("Creating combined safety distribution plot...")
    print(f"Loading data from: {base_dir}")
    
    # データを読み込み
    all_data = {}
    for env in ENVIRONMENTS:
        print(f"\nLoading {env}:")
        env_data = {}
        for alg in ALGORITHMS:
            data = load_safety_data(base_dir, env, alg)
            env_data[alg] = data
        all_data[env] = env_data
    
    # Create plot (wide layout)
    fig, axes = plt.subplots(2, 3, figsize=(24, 12))
    axes = axes.flatten()
    
    for i, env in enumerate(ENVIRONMENTS):
        env_type = env.split('-')[0]
        config = ENV_CONFIGS[env_type]
        
        ax = axes[i]
        plot_safety_distribution(ax, all_data[env], env, config)
    
    # Overall title
    fig.suptitle("Safety Distribution Comparison Across Environments", 
                 fontsize=16, fontweight='bold', y=0.98)
    
    # Layout adjustment
    plt.tight_layout()
    plt.subplots_adjust(top=0.94)
    
    # Save
    os.makedirs(output_dir, exist_ok=True)
    timestamp = time.strftime("%Y%m%d_%H%M%S")
    
    # Save as PNG
    png_path = os.path.join(output_dir, f"combined_safety_distributions_{timestamp}.png")
    plt.savefig(png_path, dpi=300, bbox_inches='tight')
    print(f"\nSaved PNG to: {png_path}")
    
    # Save as PDF
    pdf_path = os.path.join(output_dir, f"combined_safety_distributions_{timestamp}.pdf")
    plt.savefig(pdf_path, format='pdf', bbox_inches='tight')
    print(f"Saved PDF to: {pdf_path}")
    
    plt.show()
    
    return all_data

def create_individual_plots(base_dir: str, output_dir: str):
    """Also create individual plots for each environment"""
    print("\nCreating individual environment plots...")
    
    for env in ENVIRONMENTS:
        print(f"\nProcessing {env}...")
        
        # Load data
        env_data = {}
        for alg in ALGORITHMS:
            data = load_safety_data(base_dir, env, alg)
            env_data[alg] = data
        
        # Create individual plot (wide layout)
        env_type = env.split('-')[0]
        config = ENV_CONFIGS[env_type]
        
        fig, ax = plt.subplots(figsize=(14, 6))
        plot_safety_distribution(ax, env_data, env, config)
        
        # Save
        timestamp = time.strftime("%Y%m%d_%H%M%S")
        png_path = os.path.join(output_dir, f"{env}_safety_distribution_{timestamp}.png")
        plt.savefig(png_path, dpi=300, bbox_inches='tight')
        print(f"  Saved: {png_path}")
        
        plt.close()

def main():
    """Main function"""
    base_dir = "../frozen_logs/safety_plot_results"
    output_dir = "safety_plot_summary"
    
    if not os.path.exists(base_dir):
        print(f"Error: {base_dir} directory not found!")
        return
    
    print("=" * 60)
    print("SAFETY DISTRIBUTION VISUALIZATION")
    print("=" * 60)
    print(f"Base directory: {base_dir}")
    print(f"Output directory: {output_dir}")
    print(f"Environments: {len(ENVIRONMENTS)}")
    print(f"Algorithms: {', '.join(ALGORITHMS)}")
    
    # Create combined plot
    all_data = create_combined_safety_plot(base_dir, output_dir)
    
    # Create individual plots
    create_individual_plots(base_dir, output_dir)
    
    print("\n" + "=" * 60)
    print("VISUALIZATION COMPLETED!")
    print("=" * 60)
    print(f"Check {output_dir} for all output files")

if __name__ == "__main__":
    main() 