#!/usr/bin/env python3
"""
Compute normalized mean, SE, and CVaR from ep_returns.npy.
Loads ep_returns for each seed in the specified directories and computes statistics.
"""

import numpy as np
import os
import json
from pathlib import Path
import argparse

def load_ep_returns(directory_path):
    """
    Load ep_returns.npy from the specified directory.
    
    Args:
        directory_path (str): Directory path
        
    Returns:
        np.ndarray: Episode return array
    """
    ep_returns_path = os.path.join(directory_path, "ep_returns.npy")
    if not os.path.exists(ep_returns_path):
        raise FileNotFoundError(f"ep_returns.npy not found in {directory_path}")
    
    return np.load(ep_returns_path)

def transform_scores(ep_returns, env_name, normalize="none"):
    """
    Transform episode returns using the specified method.
    
    Args:
        ep_returns (np.ndarray): Episode returns
        env_name (str): Environment name
        normalize (str): Normalization method ("none", "base")
        
    Returns:
        np.ndarray: Transformed scores
    """
    ep_returns = np.asarray(ep_returns, float)
    
    if normalize == "base":
        try:
            import gym
            import d4rl
            
            # Create environment and get normalization function
            env = gym.make(env_name)
            vals = np.array([env.get_normalized_score(r) for r in ep_returns], dtype=float)
            
            # If normalized scores are 0-1, scale to 0-100 (safety fallback)
            if np.nanmax(vals) <= 1.0:
                vals *= 100.0
                print(f"  Note: Normalized scores were 0-1, scaled to 0-100")
            
            return vals
        except Exception as e:
            print(f"Warning: Could not normalize scores for {env_name}: {e}")
            print("Using raw scores instead")
            return ep_returns
    else:
        # normalize="none" -> raw returns
        return ep_returns

def calculate_cvar(scores, alpha=0.1):
    """
    Compute CVaR@alpha.
    
    Args:
        scores (np.ndarray): Score array
        alpha (float): Risk level (default: 0.1)
        
    Returns:
        float: CVaR@alpha
    """
    if len(scores) == 0:
        return np.nan
    
    # Sort ascending
    sorted_scores = np.sort(scores)
    
    # Index of alpha quantile
    n = len(sorted_scores)
    k = int(np.ceil(alpha * n))
    
    if k == 0:
        return sorted_scores[0]
    
    # CVaR@alpha = (1/alpha) * sum_{i=1}^k x_i
    cvar = np.mean(sorted_scores[:k])
    return cvar

def calculate_se(scores):
    """
    Compute standard error.
    
    Args:
        scores (np.ndarray): Score array
        
    Returns:
        float: Standard error
    """
    if len(scores) == 0:
        return np.nan
    
    return np.std(scores) / np.sqrt(len(scores))

def process_directory(base_dir, env_name, algorithm_info, normalize="none", alpha=0.1):
    """
    Process each seed directory matching the pattern.
    
    Args:
        base_dir (str): Base directory
        env_name (str): Environment name
        algorithm_info (dict): Algorithm info (pattern, seeds, name)
        
    Returns:
        dict: Computed results
    """
    pattern = algorithm_info['pattern']
    seeds = algorithm_info['seeds']
    algo_name = algorithm_info['name']
    
    # Find directories for the specified seeds
    seed_dirs = []
    for seed in seeds:
        dir_path = f"../{pattern}{seed}"
        
        if os.path.exists(dir_path):
            seed_dirs.append((seed, dir_path))
        else:
            print(f"Warning: Directory not found: {dir_path}")
    
    if not seed_dirs:
        raise FileNotFoundError(f"No valid directories found for pattern: {pattern}")
    
    print(f"Found {len(seed_dirs)} seed directories for {algo_name}")
    
    # Compute statistics per seed
    seed_results = {}
    seed_means = []
    seed_cvars = []
    
    for seed, dir_path in seed_dirs:
        print(f"Processing seed {seed}: {dir_path}")
        
        try:
            # Load ep_returns
            ep_returns = load_ep_returns(dir_path)
            print(f"  Loaded {len(ep_returns)} episodes")
            
            # Transform scores (raw for risky D4RL, normalized for standard D4RL)
            scores_for_metric = transform_scores(ep_returns, env_name, normalize=normalize)
            
            # Compute per-seed statistics
            mean_score = np.mean(scores_for_metric)
            cvar_score = calculate_cvar(scores_for_metric, alpha=0.1)
            
            seed_results[seed] = {
                'mean': mean_score,
                'cvar': cvar_score,
                'episodes': len(scores_for_metric),
                'raw_scores': ep_returns.tolist(),
                'scores_for_metric': scores_for_metric.tolist()
            }
            
            # Collect values for across-seed aggregation
            seed_means.append(float(mean_score))
            seed_cvars.append(float(cvar_score))
            
            print(f"  Seed {seed}: mean={mean_score:.4f}, cvar={cvar_score:.4f}")
            
        except Exception as e:
            print(f"  Error processing seed {seed}: {e}")
            seed_results[seed] = None
    
    # Aggregate across seeds (mean ± error)
    if seed_means and seed_cvars:
        # Compute SEM (standard error of the mean)
        def sem(x):
            x = np.asarray(x, float)
            return np.std(x, ddof=1) / np.sqrt(len(x)) if len(x) > 1 else np.nan
        
        # Get first valid result (for episodes_per_seed)
        first_valid_result = next((r for r in seed_results.values() if r is not None), None)
        
        overall_results = {
            'mean': float(np.mean(seed_means)),
            'mean_err': float(sem(seed_means)),      # SEM (recommended)
            'mean_std': float(np.std(seed_means, ddof=1)),   # STD (commonly reported)
            'cvar': float(np.mean(seed_cvars)),
            'cvar_err': float(sem(seed_cvars)),      # CVaR SEM
            'cvar_std': float(np.std(seed_cvars, ddof=1)),   # CVaR STD
            'seeds_used': len(seed_means),
            'episodes_per_seed': len(first_valid_result['scores_for_metric']) if first_valid_result else 0,
            'normalization_method': normalize,
            'metric_basis': 'raw_returns' if normalize == 'none' else 'normalized_0_100'
        }
        
        # Appendix: pooled CVaR (aggregate all episodes across seeds)
        all_scores_for_metric = []
        for seed_result in seed_results.values():
            if seed_result:
                all_scores_for_metric.extend(seed_result['scores_for_metric'])
        
        if all_scores_for_metric:
            pooled_cvar = calculate_cvar(np.array(all_scores_for_metric), alpha=0.1)
            overall_results['pooled_cvar'] = float(pooled_cvar)
            overall_results['total_episodes'] = len(all_scores_for_metric)
    else:
        overall_results = None
    
    return {
        'algorithm': algo_name,
        'seed_results': seed_results,
        'overall_results': overall_results
    }

def get_algorithm_configs(env_name):
    """
    Return algorithm configs for the given environment.
    
    Args:
        env_name (str): Environment name
        
    Returns:
        list: Algorithm config list
    """
    if env_name == "walker2d-medium-replay-v2":
        return [
            {
                'name': 'RADAC-CPW',
                'pattern': f"frozen_logs/ablation_results/{env_name}|radac|cpw|",
                'seeds': [0, 1, 2]
            },
            {
                'name': 'RADAC-WANG',
                'pattern': f"frozen_logs/ablation_results/{env_name}|radac|wang|",
                'seeds': [0, 1, 2]
            },
            {
                'name': 'RADAC-CVaR',
                'pattern': f"frozen_logs/ablation_results/{env_name}|radac|cvar/",
                'seeds': [0, 1, 2]
            },
            {
                'name': 'RAFMAC-CPW',
                'pattern': f"frozen_logs/ablation_results/{env_name}|rafmac|cpw|",
                'seeds': [0, 1, 2]
            },
            {
                'name': 'RAFMAC-WANG',
                'pattern': f"frozen_logs/ablation_results/{env_name}|rafmac|wang|",
                'seeds': [0, 1, 2]
            },
            {
                'name': 'RAFMAC-CVaR',
                'pattern': f"frozen_logs/ablation_results/{env_name}|rafmac|cvar|",
                'seeds': [0, 1, 2]
            }
        ]
    elif env_name == "halfcheetah-medium-replay-v2":
        return [
            {
                'name': 'RADAC-CPW',
                'pattern': f"frozen_logs/ablation_results/{env_name}|radac|cpw|",
                'seeds': [0, 1, 2]
            },
            {
                'name': 'RADAC-WANG',
                'pattern': f"frozen_logs/ablation_results/{env_name}|radac|wang|",
                'seeds': [0, 1, 2]
            },
            {
                'name': 'RADAC-CVaR',
                'pattern': f"frozen_logs/ablation_results/{env_name}|radac|cvar/",
                'seeds': [0, 1, 2]
            },
            {
                'name': 'RAFMAC-CPW',
                'pattern': f"frozen_logs/ablation_results/{env_name}|rafmac|cpw|",
                'seeds': [0, 1, 2]
            },
            {
                'name': 'RAFMAC-WANG',
                'pattern': f"frozen_logs/ablation_results/{env_name}|rafmac|wang|",
                'seeds': [0, 1, 2]
            },
            {
                'name': 'RAFMAC-CVaR',
                'pattern': f"frozen_logs/ablation_results/{env_name}|rafmac|cvar|",
                'seeds': [0, 1, 2]
            }
        ]
    else:
        return []

def main():
    parser = argparse.ArgumentParser(description='Calculate metrics from ep_returns.npy files')
    parser.add_argument('--env_name', default='walker2d-medium-replay-v2', help='Environment name')
    parser.add_argument('--output', default='metrics_summary.json', help='Output JSON file')
    parser.add_argument('--normalize', choices=['none', 'base'], default='none', 
                       help='Normalization: none=raw returns (risky D4RL), base=0-100 normalized (standard D4RL)')
    parser.add_argument('--all_envs', action='store_true', 
                       help='Process both walker2d and halfcheetah environments')
    parser.add_argument('--scientific', action='store_true',
                       help='Use scientific notation (×10^n) for large raw values')
    
    args = parser.parse_args()
    
    if args.all_envs:
        # Process both environments
        envs_to_process = ["walker2d-medium-replay-v2", "halfcheetah-medium-replay-v2"]
    else:
        # Process only the specified environment
        envs_to_process = [args.env_name]
    
    all_results = {}
    
    for env_name in envs_to_process:
        print(f"\n{'='*80}")
        print(f"PROCESSING ENVIRONMENT: {env_name}")
        print(f"{'='*80}")
        
        # Get algorithm configs per environment
        algorithms = get_algorithm_configs(env_name)
        
        env_results = {}
        
        for algo_info in algorithms:
            print(f"\n{'-'*60}")
            print(f"Processing {algo_info['name']} for {env_name}")
            print(f"{'-'*60}")
            
            try:
                results = process_directory(None, env_name, algo_info, 
                                         normalize=args.normalize, alpha=0.1)
                env_results[algo_info['name']] = results
                
                # Print results
                if results['overall_results']:
                    print(f"\n{algo_info['name']} Overall Results:")
                    print(f"  Mean: {results['overall_results']['mean']:.4f} ± {results['overall_results']['mean_err']:.4f} (SEM)")
                    print(f"  Mean: {results['overall_results']['mean']:.4f} ± {results['overall_results']['mean_std']:.4f} (STD)")
                    print(f"  CVaR@0.1: {results['overall_results']['cvar']:.4f} ± {results['overall_results']['cvar_err']:.4f} (SEM)")
                    print(f"  CVaR@0.1: {results['overall_results']['cvar']:.4f} ± {results['overall_results']['cvar_std']:.4f} (STD)")
                    print(f"  Pooled CVaR@0.1: {results['overall_results'].get('pooled_cvar', 'N/A')}")
                    print(f"  Episodes per seed: {results['overall_results']['episodes_per_seed']}")
                    print(f"  Seeds Used: {results['overall_results']['seeds_used']}")
                    if 'total_episodes' in results['overall_results']:
                        print(f"  Total Episodes: {results['overall_results']['total_episodes']}")
                
            except Exception as e:
                print(f"Error processing {algo_info['name']}: {e}")
                env_results[algo_info['name']] = None
        
        all_results[env_name] = env_results
    
    # Save results to JSON
    output_path = args.output
    with open(output_path, 'w') as f:
        json.dump(all_results, f, indent=2, default=str)
    
    print(f"\nResults saved to: {output_path}")
    
    # Display format note
    if args.scientific:
        print(f"\nDisplay method: Raw returns with scientific notation (×10^n for large values)")
    else:
        print(f"\nDisplay method: Raw returns")
    
    # Print results per environment
    for env_name, env_results in all_results.items():
        print(f"\n{'='*80}")
        print(f"RESULTS FOR {env_name.upper()}")
        print(f"{'='*80}")
        
        # Display format note
        if args.scientific:
            print(f"\nDisplay: Raw returns with scientific notation (×10^n for readability)")
        else:
            print(f"\nDisplay: Raw returns")
        
        # Scientific notation formatter
        def format_value(value, error, use_scientific=False):
            if use_scientific and (abs(value) >= 100 or abs(error) >= 100):
                # Use scientific notation
                exp = int(np.floor(np.log10(max(abs(value), abs(error)))))
                if exp >= 0:
                    return f"{value/10**exp:.2f}×10^{exp}±{error/10**exp:.2f}×10^{exp}"
                else:
                    return f"{value/10**exp:.2f}×10^{exp}±{error/10**exp:.2f}×10^{exp}"
            else:
                # Use standard decimal formatting
                return f"{value:.2f}±{error:.2f}"
        
        # Summary table (SEM)
        print(f\"\\nSUMMARY TABLE (Seed mean ± SEM)\")
        print(f"{'='*80}")
        print(f"{'Algorithm':<20} {'Mean±SEM':<25} {'CVaR±SEM':<25} {'Seeds':<8}")
        print(f"{'-'*80}")
        
        for algo_name in env_results.keys():
            if env_results.get(algo_name) and env_results[algo_name].get('overall_results'):
                results = env_results[algo_name]['overall_results']
                mean_str = format_value(results['mean'], results['mean_err'], args.scientific)
                cvar_str = format_value(results['cvar'], results['cvar_err'], args.scientific)
                print(f"{algo_name:<20} {mean_str:<25} {cvar_str:<25} {results['seeds_used']:<8}")
            else:
                print(f"{algo_name:<20} {'N/A':<25} {'N/A':<25} {'N/A':<8}")
        
        # Summary table (STD)
        print(f\"\\nSUMMARY TABLE (Seed mean ± STD)\")
        print(f"{'='*80}")
        print(f"{'Algorithm':<20} {'Mean±STD':<25} {'CVaR±STD':<25} {'Seeds':<8}")
        print(f"{'='*80}")
        
        for algo_name in env_results.keys():
            if env_results.get(algo_name) and env_results[algo_name].get('overall_results'):
                results = env_results[algo_name]['overall_results']
                mean_str = format_value(results['mean'], results['mean_std'], args.scientific)
                cvar_str = format_value(results['cvar'], results['cvar_std'], args.scientific)
                print(f"{algo_name:<20} {mean_str:<25} {cvar_str:<25} {results['seeds_used']:<8}")
            else:
                print(f"{algo_name:<20} {'N/A':<25} {'N/A':<25} {'N/A':<8}")
        
        # Pooled CVaR
        print(f\"\\nPOOLED CVAR (CVaR@0.1 pooled across all episodes)\")
        print(f"{'='*80}")
        print(f"{'Algorithm':<20} {'Pooled CVaR':<15} {'Total Episodes':<15}")
        print(f"{'-'*80}")
        
        for algo_name in env_results.keys():
            if env_results.get(algo_name) and env_results[algo_name].get('overall_results'):
                results = env_results[algo_name]['overall_results']
                pooled_cvar = results.get('pooled_cvar', 'N/A')
                total_episodes = results.get('total_episodes', 'N/A')
                print(f"{algo_name:<20} {pooled_cvar:<15} {total_episodes:<15}")
            else:
                print(f"{algo_name:<20} {'N/A':<15} {'N/A':<15}")
        
    
    # Cross-environment comparison table (common algorithms only)
    if args.all_envs:
        print(f"\n{'='*80}")
        print(\"CROSS-ENVIRONMENT COMPARISON (Common algorithms)\")
        print(f"{'='*80}")
        
        # Identify common algorithms
        common_algos = set()
        for env_name, env_results in all_results.items():
            if env_results:
                common_algos.update(env_results.keys())
        
        # Compare results across environments
        for algo_name in sorted(common_algos):
            print(f"\n{algo_name}:")
            print(f"{'Environment':<25} {'Mean±SEM':<20} {'CVaR±SEM':<20}")
            print(f"{'-'*70}")
            
            for env_name, env_results in all_results.items():
                if env_results and algo_name in env_results and env_results[algo_name] and env_results[algo_name].get('overall_results'):
                    results = env_results[algo_name]['overall_results']
                    
                    # Scientific notation formatter (cross-environment)
                    def format_cross_env_value(value, error, use_scientific=False):
                        if use_scientific and (abs(value) >= 100 or abs(error) >= 100):
                            exp = int(np.floor(np.log10(max(abs(value), abs(error)))))
                            if exp >= 0:
                                return f"{value/10**exp:.2f}×10^{exp}±{error/10**exp:.2f}×10^{exp}"
                            else:
                                return f"{value/10**exp:.2f}×10^{exp}±{error/10**exp:.2f}×10^{exp}"
                        else:
                            return f"{value:.2f}±{error:.2f}"
                    
                    mean_str = format_cross_env_value(results['mean'], results['mean_err'], args.scientific)
                    cvar_str = format_cross_env_value(results['cvar'], results['cvar_err'], args.scientific)
                    print(f"{env_name:<25} {mean_str:<20} {cvar_str:<20}")
                else:
                    print(f"{env_name:<25} {'N/A':<20} {'N/A':<20}")

if __name__ == "__main__":
    main() 
