import os
import json
import pandas as pd
from pathlib import Path
import logging

logging.basicConfig(level=logging.INFO, format='%(message)s')
logger = logging.getLogger(__name__)

def determine_bw_mode(row):
    m_type = row.get('model_type')
    bws = row.get('bandwidths_raw')
    if m_type in ['baseline', 'odernn', 'grud', 'gp', 'log_ncde']:
        return "N/A"

    if m_type == 'kernel':
        return "Single"

    if isinstance(bws, list):
        if len(bws) == 0:
            return "Empty"
        if len(bws) == 1:
            return "Single" 
        
        try:
            if len(set(bws)) == 1:
                return "Same"
            else:
                return "Diff"
        except:
            return "Diff"

    return "Single"

def parse_bandwidths(d):
    if 'bandwidths' in d and d['bandwidths'] is not None:
        return str(d['bandwidths']), d['bandwidths']
    
    if 'bws' in d and d['bws'] is not None:
        return str(d['bws']), d['bws']
    
    if 'bandwidth' in d and d['bandwidth'] is not None:
        return str(d['bandwidth']), [d['bandwidth']]
    
    return "N/A", None

def process_file(filepath):
    try:
        with open(filepath, 'r') as f:
            data = json.load(f)
    except Exception as e:
        logger.error(f"Error reading {filepath}: {e}")
        return None

    bw_str, bw_raw = parse_bandwidths(data)
    
    def get_val(key, default=None):
        return data.get(key, default)

    row = {
        'dataset': get_val('dataset_name'),
        'seed': get_val('seed'),
        'model_type': get_val('type'),  
        
        'time_scaling': get_val('time_scaling_factor', 'no'),
        'tolerance': get_val('tol'),
        'hidden_dim': get_val('hidden_dim'),
        'batch_size': get_val('batch_size'),
        'epochs': get_val('num_epochs'),
        'lr': get_val('lr'),
        'weight_decay': get_val('weight_decay'),

        'interpolation': get_val('interpolation'),
        'kernel_func': get_val('kernel', 'N/A'),
        'aggregation': get_val('aggr', 'N/A'),
        'conv_kernel_size': get_val('conv_kernel_size', 'N/A'),
        
        'depth': get_val('depth', 'N/A'),
        'step_size': get_val('step_size', 'N/A'),

        'length_scale': get_val('length_scale', 'N/A'),
        'noise_std': get_val('noise_std', 'N/A'),
        
        'bandwidths_str': bw_str,
        'bandwidths_raw': bw_raw, 
        
        'test_acc': get_val('test_accuracy'),
        'test_loss': get_val('test_loss'),
        'best_val_acc': get_val('best_val_accuracy'),
        'final_train_acc': get_val('final_train_accuracy'),
        'train_time_sec': get_val('training_time'),
        'total_time_sec': get_val('total_time'),
        'avg_nfe': get_val('avg_nfe', 0.0)
    }

    row['bw_mode'] = determine_bw_mode(row)

    if isinstance(bw_raw, list):
        row['num_bandwidths'] = len(bw_raw)
    else:
        row['num_bandwidths'] = 0 
        
    row['file_name'] = filepath.name

    return row

def aggregate_directory(target_dir_name):
    base_path = Path(target_dir_name)
    
    if not base_path.exists():
        logger.warning(f"Directory not found: {base_path}")
        return

    logger.info(f"Processing directory: {base_path}")
    json_files = list(base_path.rglob("*.json"))

    if not json_files:
        logger.warning(f"No JSON files found in {base_path}")
        return

    extracted_rows = []
    for jf in json_files:
        row = process_file(jf)
        if row:
            row['file_rel_path'] = str(jf.relative_to(base_path))
            extracted_rows.append(row)

    if not extracted_rows:
        return

    df = pd.DataFrame(extracted_rows)

    numeric_cols = [
        'test_acc', 'best_val_acc', 'final_train_acc', 
        'train_time_sec', 'total_time_sec', 'avg_nfe', 
        'length_scale', 'noise_std', 'conv_kernel_size', 'seed',
        'depth', 'step_size'
    ]
    for col in numeric_cols:
        if col in df.columns:
            df[col] = pd.to_numeric(df[col], errors='coerce')

    output_path = base_path / "summary_results.csv"
    df.to_csv(output_path, index=False)
    logger.info(f"Saved aggregated CSV to: {output_path}")
    logger.info(f"Total experiments processed: {len(df)}")
    logger.info("-" * 40)

if __name__ == "__main__":
    directories_to_process = [
        "dir_example" 
    ]
    
    print("==========================================")
    print("      STARTING RESULTS AGGREGATION        ")
    print("==========================================")

    for d in directories_to_process:
        aggregate_directory(d)

    print("Aggregation complete.")