#!/usr/bin/env python
"""
summarize_eps_act_results.py
============================
Load eps_act_log.npy files for each algorithm/environment under eps_act_results/
and compute OOD action statistics, summarizing them in tables.

Outputs:
- ε_act statistics per environment and algorithm (mean, std, max, plateau step, etc.)
- Comparison tables across environments and algorithms
- CSV export
"""

import argparse
import json
import time
from pathlib import Path
from typing import Dict

import numpy as np
import pandas as pd

# ────────────────────────────────────────────────────────────
# Configuration
# ────────────────────────────────────────────────────────────

# Default eps_act_results directory
DEFAULT_EPS_ACT_DIR = "../frozen_logs/eps_act_results"

# Target environments and algorithms
ENVIRONMENTS = [
    "halfcheetah-medium-expert-v2",
    "hopper-medium-expert-v2", 
    "walker2d-medium-expert-v2"
]

ALGORITHMS = [
    "radac",
    "oraac", 
    "ql"
]

# Display names per algorithm
ALGORITHM_DISPLAY_NAMES = {
    "radac": "RADAC",
    "oraac": "ORAAC", 
    "ql": "Diffusion-QL"
}

# Default hyperparameters matching plot_eps_act_plus.py
DEFAULT_THR = 0.05
DEFAULT_EMA_ALPHA = 0.1
DEFAULT_WINDOW = 10000
DEFAULT_RATIO = 0.9
DEFAULT_LAST_WIN = 500000

# ────────────────────────────────────────────────────────────
# Utility functions
# ────────────────────────────────────────────────────────────

def ema(x: np.ndarray, alpha: float = 0.1) -> np.ndarray:
    """Compute exponential moving average."""
    y = np.empty_like(x, dtype=float)
    y[0] = x[0]
    for i in range(1, len(x)):
        y[i] = alpha * x[i] + (1 - alpha) * y[i - 1]
    return y

def first_plateau_step(signal: np.ndarray, thr: float, window: int = DEFAULT_WINDOW, ratio: float = DEFAULT_RATIO) -> int:
    """Detect the first step when a plateau below threshold starts (same as plot_eps_act_plus.py)."""
    for i in range(max(0, len(signal) - window)):
        if signal[i] < thr and (signal[i : i + window] < thr).mean() >= ratio:
            return i
    return len(signal) - 1

def calculate_eps_stats(eps_data: np.ndarray, thr: float = DEFAULT_THR, ema_alpha: float = DEFAULT_EMA_ALPHA) -> Dict:
    """Compute statistics from ε_act data (same parameters as plot_eps_act_plus.py)."""
    if len(eps_data) == 0:
        return None
    
    # Basic statistics
    total_steps = len(eps_data)
    mean_eps = np.mean(eps_data)
    std_eps = np.std(eps_data)
    min_eps = np.min(eps_data)
    max_eps = np.max(eps_data)
    
    # Statistics over the last 500k steps (same as plot_eps_act_plus.py)
    last_win = min(len(eps_data), DEFAULT_LAST_WIN)
    win_slice = eps_data[-last_win:]
    mean_last = np.mean(win_slice)
    std_last = np.std(win_slice)
    max_last = np.max(win_slice)
    
    # EMA and plateau detection
    ema_eps = ema(eps_data, ema_alpha)
    plateau_step = first_plateau_step(ema_eps, thr)
    
    # Ratios below threshold
    below_thr_ratio = (eps_data < thr).mean()
    below_thr_ratio_last = (win_slice < thr).mean()
    
    return {
        'total_steps': total_steps,
        'mean_eps': mean_eps,
        'std_eps': std_eps,
        'min_eps': min_eps,
        'max_eps': max_eps,
        'mean_last': mean_last,
        'std_last': std_last,
        'max_last': max_last,
        'plateau_step': plateau_step,
        'below_thr_ratio': below_thr_ratio,
        'below_thr_ratio_last': below_thr_ratio_last,
        'thr': thr,
        'ema_alpha': ema_alpha,
        'window': DEFAULT_WINDOW,
        'ratio': DEFAULT_RATIO
    }

def load_eps_act_data(eps_act_dir: Path) -> Dict[str, Dict[str, Dict]]:
    """Load all data from the eps_act_results directory."""
    results = {}
    
    for env in ENVIRONMENTS:
        results[env] = {}
        
        for algo in ALGORITHMS:
            # Build directory path
            dir_name = f"{env}|{algo}|"
            dir_path = eps_act_dir / dir_name
            
            if not dir_path.exists():
                print(f"Warning: Directory not found: {dir_path}")
                results[env][algo] = None
                continue
            
            # Load eps_act_log.npy file
            npy_file = dir_path / "eps_act_log.npy"
            if not npy_file.exists():
                print(f"Warning: eps_act_log.npy not found in: {dir_path}")
                results[env][algo] = None
                continue
            
            try:
                print(f"Loading: {npy_file}")
                eps_data = np.load(npy_file)
                stats = calculate_eps_stats(eps_data)
                results[env][ALGORITHM_DISPLAY_NAMES[algo]] = stats
                print(f"  Loaded {len(eps_data):,} steps, mean: {stats['mean_eps']:.6f}")
                
            except Exception as e:
                print(f"Error loading {npy_file}: {e}")
                results[env][ALGORITHM_DISPLAY_NAMES[algo]] = None
    
    return results

def create_summary_table(results: Dict[str, Dict[str, Dict]]) -> pd.DataFrame:
    """Summarize results into a DataFrame (same stats as plot_eps_act_plus.py, with ×100 scale)."""
    rows = []
    
    for env in ENVIRONMENTS:
        for algo in ALGORITHMS:
            stats = results[env][ALGORITHM_DISPLAY_NAMES[algo]]
            
            if stats is None:
                row = {
                    'Environment': env,
                    'Algorithm': ALGORITHM_DISPLAY_NAMES[algo],
                    'Total Steps': 'N/A',
                    'Mean ε_act (all)': 'N/A',
                    'Std ε_act (all)': 'N/A',
                    'Max ε_act (all)': 'N/A',
                    'Mean (last 500k)': 'N/A',
                    'Std (last 500k)': 'N/A',
                    'Max (last 500k)': 'N/A',
                    'Mean (×100 scale)': 'N/A',
                    'Std (×100 scale)': 'N/A',
                    'Max (×100 scale)': 'N/A',
                    'Plateau Step': 'N/A',
                    'Below Thr Ratio': 'N/A',
                    'Below Thr Ratio (last 500k)': 'N/A'
                }
            else:
                # Compute ×100 scale values
                mean_scaled = stats['mean_last'] * 100
                std_scaled = stats['std_last'] * 100
                max_scaled = stats['max_last'] * 100
                
                row = {
                    'Environment': env,
                    'Algorithm': ALGORITHM_DISPLAY_NAMES[algo],
                    'Total Steps': f"{stats['total_steps']:,}",
                    'Mean ε_act (all)': f"{stats['mean_eps']:.6f}",
                    'Std ε_act (all)': f"{stats['std_eps']:.6f}",
                    'Max ε_act (all)': f"{stats['max_eps']:.6f}",
                    'Mean (last 500k)': f"{stats['mean_last']:.6f}",
                    'Std (last 500k)': f"{stats['std_last']:.6f}",
                    'Max (last 500k)': f"{stats['max_last']:.6f}",
                    'Mean (×100 scale)': f"{mean_scaled:.2f}",
                    'Std (×100 scale)': f"{std_scaled:.2f}",
                    'Max (×100 scale)': f"{max_scaled:.2f}",
                    'Plateau Step': f"{stats['plateau_step']:,}",
                    'Below Thr Ratio': f"{stats['below_thr_ratio']:.3f}",
                    'Below Thr Ratio (last 500k)': f"{stats['below_thr_ratio_last']:.3f}"
                }
            
            rows.append(row)
    
    return pd.DataFrame(rows)

def print_summary_table(df: pd.DataFrame):
    """Pretty-print results (same stats as plot_eps_act_plus.py)."""
    print("\n" + "="*120)
    print("OOD ACTION (ε_act) SUMMARY TABLE - Last 500k steps (×100 scale)")
    print("="*120)
    
    # Group by environment
    for env in ENVIRONMENTS:
        print(f"\n{env.upper()}")
        print("-" * 80)
        
        env_df = df[df['Environment'] == env]
        
        # Header (same format as plot_eps_act_plus.py)
        header = f"{'Algorithm':<12} {'Mean±Std (×100)':<25} {'Max (×100)':<15} {'Plateau':<10}"
        print(header)
        print("-" * 80)
        
        for _, row in env_df.iterrows():
            if row['Mean (last 500k)'] != 'N/A':
                # Same: use last 500k steps stats
                mean_last = float(row['Mean (last 500k)'])
                std_last = float(row['Std (last 500k)'])
                max_last = float(row['Max (last 500k)'])
                
                # Show with ×100 scale (similar to Table 3)
                mean_std_scaled = f"{mean_last*100:.2f}±{std_last*100:.2f}"
                max_scaled = f"{max_last*100:.2f}"
                plateau = row['Plateau Step']
                
                print(f"{row['Algorithm']:<12} {mean_std_scaled:<25} {max_scaled:<15} {plateau:<10}")
            else:
                print(f"{row['Algorithm']:<12} {'N/A':<25} {'N/A':<15} {'N/A':<10}")

def save_results(results: Dict, df: pd.DataFrame, output_dir: Path):
    """Save results to files."""
    output_dir.mkdir(exist_ok=True)
    
    # Save detailed results as JSON
    json_path = output_dir / "eps_act_detailed_results.json"
    with open(json_path, 'w') as f:
        json.dump(results, f, indent=2, default=str)
    print(f"Detailed results saved to: {json_path}")
    
    # Save summary table as CSV
    csv_path = output_dir / "eps_act_summary_table.csv"
    df.to_csv(csv_path, index=False)
    print(f"Summary table saved to: {csv_path}")
    
    # Save per-environment/algorithm comparison table
    comparison_data = {}
    
    for env in ENVIRONMENTS:
        comparison_data[env] = {}
        for algo in ALGORITHMS:
            stats = results[env][ALGORITHM_DISPLAY_NAMES[algo]]
            if stats:
                comparison_data[env][ALGORITHM_DISPLAY_NAMES[algo]] = {
                    'mean_eps': stats['mean_eps'],
                    'std_eps': stats['std_eps'],
                    'max_eps': stats['max_eps'],
                    'mean_last': stats['mean_last'],
                    'plateau_step': stats['plateau_step'],
                    'below_thr_ratio': stats['below_thr_ratio']
                }
    
    comparison_path = output_dir / "eps_act_comparison.json"
    with open(comparison_path, 'w') as f:
        json.dump(comparison_data, f, indent=2, default=str)
    print(f"Comparison data saved to: {comparison_path}")

# ────────────────────────────────────────────────────────────
# Main
# ────────────────────────────────────────────────────────────

def main():
    parser = argparse.ArgumentParser(description="Summarize eps_act results from eps_act_results/")
    parser.add_argument("--eps_act_dir", type=Path, default=DEFAULT_EPS_ACT_DIR,
                       help=f"Directory containing eps_act results (default: {DEFAULT_EPS_ACT_DIR})")
    parser.add_argument("--output_dir", type=Path, default="eps_act_summary",
                       help="Output directory for results (default: eps_act_summary)")
    parser.add_argument("--thr", type=float, default=DEFAULT_THR,
                       help=f"Threshold for plateau detection (default: {DEFAULT_THR})")
    parser.add_argument("--ema_alpha", type=float, default=DEFAULT_EMA_ALPHA,
                       help=f"EMA alpha parameter (default: {DEFAULT_EMA_ALPHA})")
    parser.add_argument("--window", type=int, default=DEFAULT_WINDOW,
                       help=f"Window size for plateau detection (default: {DEFAULT_WINDOW})")
    parser.add_argument("--ratio", type=float, default=DEFAULT_RATIO,
                       help=f"Ratio threshold for plateau detection (default: {DEFAULT_RATIO})")
    
    args = parser.parse_args()
    
    # Check input directory
    if not args.eps_act_dir.exists():
        print(f"Error: eps_act directory not found: {args.eps_act_dir}")
        return
    
    print(f"Loading eps_act results from: {args.eps_act_dir}")
    print(f"Output directory: {args.output_dir}")
    print(f"Threshold: {args.thr}, EMA alpha: {args.ema_alpha}")
    print(f"Window: {args.window}, Ratio: {args.ratio}")
    print(f"Last window size: {DEFAULT_LAST_WIN}")
    
    # Load data
    start_time = time.time()
    results = load_eps_act_data(args.eps_act_dir)
    load_time = time.time() - start_time
    
    print(f"\nData loading completed in {load_time:.2f} seconds")
    
    # Create summary table
    df = create_summary_table(results)
    
    # Print results
    print_summary_table(df)
    
    # Save results
    save_results(results, df, args.output_dir)
    
    print(f"\nSummary completed! Check {args.output_dir} for detailed results.")

if __name__ == "__main__":
    main() 