`#!/usr/bin/env python3
"""
Script to calculate normalized mean, SE, and CVaR from ep_returns.npy
Loads ep_returns from each seed in the specified directory and calculates statistics.
"""

import numpy as np
import os
import json
import argparse

def load_ep_returns(directory_path):
    """
    Load ep_returns.npy from the specified directory

    Args:
        directory_path (str): Directory path

    Returns:
        np.ndarray: Array of episode returns
    """
    ep_returns_path = os.path.join(directory_path, "ep_returns.npy")
    if not os.path.exists(ep_returns_path):
        raise FileNotFoundError(f"ep_returns.npy not found in {directory_path}")

    return np.load(ep_returns_path)

def transform_scores(ep_returns, env_name, normalize="none"):
    """
    Transform episode returns using the specified method

    Args:
        ep_returns (np.ndarray): Episode returns
        env_name (str): Environment name
        normalize (str): Normalization method ("none", "base")

    Returns:
        np.ndarray: Transformed scores
    """
    ep_returns = np.asarray(ep_returns, float)

    if normalize == "base":
        try:
            import gym

            # Create environment and get normalization function
            env = gym.make(env_name)
            vals = np.array([env.get_normalized_score(r) for r in ep_returns], dtype=float)

            # If 0-1 range detected, scale to 0-100 (for safety)
            if np.nanmax(vals) <= 1.0:
                vals *= 100.0
                print(f"  Note: Normalized scores were 0-1, scaled to 0-100")

            return vals
        except Exception as e:
            print(f"Warning: Could not normalize scores for {env_name}: {e}")
            print("Using raw scores instead")
            return ep_returns
    else:
        # normalize="none" → keep raw values
        return ep_returns

def calculate_cvar(scores, alpha=0.1):
    """
    Calculate CVaR@alpha

    Args:
        scores (np.ndarray): Score array
        alpha (float): Risk level (default: 0.1)

    Returns:
        float: CVaR@alpha
    """
    if len(scores) == 0:
        return np.nan

    # Sort in ascending order
    sorted_scores = np.sort(scores)

    # Index for alpha quantile
    n = len(sorted_scores)
    k = int(np.ceil(alpha * n))

    if k == 0:
        return sorted_scores[0]

    # CVaR@alpha = (1/alpha) * sum_{i=1}^k x_i
    cvar = np.mean(sorted_scores[:k])
    return cvar

def calculate_se(scores):
    """
    Calculate Standard Error (SE)

    Args:
        scores (np.ndarray): Score array

    Returns:
        float: Standard error
    """
    if len(scores) == 0:
        return np.nan

    return np.std(scores) / np.sqrt(len(scores))

def process_directory(base_dir, env_name, algorithm_info, normalize="none", alpha=0.1):
    """
    Process each seed in the specified directory pattern

    Args:
        base_dir (str): Base directory
        env_name (str): Environment name
        algorithm_info (dict): Algorithm information (pattern, seeds, name)

    Returns:
        dict: Calculation results
    """
    pattern = algorithm_info['pattern']
    seeds = algorithm_info['seeds']
    algo_name = algorithm_info['name']

    # Find directories for specified seeds
    seed_dirs = []
    for seed in seeds:
        dir_path = f"../{pattern}{seed}"

        if os.path.exists(dir_path):
            seed_dirs.append((seed, dir_path))
        else:
            print(f"Warning: Directory not found: {dir_path}")

    if not seed_dirs:
        raise FileNotFoundError(f"No valid directories found for pattern: {pattern}")

    print(f"Found {len(seed_dirs)} seed directories for {algo_name}")

    # Calculate statistics for each seed
    seed_results = {}
    seed_means = []
    seed_cvars = []

    for seed, dir_path in seed_dirs:
        print(f"Processing seed {seed}: {dir_path}")

        try:
            # Load ep_returns
            ep_returns = load_ep_returns(dir_path)
            print(f"  Loaded {len(ep_returns)} episodes")

            # Transform scores (raw returns for risky D4RL, normalized for standard D4RL)
            scores_for_metric = transform_scores(ep_returns, env_name, normalize=normalize)

            # Calculate statistics for each seed
            mean_score = np.mean(scores_for_metric)
            cvar_score = calculate_cvar(scores_for_metric, alpha=0.1)

            seed_results[seed] = {
                'mean': mean_score,
                'cvar': cvar_score,
                'episodes': len(scores_for_metric),
                'raw_scores': ep_returns.tolist(),
                'scores_for_metric': scores_for_metric.tolist()
            }

            # Save values for aggregation across seeds
            seed_means.append(float(mean_score))
            seed_cvars.append(float(cvar_score))

            print(f"  Seed {seed}: mean={mean_score:.4f}, cvar={cvar_score:.4f}")

        except Exception as e:
            print(f"  Error processing seed {seed}: {e}")
            seed_results[seed] = None

    # Literature-based aggregation (mean ± error across seeds)
    if seed_means and seed_cvars:
        # Calculate SEM (Standard Error of Mean)
        def sem(x):
            x = np.asarray(x, float)
            return np.std(x, ddof=1) / np.sqrt(len(x)) if len(x) > 1 else np.nan

        # Get first valid result (for episodes_per_seed)
        first_valid_result = next((r for r in seed_results.values() if r is not None), None)

        overall_results = {
            'mean': float(np.mean(seed_means)),
            'mean_err': float(sem(seed_means)),      # SEM (recommended)
            'mean_std': float(np.std(seed_means, ddof=1)),   # STD (more common in literature)
            'cvar': float(np.mean(seed_cvars)),
            'cvar_err': float(sem(seed_cvars)),      # CVaR SEM
            'cvar_std': float(np.std(seed_cvars, ddof=1)),   # CVaR STD
            'seeds_used': len(seed_means),
            'episodes_per_seed': len(first_valid_result['scores_for_metric']) if first_valid_result else 0,
            'normalization_method': normalize,
            'metric_basis': 'raw_returns' if normalize == 'none' else 'normalized_0_100'
        }

        # For appendix: pooled CVaR (combine episodes from all seeds)
        all_scores_for_metric = []
        for seed_result in seed_results.values():
            if seed_result:
                all_scores_for_metric.extend(seed_result['scores_for_metric'])

        if all_scores_for_metric:
            pooled_cvar = calculate_cvar(np.array(all_scores_for_metric), alpha=0.1)
            overall_results['pooled_cvar'] = float(pooled_cvar)
            overall_results['total_episodes'] = len(all_scores_for_metric)
    else:
        overall_results = None

    return {
        'algorithm': algo_name,
        'seed_results': seed_results,
        'overall_results': overall_results
    }

def get_algorithm_configs(env_name):
    """
    Return algorithm configurations based on environment name

    Args:
        env_name (str): Environment name

    Returns:
        list: List of algorithm configurations
    """
    if env_name == "walker2d-medium-replay-v2":
        return [
            {
                'name': 'RADAC-CPW',
                'pattern': f"frozen_logs/ablation_results/{env_name}|radac|cpw|",
                'seeds': [0, 1, 2]
            },
            {
                'name': 'RADAC-WANG',
                'pattern': f"frozen_logs/ablation_results/{env_name}|radac|wang|",
                'seeds': [0, 1, 2]
            },
            {
                'name': 'RADAC-CVaR',
                'pattern': f"frozen_logs/ablation_results/{env_name}|radac|cvar/",
                'seeds': [0, 1, 2]
            },
            {
                'name': 'RAFMAC-CPW',
                'pattern': f"frozen_logs/ablation_results/{env_name}|rafmac|cpw|",
                'seeds': [0, 1, 2]
            },
            {
                'name': 'RAFMAC-WANG',
                'pattern': f"frozen_logs/ablation_results/{env_name}|rafmac|wang|",
                'seeds': [0, 1, 2]
            },
            {
                'name': 'RAFMAC-CVaR',
                'pattern': f"frozen_logs/ablation_results/{env_name}|rafmac|cvar|",
                'seeds': [0, 1, 2]
            }
        ]
    elif env_name == "halfcheetah-medium-replay-v2":
        return [
            {
                'name': 'RADAC-CPW',
                'pattern': f"frozen_logs/ablation_results/{env_name}|radac|cpw|",
                'seeds': [0, 1, 2]
            },
            {
                'name': 'RADAC-WANG',
                'pattern': f"frozen_logs/ablation_results/{env_name}|radac|wang|",
                'seeds': [0, 1, 2]
            },
            {
                'name': 'RADAC-CVaR',
                'pattern': f"frozen_logs/ablation_results/{env_name}|radac|cvar/",
                'seeds': [0, 1, 2]
            },
            {
                'name': 'RAFMAC-CPW',
                'pattern': f"frozen_logs/ablation_results/{env_name}|rafmac|cpw|",
                'seeds': [0, 1, 2]
            },
            {
                'name': 'RAFMAC-WANG',
                'pattern': f"frozen_logs/ablation_results/{env_name}|rafmac|wang|",
                'seeds': [0, 1, 2]
            },
            {
                'name': 'RAFMAC-CVaR',
                'pattern': f"frozen_logs/ablation_results/{env_name}|rafmac|cvar|",
                'seeds': [0, 1, 2]
            }
        ]
    else:
        return []

def main():
    parser = argparse.ArgumentParser(description='Calculate metrics from ep_returns.npy files')
    parser.add_argument('--env_name', default='halfcheetah-medium-replay-v2', help='Environment name')
    parser.add_argument('--output', default='metrics_summary.json', help='Output JSON file')
    parser.add_argument('--normalize', choices=['none', 'base'], default='none',
                       help='Normalization: none=raw returns (risky D4RL), base=0-100 normalized (standard D4RL)')
    parser.add_argument('--all_envs', action='store_true',
                       help='Process both walker2d and halfcheetah environments')
    parser.add_argument('--scientific', action='store_true',
                       help='Use scientific notation (×10^n) for large raw values')

    args = parser.parse_args()

    if args.all_envs:
        # Process both environments
        envs_to_process = ["walker2d-medium-replay-v2", "halfcheetah-medium-replay-v2"]
    else:
        # Process only specified environment
        envs_to_process = [args.env_name]

    all_results = {}

    for env_name in envs_to_process:
        print(f"\n{'='*80}")
        print(f"PROCESSING ENVIRONMENT: {env_name}")
        print(f"{'='*80}")

        # Get algorithm configurations for each environment
        algorithms = get_algorithm_configs(env_name)

        env_results = {}

        for algo_info in algorithms:
            print(f"\n{'-'*60}")
            print(f"Processing {algo_info['name']} for {env_name}")
            print(f"{'-'*60}")

            try:
                results = process_directory(None, env_name, algo_info,
                                         normalize=args.normalize, alpha=0.1)
                env_results[algo_info['name']] = results

                # Display results
                if results['overall_results']:
                    print(f"\n{algo_info['name']} Overall Results:")
                    print(f"  Mean: {results['overall_results']['mean']:.4f} ± {results['overall_results']['mean_err']:.4f} (SEM)")
                    print(f"  Mean: {results['overall_results']['mean']:.4f} ± {results['overall_results']['mean_std']:.4f} (STD)")
                    print(f"  CVaR@0.1: {results['overall_results']['cvar']:.4f} ± {results['overall_results']['cvar_err']:.4f} (SEM)")
                    print(f"  CVaR@0.1: {results['overall_results']['cvar']:.4f} ± {results['overall_results']['cvar_std']:.4f} (STD)")
                    print(f"  Pooled CVaR@0.1: {results['overall_results'].get('pooled_cvar', 'N/A')}")
                    print(f"  Episodes per seed: {results['overall_results']['episodes_per_seed']}")
                    print(f"  Seeds Used: {results['overall_results']['seeds_used']}")
                    if 'total_episodes' in results['overall_results']:
                        print(f"  Total Episodes: {results['overall_results']['total_episodes']}")

            except Exception as e:
                print(f"Error processing {algo_info['name']}: {e}")
                env_results[algo_info['name']] = None

        all_results[env_name] = env_results

    # Save results to JSON file
    output_path = args.output
    with open(output_path, 'w') as f:
        json.dump(all_results, f, indent=2, default=str)

    print(f"\nResults saved to: {output_path}")

    # Display method explanation
    if args.scientific:
        print(f"\nDisplay method: Raw returns with scientific notation (×10^n for large values)")
    else:
        print(f"\nDisplay method: Raw returns")

    # Display results for each environment
    for env_name, env_results in all_results.items():
        print(f"\n{'='*80}")
        print(f"RESULTS FOR {env_name.upper()}")
        print(f"{'='*80}")

        # Display method explanation
        if args.scientific:
            print(f"\nDisplay: Raw returns with scientific notation (×10^n for readability)")
        else:
            print(f"\nDisplay: Raw returns")

        # Scientific notation format function
        def format_value(value, error, use_scientific=False):
            if use_scientific and (abs(value) >= 100 or abs(error) >= 100):
                # Display in scientific notation
                exp = int(np.floor(np.log10(max(abs(value), abs(error)))))
                if exp >= 0:
                    return f"{value/10**exp:.2f}×10^{exp}±{error/10**exp:.2f}×10^{exp}"
                else:
                    return f"{value/10**exp:.2f}×10^{exp}±{error/10**exp:.2f}×10^{exp}"
            else:
                # Normal decimal display
                return f"{value:.2f}±{error:.2f}"

        # Display concise summary table (SEM)
        print(f"\nSUMMARY TABLE (Mean across seeds ± SEM)")
        print(f"{'='*80}")
        print(f"{'Algorithm':<20} {'Mean±SEM':<25} {'CVaR±SEM':<25} {'Seeds':<8}")
        print(f"{'-'*80}")

        for algo_name in env_results.keys():
            if env_results.get(algo_name) and env_results[algo_name].get('overall_results'):
                results = env_results[algo_name]['overall_results']
                mean_str = format_value(results['mean'], results['mean_err'], args.scientific)
                cvar_str = format_value(results['cvar'], results['cvar_err'], args.scientific)
                print(f"{algo_name:<20} {mean_str:<25} {cvar_str:<25} {results['seeds_used']:<8}")
            else:
                print(f"{algo_name:<20} {'N/A':<25} {'N/A':<25} {'N/A':<8}")

        # STD version table
        print(f"\nSUMMARY TABLE (Mean across seeds ± STD)")
        print(f"{'='*80}")
        print(f"{'Algorithm':<20} {'Mean±STD':<25} {'CVaR±STD':<25} {'Seeds':<8}")
        print(f"{'='*80}")

        for algo_name in env_results.keys():
            if env_results.get(algo_name) and env_results[algo_name].get('overall_results'):
                results = env_results[algo_name]['overall_results']
                mean_str = format_value(results['mean'], results['mean_std'], args.scientific)
                cvar_str = format_value(results['cvar'], results['cvar_std'], args.scientific)
                print(f"{algo_name:<20} {mean_str:<25} {cvar_str:<25} {results['seeds_used']:<8}")
            else:
                print(f"{algo_name:<20} {'N/A':<25} {'N/A':<25} {'N/A':<8}")

        # Pooled CVaR
        print(f"\nPOOLED CVAR (CVaR@0.1 combining episodes from all seeds)")
        print(f"{'='*80}")
        print(f"{'Algorithm':<20} {'Pooled CVaR':<15} {'Total Episodes':<15}")
        print(f"{'-'*80}")

        for algo_name in env_results.keys():
            if env_results.get(algo_name) and env_results[algo_name].get('overall_results'):
                results = env_results[algo_name]['overall_results']
                pooled_cvar = results.get('pooled_cvar', 'N/A')
                total_episodes = results.get('total_episodes', 'N/A')
                print(f"{algo_name:<20} {pooled_cvar:<15} {total_episodes:<15}")
            else:
                print(f"{algo_name:<20} {'N/A':<15} {'N/A':<15}")


    # Cross-environment comparison table (common algorithms only)
    if args.all_envs:
        print(f"\n{'='*80}")
        print("CROSS-ENVIRONMENT COMPARISON (Common algorithms)")
        print(f"{'='*80}")

        # Identify common algorithms
        common_algos = set()
        for env_name, env_results in all_results.items():
            if env_results:
                common_algos.update(env_results.keys())

        # Compare results across environments
        for algo_name in sorted(common_algos):
            print(f"\n{algo_name}:")
            print(f"{'Environment':<25} {'Mean±SEM':<20} {'CVaR±SEM':<20}")
            print(f"{'-'*70}")

            for env_name, env_results in all_results.items():
                if env_results and algo_name in env_results and env_results[algo_name] and env_results[algo_name].get('overall_results'):
                    results = env_results[algo_name]['overall_results']

                # Scientific notation format function (for cross-environment comparison)
                def format_cross_env_value(value, error, use_scientific=False):
                    if use_scientific and (abs(value) >= 100 or abs(error) >= 100):
                        exp = int(np.floor(np.log10(max(abs(value), abs(error)))))
                        if exp >= 0:
                            return f"{value/10**exp:.2f}×10^{exp}±{error/10**exp:.2f}×10^{exp}"
                        else:
                            return f"{value/10**exp:.2f}×10^{exp}±{error/10**exp:.2f}×10^{exp}"
                    else:
                        return f"{value:.2f}±{error:.2f}"

                    mean_str = format_cross_env_value(results['mean'], results['mean_err'], args.scientific)
                    cvar_str = format_cross_env_value(results['cvar'], results['cvar_err'], args.scientific)
                    print(f"{env_name:<25} {mean_str:<20} {cvar_str:<20}")
                else:
                    print(f"{env_name:<25} {'N/A':<20} {'N/A':<20}")

if __name__ == "__main__":
    main()
