#!/usr/bin/env python3
"""
Visualization script for toy 2D risk bandit algorithm results.
Creates individual panel plots in the same style as run_toy_ql.py
"""

import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import math

# Try to import torch, but make it optional
try:
    import torch
    TORCH_AVAILABLE = True
except ImportError:
    TORCH_AVAILABLE = False
    print("Warning: PyTorch not available. Using numpy for ground truth generation.")

def generate_ground_truth_data(N=10000, device='cpu'):
    """
    Generate ground truth donut dataset for comparison.
    Based on the generate_donut_dataset function from risky_bandit.py
    """
    ring_radius, ring_std = 0.9, 0.04
    pct_center, center_std = 0.20, 0.08
    pct_trap, trap_penalty = 0.05, -40.
    ring_mean, ring_rstd = 9., 0.3
    center_mean = 5.
    
    # Position sampling
    N_ring, N_center = int(N*(1-pct_center)), N - int(N*(1-pct_center))
    
    if TORCH_AVAILABLE:
        theta = 2*math.pi*torch.rand(N_ring, device=device)
        radius = torch.normal(ring_radius, ring_std, (N_ring,), device=device)
        ring_xy = torch.stack([radius*torch.cos(theta), radius*torch.sin(theta)], 1)
        center_xy = torch.normal(0., center_std, (N_center, 2), device=device)
        action = torch.cat([ring_xy, center_xy], 0).clamp_(-1., 1.)
        return action.cpu().numpy()
    else:
        # Use numpy instead of torch
        theta = 2*math.pi*np.random.rand(N_ring)
        radius = np.random.normal(ring_radius, ring_std, N_ring)
        ring_xy = np.column_stack([radius*np.cos(theta), radius*np.sin(theta)])
        center_xy = np.random.normal(0., center_std, (N_center, 2))
        action = np.vstack([ring_xy, center_xy])
        action = np.clip(action, -1., 1.)
        return action

def pick_color(name: str) -> str:
    """
    Color scheme matching run_toy_ql.py
    """
    if 'RAFMAC' in name:
        return '#2ca02c'         # Green
    elif 'RADAC' in name:
        return '#2ca02c'         # Green
    elif 'ORAAC' in name:
        return '#e41a1c'         # Red
    elif any(key in name for key in ('Diffusion-QL', 'FQL', 'QL-CVAE')):
        return '#e41a1c'         # Red
    else:
        return '#e41a1c'         # Default red

def load_algorithm_data(data_dir="../frozen_logs/toy_results"):
    """
    Load only the originally requested algorithm data files.
    Returns a dictionary with algorithm names as keys and data as values.
    """
    data_path = Path(data_dir)
    
    # Map user-requested algorithms to available files (updated request)
    algorithm_mapping = {
        'CVAE-QL': 'ql_cvae_actions.npy',
        'Diffusion-QL': 'diffusion_ql_actions.npy',
        'Flow-QL': 'fql_actions.npy',
        'ORAAC': 'oraac_actions.npy',
        'ORAAC-Diffusion': 'oraac_diffusion_actions.npy',
        'ORAAC-Flow': 'oraac_flow_actions.npy',
        'RADAC': 'radac_actions.npy',
        'RAFMAC': 'rafmac_actions.npy'
    }
    
    algorithm_data = {}
    
    for alg_name, filename in algorithm_mapping.items():
        file_path = data_path / filename
        if file_path.exists():
            try:
                data = np.load(file_path)
                algorithm_data[alg_name] = data
                print(f"Loaded {alg_name}: {data.shape}")
            except Exception as e:
                print(f"Error loading {filename}: {e}")
        else:
            print(f"File not found: {filename}")
    
    return algorithm_data

def create_single_comparison_plot(algorithm_data, ground_truth_data, output_path="toy_comparison.png"):
    """
    Create a single comparison plot with all algorithms in a 4x2 grid layout.
    """
    # Fixed grid: 4 columns x 2 rows for 8 algorithms
    n_cols = 4
    n_rows = 2
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(16, 8))
    
    # Set up all axes with grid and proper styling
    axis_lim = 1.1
    for row in range(n_rows):
        for col in range(n_cols):
            ax = axes[row, col]
            ax.set_xlim(-axis_lim, axis_lim)
            ax.set_ylim(-axis_lim, axis_lim)
            ax.set_xticks([])
            ax.set_yticks([])
            ax.grid(True, alpha=0.3)
            ax.set_aspect('equal')
    
    # Plot each algorithm in 4x2 grid
    for i, (alg_name, data) in enumerate(algorithm_data.items()):
        row = i // n_cols
        col = i % n_cols
        ax = axes[row, col]
        
        # Handle different data shapes
        if len(data.shape) == 3:
            data = data.reshape(-1, data.shape[-1])
        
        if data.shape[1] >= 2:
            color = pick_color(alg_name)
            ax.scatter(data[:, 0], data[:, 1], 
                      c=color, alpha=0.5, s=30)
        else:
            print(f"Warning: {alg_name} data has insufficient dimensions: {data.shape}")
            ax.text(0.5, 0.5, f'Data shape: {data.shape}', 
                   transform=ax.transAxes, ha='center', va='center')
        
        ax.set_title(alg_name, fontsize=20, fontweight='bold', pad=6)
    
    # Hide unused subplots (if any)
    for i in range(len(algorithm_data), n_rows * n_cols):
        row = i // n_cols
        col = i % n_cols
        axes[row, col].set_visible(False)
    
    plt.tight_layout(pad=2.0, h_pad=1.5, w_pad=1.0)
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"Comparison plot saved to: {output_path}")
    plt.close()

def main():
    """
    Main function to run the visualization.
    """
    print("Loading algorithm data...")
    algorithm_data = load_algorithm_data()
    
    if not algorithm_data:
        print("No algorithm data found. Please check the file paths.")
        return
    
    print("Generating ground truth data...")
    ground_truth_data = generate_ground_truth_data()
    
    print("Creating single comparison plot...")
    create_single_comparison_plot(algorithm_data, ground_truth_data)
    
    print("Visualization complete!")

if __name__ == "__main__":
    main()
