'''
Load data from summary folder, calculate the average and standard deviation of the results, and save them to a CSV file.
'''
import os
import sys
sys.path.append('.')
path0 = os.path.dirname(sys.argv[0])

import pandas as pd

path_summary = os.path.join(path0, 'summary')


def read_summary_results(fname: str) -> pd.DataFrame:
    """
    Read the summary file and return a pandas DataFrame with the results
    
    Args:
        fname: Name of the summary file
        
    Returns:
        pandas DataFrame with the results
    """
    if not os.path.exists(fname):
        print(f"Warning: Summary file {fname} does not exist.")
        return None
    
    # Read the results file into a DataFrame
    try:
        # Read with header from file instead of specifying column names
        # This assumes the CSV file already has a header
        df = pd.read_csv(fname, sep=',', skipinitialspace=True)
        
        # Check if expected columns are present - updated to match comprehensive header from run_models_165_parallel.py
        expected_columns = ["n_sample", "seed"]
        
        # Basic train metrics
        expected_columns.extend([
            "train_mse", "train_nll", "train_mae", "train_crps", 
            "train_coverage_95", "train_width_95", "train_ace"
        ])
        
        # Train coverage at different confidence levels
        for conf in [10, 20, 30, 40, 50, 60, 70, 80, 90]:
            expected_columns.append(f"train_coverage_{conf}")
        
        # Train width at different confidence levels
        for conf in [10, 20, 30, 40, 50, 60, 70, 80, 90]:
            expected_columns.append(f"train_width_{conf}")
        
        # Basic test metrics
        expected_columns.extend([
            "test_mse", "test_nll", "test_mae", "test_crps",
            "test_coverage_95", "test_width_95", "test_ace"
        ])
        
        # Test coverage at different confidence levels
        for conf in [10, 20, 30, 40, 50, 60, 70, 80, 90]:
            expected_columns.append(f"test_coverage_{conf}")
        
        # Test width at different confidence levels
        for conf in [10, 20, 30, 40, 50, 60, 70, 80, 90]:
            expected_columns.append(f"test_width_{conf}")
        
        # Training metadata
        expected_columns.extend(["training_time", "epoch"])
        
        # If file has no header or different column names, assign the expected ones
        if not all(col in df.columns for col in expected_columns[:2]):
            print(f"Warning: CSV file {fname} does not have expected headers. Assigning default column names.")
            df = pd.read_csv(fname, sep=',', skipinitialspace=True, header=None, 
                             names=expected_columns[:len(df.columns)])
        
        # Make sure numeric columns are properly converted - updated to include all new metrics
        numeric_cols = ["n_sample", "seed"]
        
        # Basic train metrics
        numeric_cols.extend([
            "train_mse", "train_nll", "train_mae", "train_crps", 
            "train_coverage_95", "train_width_95", "train_ace"
        ])
        
        # Train coverage at different confidence levels
        for conf in [10, 20, 30, 40, 50, 60, 70, 80, 90]:
            numeric_cols.append(f"train_coverage_{conf}")
        
        # Train width at different confidence levels
        for conf in [10, 20, 30, 40, 50, 60, 70, 80, 90]:
            numeric_cols.append(f"train_width_{conf}")
        
        # Basic test metrics
        numeric_cols.extend([
            "test_mse", "test_nll", "test_mae", "test_crps",
            "test_coverage_95", "test_width_95", "test_ace"
        ])
        
        # Test coverage at different confidence levels
        for conf in [10, 20, 30, 40, 50, 60, 70, 80, 90]:
            numeric_cols.append(f"test_coverage_{conf}")
        
        # Test width at different confidence levels
        for conf in [10, 20, 30, 40, 50, 60, 70, 80, 90]:
            numeric_cols.append(f"test_width_{conf}")
        
        # Training metadata
        numeric_cols.extend(["training_time", "epoch"])
        for col in numeric_cols:
            if col in df.columns:
                df[col] = pd.to_numeric(df[col], errors='coerce')
        
        return df
    
    except Exception as e:
        print(f"Error reading summary file {fname}: {e}")
        return None

def calculate_summary_statistics(df: pd.DataFrame, threshold=100.0):
    """
    Calculate average and standard deviation of results across different seeds,
    grouped by n_sample
    
    Args:
        df: DataFrame with results for a specific model and dataset
        
    Returns:
        Dictionary with sample sizes as keys, each containing avg and std for each metric
    """

    if df is None or len(df) == 0:
        print(f"No results found")
        return None
    
    # Group by n_sample
    stats_by_sample = {}
    
    # Define comprehensive metrics list to match the new output format
    metrics = ['train_mse', 'train_nll', 'train_mae', 'train_crps', 'train_coverage_95', 'train_width_95', 'train_ace']
    
    # Add train coverage at different confidence levels
    for conf in [10, 20, 30, 40, 50, 60, 70, 80, 90]:
        metrics.append(f'train_coverage_{conf}')
    
    # Add train width at different confidence levels
    for conf in [10, 20, 30, 40, 50, 60, 70, 80, 90]:
        metrics.append(f'train_width_{conf}')
    
    # Add basic test metrics
    metrics.extend(['test_mse', 'test_nll', 'test_mae', 'test_crps', 'test_coverage_95', 'test_width_95', 'test_ace'])
    
    # Add test coverage at different confidence levels
    for conf in [10, 20, 30, 40, 50, 60, 70, 80, 90]:
        metrics.append(f'test_coverage_{conf}')
    
    # Add test width at different confidence levels
    for conf in [10, 20, 30, 40, 50, 60, 70, 80, 90]:
        metrics.append(f'test_width_{conf}')
    
    # Add training metadata
    metrics.extend(['training_time', 'epoch'])
    
    # Get unique sample sizes
    sample_sizes = df['n_sample'].unique()
    
    for sample_size in sample_sizes:
        # Filter data for this sample size
        df_sample = df[df['n_sample'] == sample_size]
        
        # Calculate statistics for this group
        stats = {}
        for metric in metrics:
            if metric in df.columns:
                
                data = df_sample[metric]
                data = data[data <= threshold]
                
                stats[f'{metric}_mean'] = data.mean()
                stats[f'{metric}_std'] = data.std()
        
        # Add count of seeds
        stats['n_seeds'] = len(df_sample)
        
        # Store stats for this sample size
        stats_by_sample[sample_size] = stats
    
    return stats_by_sample

def save_dataset_test_results(fname_suffix, dataset_prefix, model_list=['HVBLL', 'VBLL'], threshold=100.0):
    """
    Save a comparison of different models for the same dataset and settings
    
    Args:
        fname_suffix: Suffix for the dataset file
        dataset_prefix: Prefix for the dataset
        model_list: List of model names to compare.
        
    Returns:
        pandas DataFrame with the comparison
    """
    results = {}
    for model_name in model_list:
        
        fname = os.path.join(path_summary, 'summary-%s-%s.csv'%(model_name, fname_suffix))
        df = read_summary_results(fname)
        if df is None:
            continue
        stats_by_sample = calculate_summary_statistics(df, threshold=threshold)
        if stats_by_sample is None:
            continue
        
        # Include all sample sizes, with sample size in the model name
        for sample_size, stats in stats_by_sample.items():
            results[f"{model_name}_{sample_size}"] = stats
    
    if not results:
        print(f"No results found for dataset {dataset_prefix}")
        return None
    
    # Convert to DataFrame for better display
    df_results = pd.DataFrame(results).T
    
    # Save to CSV with scientific notation format and model name as 20-character string
    os.makedirs(os.path.join(path0, 'result'), exist_ok=True)
    output_file = f"result/comparison-{dataset_prefix}.csv"
    
    # Format the index (model names) to have 20-character width
    df_results.index = [f"{idx:15s}" for idx in df_results.index]
    
    # Use float_format for numeric values only
    df_results.to_csv(os.path.join(path0, output_file), index=True, float_format='%10.3E')
    print(f"Comparison results saved to {output_file}")
    
    return df_results


if __name__ == "__main__":
    
    case_list = [('era5', 'era5', 100)]

    model_list = [
        'HVBLL', 
        'VBLL', 
        'BLL', 
        'MC-Dropout', 
        # 'Deep-GP',    #! Very bad results
        'PNN', 
        'SWAG', 
        'DVI',
        'MDN',
        ]
    
    update_summary = True
    
    if update_summary:
    
        for fname_suffix, dataset_prefix, threshold in case_list:

            print(f"Analyzing dataset {dataset_prefix}...")
            
            # Generate CSV comparison
            results_summary = save_dataset_test_results(
                fname_suffix=fname_suffix,
                dataset_prefix=dataset_prefix,
                model_list=model_list,
                threshold=threshold
            )
