import os
import json
import ast
import numpy as np

def load_results(results_dir):
    """Load all JSON result files from a directory."""
    all_results = []
    for fname in os.listdir(results_dir):
        if fname.endswith(".json") and fname != "experiment_config.json" and fname != "summary.json":
            fpath = os.path.join(results_dir, fname)
            with open(fpath, 'r') as f:
                try:
                    data = json.load(f)
                    all_results.append(data)
                except json.JSONDecodeError as e:
                    print(f"⚠ Failed to load {fname}: {e}")
    return all_results

def summarize_results(all_results):
    """Aggregate results by configuration (alpha, beta, gamma, lr)."""
    summary = {}
    
    for res in all_results:
        # Use tuple as key
        key = (res['alpha'], res['beta'], res['gamma'], res['lr'])
        if key not in summary:
            summary[key] = {'final_accs': [], 'best_accs': []}
        summary[key]['final_accs'].append(res.get('final_accuracy', 0))
        summary[key]['best_accs'].append(res.get('best_accuracy', 0))
    
    # Convert tuple keys to string for JSON
    serializable_summary = {str(k): {
        'mean_final': float(np.mean(v['final_accs'])),
        'std_final': float(np.std(v['final_accs'])),
        'mean_best': float(np.mean(v['best_accs']))
    } for k, v in summary.items()}
    
    return serializable_summary

def save_summary(summary, out_file='summary.json'):
    """Save aggregated summary to JSON."""
    with open(out_file, 'w') as f:
        json.dump(summary, f, indent=2)

def load_summary(file):
    """Load summary JSON and convert string keys back to tuple."""
    with open(file, 'r') as f:
        data = json.load(f)
    parsed = {ast.literal_eval(k): v for k, v in data.items()}
    return parsed

if __name__ == "__main__":
    results_dir = "./homm_ablation_manual_lr_20250907_123108"
    
    all_results = load_results(results_dir)
    summary = summarize_results(all_results)
    
    save_summary(summary, out_file=os.path.join(results_dir, 'summary.json'))
    
    # Test loading
    loaded_summary = load_summary(os.path.join(results_dir, 'summary.json'))
    print("Loaded summary keys:", list(loaded_summary.keys())[:5])
