#!/usr/bin/env python3
"""
Visualize the safety distributions for six environments:
- halfcheetah-medium-expert-v2
- halfcheetah-medium-replay-v2
- walker2d-medium-expert-v2
- walker2d-medium-replay-v2
- hopper-medium-expert-v2
- hopper-medium-replay-v2

Compare the distributions of RADAC, ORAAC, and DiffusionQL for each
environment and highlight the safe regions.
"""

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
from pathlib import Path
import time

# Environment list
ENVIRONMENTS = [
    "halfcheetah-medium-expert-v2",
    "halfcheetah-medium-replay-v2", 
    "walker2d-medium-expert-v2",
    "walker2d-medium-replay-v2",
    "hopper-medium-expert-v2",
    "hopper-medium-replay-v2"
]

ALGORITHMS = ["radac", "oraac", "ql"]

# Per-environment settings (adjust safe regions per dataset type)
ENV_CONFIGS = {
    "halfcheetah": {
        "file_prefix": "velocity",
        "safe_region_medium_replay": (0, 5),    # for medium-replay
        "safe_region_medium_expert": (0, 10),   # for medium-expert
        "x_label": "Velocity (m/s)",
        "x_lim": (-1, 8)  # shorten right range for medium-replay (14 → 8)
    },
    "walker2d": {
        "file_prefix": "angle", 
        "safe_region": (-0.5, 0.5),
        "x_label": "Angle (rad)",
        "x_lim": (-1, 1)
    },
    "hopper": {
        "file_prefix": "angle",
        "safe_region": (-0.1, 0.1), 
        "x_label": "Angle (rad)",
        "x_lim": (-0.3, 0.3)  # zoom into Hopper (-1,1 → -0.3,0.3)
    }
}

# Plotting styles per algorithm
ALGORITHM_STYLES = {
    "radac": {"color": "blue", "linestyle": "-", "alpha": 0.15, "label": "RADAC"},
    "oraac": {"color": "red", "linestyle": "-.", "alpha": 0.15, "label": "ORAAC"},
    "ql": {"color": "purple", "linestyle": "dotted", "alpha": 0.15, "label": "DiffusionQL"}
}

def load_safety_data(base_dir: str, env_name: str, algorithm: str) -> np.ndarray:
    """Load the safety distribution data."""
    # Extract environment type (halfcheetah, walker2d, hopper)
    env_type = env_name.split('-')[0]
    config = ENV_CONFIGS[env_type]
    
    file_name = f"{config['file_prefix']}_{algorithm}_{env_name}.npy"
    file_path = os.path.join(base_dir, env_name, file_name)
    
    if os.path.exists(file_path):
        data = np.load(file_path)
        print(f"  Loaded {algorithm}: {len(data):,} samples")
        return data
    else:
        print(f"  Warning: {file_path} not found")
        return None

def plot_safety_distribution(ax, data_dict: dict, env_name: str, config: dict):
    """Plot the safety distribution for one environment."""
    # Prepare the right axis (DiffusionQL)
    ax_right = None
    
    # Select safe region based on dataset type
    if "medium-replay" in env_name and "halfcheetah" in env_name:
        safe_region = config["safe_region_medium_replay"]
    elif "medium-expert" in env_name and "halfcheetah" in env_name:
        safe_region = config["safe_region_medium_expert"]
    else:
        safe_region = config["safe_region"]
    
    # Draw safe region
    safe_low, safe_high = safe_region
    if safe_low < safe_high:
        ax.axvspan(safe_low, safe_high, color="green", alpha=0.2, label="Safe Region")
    
    # Draw distributions per algorithm
    for alg, data in data_dict.items():
        if data is None:
            continue
            
        style = ALGORITHM_STYLES[alg]
        
        # Print basic stats for debugging
        print(f"    {alg}: {len(data)} samples, range=[{data.min():.3f}, {data.max():.3f}], mean={data.mean():.3f}")
        
        # Adjust bin count for small sample sizes
        bins = min(30, max(5, len(data) // 20))
        
        # Use the right axis for DiffusionQL
        if alg == "ql":
            if ax_right is None:
                ax_right = ax.twinx()
            ax_cur = ax_right
        else:
            ax_cur = ax
        
        # Histogram (hide for DiffusionQL)
        if alg != "ql":
            ax_cur.hist(data, bins=bins, density=True, alpha=0.25, 
                       color=style["color"], linewidth=0, edgecolor='none')
        
        # KDE with error handling
        try:
            if len(data) > 10:  # only run KDE when we have enough data
                kde_line = sns.kdeplot(data, color=style["color"], linestyle=style["linestyle"],
                                      linewidth=2.0, alpha=0.9, ax=ax_cur, label=style["label"])
                
                # Decide whether to fill the area under the curve
                should_fill = False
                if alg == "ql":  # always fill for DiffusionQL
                    should_fill = True
                elif alg == "oraac" and "hopper-medium-expert" not in env_name:
                    should_fill = True  # fill ORAAC except hopper-medium-expert
                elif alg == "radac" and ("hopper-medium-expert" in env_name or "hopper-medium-replay" in env_name):
                    should_fill = True  # fill RADAC for Hopper environments
                
                if should_fill:
                    line = ax_cur.lines[-1]  # most recent KDE line
                    x, y = line.get_xdata(), line.get_ydata()
                    ax_cur.fill_between(x, 0, y, color=style["color"], alpha=0.15, zorder=0)
            else:
                # Use histogram only when sample size is small
                ax_cur.hist(data, bins=bins, density=True, alpha=0.5, 
                           color=style["color"], linewidth=1, edgecolor=style["color"], 
                           label=f"{style['label']} (n={len(data)})")
        except Exception as e:
            print(f"    Warning: KDE failed for {alg}: {e}")
            # Fall back to histogram if KDE fails
            ax_cur.hist(data, bins=bins, density=True, alpha=0.5, 
                       color=style["color"], linewidth=1, edgecolor=style["color"], 
                       label=f"{style['label']} (hist only)")
    
    # Configure axes based on observed data range
    data_min = min([data.min() for data in data_dict.values() if data is not None])
    data_max = max([data.max() for data in data_dict.values() if data is not None])
    
    # Respect both configured x_lim and data range
    x_min = min(config["x_lim"][0], data_min - 0.1)
    x_max = max(config["x_lim"][1], data_max + 0.1)
    ax.set_xlim(x_min, x_max)
    if ax_right:
        ax_right.set_xlim(x_min, x_max)
        
        # Adjust DiffusionQL scale for Walker2d
        if "walker2d" in env_name and "ql" in data_dict and data_dict["ql"] is not None:
            # Limit right-axis range so DiffusionQL stands out
            ax_right.set_ylim(0, 4)
        
        # Adjust DiffusionQL scale for Hopper
        if "hopper" in env_name and "ql" in data_dict and data_dict["ql"] is not None:
            # Lower right-axis range to emphasize DiffusionQL
            ax_right.set_ylim(0, 15)
    
    # Hide Y-axis (distribution_plots.py style)
    ax.set_ylabel("")
    ax.set_yticks([])
    ax.tick_params(axis="y", left=False, right=False)
    if ax_right:
        ax_right.set_ylabel("")
        ax_right.set_yticks([])
        ax_right.tick_params(axis="y", left=False, right=False)
    
    # Hide X-axis (distribution_plots.py style)
    ax.set_xlabel("")
    ax.set_xticks([])
    ax.tick_params(axis="x", bottom=False, top=False)
    if ax_right:
        ax_right.set_xticks([])
        ax_right.tick_params(axis="x", bottom=False, top=False)
    
    # Hide spines
    ax.spines["left"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.spines["top"].set_visible(False)
    ax.spines["bottom"].set_visible(False)
    if ax_right:
        ax_right.spines["left"].set_visible(False)
        ax_right.spines["right"].set_visible(False)
        ax_right.spines["top"].set_visible(False)
        ax_right.spines["bottom"].set_visible(False)
    
    # Combine legends from both axes
    handles, labels = ax.get_legend_handles_labels()
    if ax_right:
        handles_r, labels_r = ax_right.get_legend_handles_labels()
        handles.extend(handles_r)
        labels.extend(labels_r)
    
    # Move Safe Region entry to the end
    if "Safe Region" in labels:
        idx = labels.index("Safe Region")
        safe_h, safe_l = handles.pop(idx), labels.pop(idx)
        handles.append(safe_h)
        labels.append(safe_l)
    
    # Render legend
    if handles:
        leg = ax.legend(handles, labels, loc="upper right", frameon=False, 
                       fontsize=12, handlelength=3, handletextpad=0.6)
        # Boldface RADAC entry
        for txt in leg.get_texts():
            if txt.get_text() == "RADAC":
                txt.set_fontweight("bold")
    
    # Make plotted lines thicker
    for line in ax.get_lines() + (ax_right.get_lines() if ax_right else []):
        line.set_linewidth(2.0)
    
    # Title
    env_display = env_name.replace('-v2', '').replace('-', ' ').title()
    ax.set_title(env_display, fontsize=12, fontweight='bold')

def create_combined_safety_plot(base_dir: str, output_dir: str):
    """Visualize safety distributions for all six environments."""
    print("Creating combined safety distribution plot...")
    print(f"Loading data from: {base_dir}")
    
    # Load data
    all_data = {}
    for env in ENVIRONMENTS:
        print(f"\nLoading {env}:")
        env_data = {}
        for alg in ALGORITHMS:
            data = load_safety_data(base_dir, env, alg)
            env_data[alg] = data
        all_data[env] = env_data
    
    # Create plots (wide layout)
    fig, axes = plt.subplots(2, 3, figsize=(24, 12))
    axes = axes.flatten()
    
    for i, env in enumerate(ENVIRONMENTS):
        env_type = env.split('-')[0]
        config = ENV_CONFIGS[env_type]
        
        ax = axes[i]
        plot_safety_distribution(ax, all_data[env], env, config)
    
    # Overall title
    fig.suptitle("Safety Distribution Comparison Across Environments", 
                 fontsize=16, fontweight='bold', y=0.98)
    
    # Layout adjustments
    plt.tight_layout()
    plt.subplots_adjust(top=0.94)
    
    # Save outputs
    os.makedirs(output_dir, exist_ok=True)
    timestamp = time.strftime("%Y%m%d_%H%M%S")
    
    # Save PNG
    png_path = os.path.join(output_dir, f"combined_safety_distributions_{timestamp}.png")
    plt.savefig(png_path, dpi=300, bbox_inches='tight')
    print(f"\nSaved PNG to: {png_path}")
    
    # Save PDF
    pdf_path = os.path.join(output_dir, f"combined_safety_distributions_{timestamp}.pdf")
    plt.savefig(pdf_path, format='pdf', bbox_inches='tight')
    print(f"Saved PDF to: {pdf_path}")
    
    plt.show()
    
    return all_data

def create_individual_plots(base_dir: str, output_dir: str):
    """Create individual plots per environment."""
    print("\nCreating individual environment plots...")
    
    for env in ENVIRONMENTS:
        print(f"\nProcessing {env}...")
        
        # Load data
        env_data = {}
        for alg in ALGORITHMS:
            data = load_safety_data(base_dir, env, alg)
            env_data[alg] = data
        
        # Create individual (wide) plot
        env_type = env.split('-')[0]
        config = ENV_CONFIGS[env_type]
        
        fig, ax = plt.subplots(figsize=(14, 6))
        plot_safety_distribution(ax, env_data, env, config)
        
        # Save output
        timestamp = time.strftime("%Y%m%d_%H%M%S")
        png_path = os.path.join(output_dir, f"{env}_safety_distribution_{timestamp}.png")
        plt.savefig(png_path, dpi=300, bbox_inches='tight')
        print(f"  Saved: {png_path}")
        
        plt.close()

def main():
    """Entrypoint."""
    base_dir = "../frozen_logs/safety_plot_results"
    output_dir = "safety_plot_summary"
    
    if not os.path.exists(base_dir):
        print(f"Error: {base_dir} directory not found!")
        return
    
    print("=" * 60)
    print("SAFETY DISTRIBUTION VISUALIZATION")
    print("=" * 60)
    print(f"Base directory: {base_dir}")
    print(f"Output directory: {output_dir}")
    print(f"Environments: {len(ENVIRONMENTS)}")
    print(f"Algorithms: {', '.join(ALGORITHMS)}")
    
    # Combined plot
    all_data = create_combined_safety_plot(base_dir, output_dir)
    
    # Individual plots
    create_individual_plots(base_dir, output_dir)
    
    print("\n" + "=" * 60)
    print("VISUALIZATION COMPLETED!")
    print("=" * 60)
    print(f"Check {output_dir} for all output files")

if __name__ == "__main__":
    main() 
