import pandas as pd
import numpy as np
import logging
import ast
from pathlib import Path

logging.basicConfig(level=logging.INFO, format='[%(levelname)s] %(message)s')
logger = logging.getLogger(__name__)

INPUT_FILENAME = "summary_results.csv"
OUTPUT_FILENAME = "final_summary_table.csv"

def clean_and_normalize_hps(df):
    hp_cols = [
        'kernel_func', 'aggregation', 'conv_kernel_size', 
        'bandwidths_str', 'interpolation', 'bw_mode',
        'length_scale', 'noise_std', 'tolerance', 'time_scaling',
        'depth', 'step_size' 
    ]
    
    for col in hp_cols:
        if col in df.columns:
            df[col] = df[col].fillna('N/A_HP')
            df[col] = df[col].astype(str)
            if col == 'bandwidths_str':
                df[col] = df[col].str.replace(' ', '')
                
    return df

def format_mean_std(mean, std):
    if pd.isna(mean): return "N/A"
    if pd.isna(std): std = 0.0
    return f"{mean:.2f} ± {std:.2f}"

def process_dataset_and_model(df_dataset, dataset_name, model_type, summary_rows):
    if model_type == 'baseline':
        df_model = df_dataset[df_dataset['model_type'] == 'baseline'].copy()
    else:
        df_model = df_dataset[df_dataset['model_type'] == model_type].copy()

    if df_model.empty:
        return

    possible_hps = [
        'interpolation', 'kernel_func', 'aggregation', 'bandwidths_str', 
        'conv_kernel_size', 'time_scaling', 'tolerance', 
        'length_scale', 'noise_std', 'bw_mode',
        'depth', 'step_size', 
        'hidden_dim', 'batch_size', 'lr' 
    ]
    group_keys = [c for c in possible_hps if c in df_model.columns]

    df_grouped = df_model.groupby(group_keys, dropna=False).agg(
        mean_test_acc=('test_acc', 'mean'),
        std_test_acc=('test_acc', 'std'),
        
        mean_best_val_acc=('best_val_acc', 'mean'),
        
        mean_total_time=('total_time_sec', 'mean'),
        std_total_time=('total_time_sec', 'std'),
        
        mean_nfe=('avg_nfe', 'mean'),
        std_nfe=('avg_nfe', 'std'),
        
        count=('seed', 'count')
    ).reset_index()

    if df_grouped.empty:
        return

    for index, row in df_grouped.iterrows():
        
        display_name = model_type
        
        interpolation = str(row.get('interpolation', 'N/A_HP'))
        kernel_func = str(row.get('kernel_func', 'N/A_HP'))

        if model_type == 'baseline':
            if interpolation != 'N/A_HP':
                display_name = f"{model_type} ({interpolation})"
        elif model_type in ['q-former', 'conv', 'qformer']:
            if kernel_func != 'N/A_HP':
                display_name = f"{model_type} ({kernel_func})"
        
        hp_str_parts = []
        
        def is_valid(val):
            return str(val) != 'N/A_HP'

        if is_valid(row.get('bandwidths_str')):
            hp_str_parts.append(f"BW:{row['bandwidths_str']}")
        
        if is_valid(row.get('lr')):
             hp_str_parts.append(f"LR:{row['lr']}")

        if is_valid(row.get('tolerance')):
            hp_str_parts.append(f"Tol:{row['tolerance']}")

        if is_valid(row.get('depth')):
            hp_str_parts.append(f"Depth:{row['depth']}")
        
        if is_valid(row.get('step_size')):
            hp_str_parts.append(f"Step:{row['step_size']}")
        if is_valid(row.get('length_scale')):
            hp_str_parts.append(f"LS:{row['length_scale']}")
        if is_valid(row.get('noise_std')):
            hp_str_parts.append(f"Noise:{row['noise_std']}")
            
        if is_valid(row.get('conv_kernel_size')):
            hp_str_parts.append(f"KS:{row['conv_kernel_size']}")
        if is_valid(row.get('aggregation')):
            hp_str_parts.append(f"Aggr:{row['aggregation']}")

        row_dict = {
            'Dataset': dataset_name,
            'Model': display_name,
            'Seeds': row['count'],
            
            'Test Acc (%)': format_mean_std(row['mean_test_acc'], row['std_test_acc']),
            'Total Time (s)': format_mean_std(row['mean_total_time'], row['std_total_time']),
            'Avg NFE': format_mean_std(row['mean_nfe'], row['std_nfe']),
            
            'Params': " | ".join(hp_str_parts),
            

            '_raw_acc': row['mean_test_acc']
        }
        summary_rows.append(row_dict)

def generate_summary_table(input_path):
    input_path = Path(input_path)
    if not input_path.exists():
        logger.error(f"File not found: {input_path}")
        return

    logger.info(f"Loading {input_path}...")
    df = pd.read_csv(input_path)
    
    df = clean_and_normalize_hps(df)
    
    df['model_type'] = df['model_type'].replace({'qformer': 'q-former'})
    
    summary_rows = []
    
    for dataset_name, df_ds in df.groupby('dataset'):
        logger.info(f"Processing {dataset_name}...")
        
        models = ['baseline', 'kernel', 'gp', 'odernn', 'grud', 'q-former', 'conv', 'log_ncde']
        
        for m in models:
            process_dataset_and_model(df_ds, dataset_name, m, summary_rows)

    if summary_rows:
        df_summary = pd.DataFrame(summary_rows)
        
        df_summary.sort_values(by=['Dataset', 'Model', '_raw_acc'], ascending=[True, True, False], inplace=True)
        
        df_summary.drop(columns=['_raw_acc'], inplace=True)
        
        output_path = input_path.parent / OUTPUT_FILENAME
        df_summary.to_csv(output_path, index=False)
        logger.info(f"Full summary table saved to: {output_path}")
        logger.info(f"Total rows generated: {len(df_summary)}")
    else:
        logger.warning("No summary rows generated.")

if __name__ == "__main__":
    generate_summary_table("exp_dir/summary_results.csv")