#!/usr/bin/env python3
"""
Minimal script to calculate and display performance metrics from ep_returns.npy files.
Follows DRY and KISS principles.
"""

import numpy as np
import os
import json
import argparse
from typing import Dict, List, Optional, Tuple

# Core calculation functions
def calculate_cvar(scores: np.ndarray, alpha: float = 0.1) -> float:
    """Calculate CVaR@alpha"""
    if len(scores) == 0:
        return np.nan
    sorted_scores = np.sort(scores)
    k = int(np.ceil(alpha * len(sorted_scores)))
    return np.mean(sorted_scores[:max(1, k)])

def calculate_metrics(ep_returns: np.ndarray) -> Dict:
    """Calculate all metrics for a single seed"""
    if len(ep_returns) == 0:
        return {'mean': np.nan, 'cvar': np.nan, 'episodes': 0}
    
    return {
        'mean': float(np.mean(ep_returns)),
        'cvar': float(calculate_cvar(ep_returns)),
        'episodes': len(ep_returns)
    }

def aggregate_seeds(seed_metrics: List[Dict]) -> Dict:
    """Aggregate metrics across seeds"""
    if not seed_metrics:
        return {}
    
    means = [m['mean'] for m in seed_metrics if not np.isnan(m['mean'])]
    cvars = [m['cvar'] for m in seed_metrics if not np.isnan(m['cvar'])]
    
    if not means or not cvars:
        return {}
    
    def sem(x):
        """Calculate Standard Error of Mean (SEM)"""
        x = np.asarray(x, float)
        return np.std(x, ddof=1) / np.sqrt(len(x)) if len(x) > 1 else np.nan
    
    return {
        'mean': np.mean(means),
        'mean_err': sem(means),
        'cvar': np.mean(cvars), 
        'cvar_err': sem(cvars),
        'seeds': len(means)
    }

# Data loading and processing
def load_seed_data(seed_dir: str) -> Optional[np.ndarray]:
    """Load ep_returns.npy from seed directory"""
    ep_returns_path = os.path.join(seed_dir, "ep_returns.npy")
    if not os.path.exists(ep_returns_path):
        return None
    return np.load(ep_returns_path)

def process_algorithm(base_dir: str, pattern: str, seeds: List[int]) -> Dict:
    """Process all seeds for one algorithm"""
    seed_metrics = []
    
    for seed in seeds:
        seed_dir = os.path.join(base_dir, f"{pattern}{seed}")
        ep_returns = load_seed_data(seed_dir)
        
        if ep_returns is not None:
            metrics = calculate_metrics(ep_returns)
            seed_metrics.append(metrics)
    
    return aggregate_seeds(seed_metrics)

def process_all_experiments(config: Dict) -> Dict:
    """Process all algorithms for all environments"""
    results = {}
    
    for env_name, algorithms in config.items():
        print(f"Processing {env_name}...")
        env_results = {}
        
        for algo_info in algorithms:
            algo_name = algo_info['name']
            base_dir = algo_info.get('base_dir', 'results')
            pattern = algo_info['pattern']
            seeds = algo_info['seeds']
            
            metrics = process_algorithm(base_dir, pattern, seeds)
            if metrics:
                env_results[algo_name] = metrics
                print(f"  {algo_name}: {len(seeds)} seeds processed")
        
        results[env_name] = env_results
    
    return results

# Algorithm extraction and grouping
def extract_algorithm_name(experiment_name: str) -> str:
    """Extract algorithm name from experiment name"""
    algorithms = ['CODAC', 'CQL', 'FQL', 'ORAAC', 'QL', 'RADAC', 'RAFMAC']
    
    experiment_upper = experiment_name.upper()
    for algo in algorithms:
        if algo in experiment_upper:
            return algo
    
    return experiment_name  # fallback to original name

def group_by_algorithm(results: Dict) -> Dict:
    """Group experiment results by algorithm name"""
    algo_grouped = {}
    
    for env_name, env_results in results.items():
        algo_grouped[env_name] = {}
        
        # Group experiments by algorithm
        algo_experiments = {}
        for exp_name, metrics in env_results.items():
            algo_name = extract_algorithm_name(exp_name)
            if algo_name not in algo_experiments:
                algo_experiments[algo_name] = []
            algo_experiments[algo_name].append(metrics)
        
        # Aggregate metrics for each algorithm
        for algo_name, exp_list in algo_experiments.items():
            if exp_list:
                if len(exp_list) == 1:
                    # Single experiment - use its results directly
                    exp_metrics = exp_list[0]
                    algo_grouped[env_name][algo_name] = {
                        'mean': exp_metrics['mean'],
                        'mean_err': exp_metrics.get('mean_err', np.nan),
                        'cvar': exp_metrics['cvar'],
                        'cvar_err': exp_metrics.get('cvar_err', np.nan),
                        'seeds': exp_metrics['seeds'],
                        'experiments': 1
                    }
                else:
                    # Multiple experiments - we need to properly aggregate seed-level data
                    # This is problematic because we lost the individual seed values
                    # For now, we'll use experiment means but note this is not ideal
                    all_means = []
                    all_cvars = []
                    total_seeds = 0
                    
                    for exp_metrics in exp_list:
                        if 'mean' in exp_metrics and not np.isnan(exp_metrics['mean']):
                            all_means.append(exp_metrics['mean'])
                            all_cvars.append(exp_metrics['cvar'])
                            total_seeds += exp_metrics['seeds']
                    
                    if all_means:
                        def sem(x):
                            """Calculate Standard Error of Mean (SEM)"""
                            x = np.asarray(x, float)
                            return np.std(x, ddof=1) / np.sqrt(len(x)) if len(x) > 1 else np.nan
                        
                        algo_grouped[env_name][algo_name] = {
                            'mean': np.mean(all_means),
                            'mean_err': sem(all_means),
                            'cvar': np.mean(all_cvars),
                            'cvar_err': sem(all_cvars),
                            'seeds': total_seeds,
                            'experiments': len(exp_list),
                            'note': 'Aggregated from multiple experiments - SEM may not be accurate'
                        }
    
    return algo_grouped

# Display functions
def format_value(value: float, error: float) -> str:
    """Format value ± error"""
    if np.isnan(value):
        return "N/A"
    if np.isnan(error):
        return f"{value:.2f}±N/A"
    return f"{value:.2f}±{error:.2f}"

def print_table(data: Dict, group_by: str = 'env') -> None:
    """Print results table grouped by environment or algorithm"""
    # Group experiments by algorithm first
    algo_data = group_by_algorithm(data)
    
    if group_by == 'algo':
        print_by_algorithm(algo_data)
    else:
        print_by_environment(algo_data)

def print_by_environment(data: Dict) -> None:
    """Print table grouped by environment"""
    for env_name, env_results in data.items():
        if not env_results:
            continue
            
        print(f"\n{env_name.upper()}:")
        print(f"{'Algorithm':<15} {'Mean±SEM':<20} {'CVaR±SEM':<20} {'Seeds':<6} {'Exps':<5}")
        print("-" * 67)
        
        for algo_name in sorted(env_results.keys()):
            metrics = env_results[algo_name]
            mean_str = format_value(metrics['mean'], metrics['mean_err'])
            cvar_str = format_value(metrics['cvar'], metrics['cvar_err'])
            exps = metrics.get('experiments', 1)
            print(f"{algo_name:<15} {mean_str:<20} {cvar_str:<20} {metrics['seeds']:<6} {exps:<5}")

def print_by_algorithm(data: Dict) -> None:
    """Print table grouped by algorithm"""
    # Collect all algorithms
    all_algorithms = set()
    for env_results in data.values():
        all_algorithms.update(env_results.keys())
    
    for algo_name in sorted(all_algorithms):
        print(f"\n{algo_name}:")
        print(f"{'Environment':<25} {'Mean±SEM':<20} {'CVaR±SEM':<20} {'Seeds':<6} {'Exps':<5}")
        print("-" * 77)
        
        for env_name in sorted(data.keys()):
            if algo_name in data[env_name]:
                metrics = data[env_name][algo_name]
                mean_str = format_value(metrics['mean'], metrics['mean_err'])
                cvar_str = format_value(metrics['cvar'], metrics['cvar_err'])
                exps = metrics.get('experiments', 1)
                print(f"{env_name:<25} {mean_str:<20} {cvar_str:<20} {metrics['seeds']:<6} {exps:<5}")
            else:
                print(f"{env_name:<25} {'N/A':<20} {'N/A':<20} {'N/A':<6} {'N/A':<5}")

# Configuration
def get_default_config() -> Dict:
    """Get default algorithm configurations"""
    return {
        "walker2d-medium-replay-v2": [
            {
                'name': 'RADAC_CPW',
                'pattern': 'frozen_logs/ablation_results/walker2d-medium-replay-v2|radac|cpw|',
                'seeds': [0, 1, 2],
                'base_dir': '../'
            },
            {
                'name': 'RADAC_WANG', 
                'pattern': 'frozen_logs/ablation_results/walker2d-medium-replay-v2|radac|wang|',
                'seeds': [0, 1, 2],
                'base_dir': '../'
            },
            {
                'name': 'RADAC_7_12',
                'pattern': 'walker2d-medium-replay-v2|7_12|radac|T-5|ms-offline|k-0|risky_data|',
                'seeds': [0, 1, 2, 3, 4]
            }
        ],
        "halfcheetah-medium-replay-v2": [
            {
                'name': 'RADAC_CPW',
                'pattern': 'frozen_logs/ablation_results/halfcheetah-medium-replay-v2|radac|cpw|',
                'seeds': [0, 1, 2], 
                'base_dir': '../'
            },
            {
                'name': 'ORAAC_7_3',
                'pattern': 'halfcheetah-medium-replay-v2|7_3|oraac|T-5|ms-offline|k-0|risky_data|',
                'seeds': [0, 1, 2, 3, 4]
            },
            {
                'name': 'CQL_6_10',
                'pattern': 'halfcheetah-medium-replay-v2|6_10|cql|T-5|ms-offline|k-0|risky_data|',
                'seeds': [0, 1, 2, 3, 4]
            }
        ]
    }

def discover_config(runs_dir: str = "frozen_logs/runs") -> Dict:
    """Auto-discover configuration from runs directory"""
    import glob
    
    config = {}
    for dir_path in glob.glob(os.path.join(runs_dir, "*")):
        if not os.path.isdir(dir_path):
            continue
            
        parts = os.path.basename(dir_path).split('|')
        if len(parts) < 8:
            continue
            
        env_name, date, algorithm = parts[0], parts[1], parts[2]
        seed = int(parts[7])
        
        pattern = '|'.join(parts[:-1]) + '|'
        # Create experiment name with more detail for uniqueness
        exp_key = f"{algorithm}_{date}_{parts[5] if len(parts) > 5 else 'k-0'}"
        
        if env_name not in config:
            config[env_name] = {}
        
        if exp_key not in config[env_name]:
            config[env_name][exp_key] = {
                'name': exp_key,  # Will be grouped by algorithm later
                'pattern': pattern,
                'seeds': [],
                'base_dir': runs_dir
            }
        
        config[env_name][exp_key]['seeds'].append(seed)
    
    # Convert to list format and sort seeds
    result = {}
    for env_name, algos in config.items():
        result[env_name] = []
        for algo_info in algos.values():
            algo_info['seeds'].sort()
            result[env_name].append(algo_info)
    
    return result

# Main function
def main():
    parser = argparse.ArgumentParser(description='Calculate performance metrics')
    parser.add_argument('--config', choices=['default', 'discover'], default='default',
                       help='Configuration source')
    parser.add_argument('--runs_dir', default='frozen_logs/runs',
                       help='Directory for discover mode')
    parser.add_argument('--output', default='metrics.json',
                       help='Output JSON file')
    parser.add_argument('--table_type', choices=['env', 'algo'], default='env',
                       help='Group table by environment or algorithm')
    parser.add_argument('--summary_only', action='store_true',
                       help='Show only summary from existing results')
    
    args = parser.parse_args()
    
    # Load or generate results
    if args.summary_only and os.path.exists(args.output):
        with open(args.output, 'r') as f:
            results = json.load(f)
    else:
        # Get configuration
        if args.config == 'discover':
            config = discover_config(args.runs_dir)
        else:
            config = get_default_config()
        
        # Process experiments
        results = process_all_experiments(config)
        
        # Save results
        with open(args.output, 'w') as f:
            json.dump(results, f, indent=2, default=str)
        print(f"\nResults saved to: {args.output}")
    
    # Display results
    print_table(results, args.table_type)

if __name__ == "__main__":
    main()