import os
import glob
import pandas as pd
import numpy as np
import argparse

def calculate_stats_for_directory(log_dir, warmup_steps=30):
    print(f"Scanning directory: {log_dir}")
    
    # Find all loss_log files
    loss_files = glob.glob(os.path.join(log_dir, "loss_log_*.csv"))
    
    results = []
    
    for loss_file in sorted(loss_files):
        basename = os.path.basename(loss_file)
        
        # Load loss log
        try:
            loss_df = pd.read_csv(loss_file)
            if loss_df.empty:
                print(f"  Warning: Empty loss log for {basename}")
                continue
        except Exception as e:
            print(f"  Error reading {basename}: {e}")
            continue

        # Extract experiment name
        exp_name = basename.replace("loss_log_", "").replace(".csv", "")
        
        # Filter warmup steps if needed
        if warmup_steps > 0:
            df_analyzed = loss_df[loss_df['step'] >= warmup_steps].copy()
        else:
            df_analyzed = loss_df.copy()
            
        if df_analyzed.empty:
            print(f"  Warning: No data after warmup for {basename}")
            continue
            
        # Calculate Reward Statistics
        if 'reward' in df_analyzed.columns:
            rewards = df_analyzed['reward']
            reward_mean = rewards.mean()
            reward_min = rewards.min()
            reward_max = rewards.max()
            reward_std = rewards.std()
            count = len(rewards)
        else:
            print(f"  Warning: 'reward' column missing in {basename}")
            reward_mean = reward_min = reward_max = reward_std = np.nan
            count = 0

        # Calculate RMSE (if UCB log exists)
        rmse = np.nan
        ucb_filename = basename.replace("loss_log_", "ucb_log_")
        ucb_file = os.path.join(log_dir, ucb_filename)
        
        if os.path.exists(ucb_file):
            try:
                ucb_df = pd.read_csv(ucb_file)
                if not ucb_df.empty:
                    # Merge to find predictions
                    merged_df = pd.merge(df_analyzed, ucb_df, on=['step', 'model', 'qp', 'cp', 'bs'], how='inner')
                    
                    if not merged_df.empty:
                        sq_error = (merged_df['pred'] - merged_df['reward']) ** 2
                        rmse = np.sqrt(sq_error.mean())
            except Exception as e:
                print(f"  Error processing UCB log for {basename}: {e}")
        
        results.append({
            "Experiment": exp_name,
            "Count": count,
            "Reward Mean": reward_mean,
            "Reward Min": reward_min,
            "Reward Max": reward_max,
            "Reward Std": reward_std,
            "RMSE": rmse
        })

    # Print results table
    if results:
        df_results = pd.DataFrame(results)
        print(f"\nResults for {log_dir} (Warmup Steps Excluded: {warmup_steps}):")
        print("-" * 145)
        print(f"{'Experiment':<55} | {'Count':<6} | {'Mean':<8} | {'Min':<8} | {'Max':<8} | {'Std':<8} | {'RMSE':<8}")
        print("-" * 145)
        for _, row in df_results.iterrows():
            mean_str = f"{row['Reward Mean']:.4f}"
            min_str = f"{row['Reward Min']:.4f}"
            max_str = f"{row['Reward Max']:.4f}"
            std_str = f"{row['Reward Std']:.4f}"
            rmse_str = f"{row['RMSE']:.4f}" if not np.isnan(row['RMSE']) else "N/A"
            
            print(f"{row['Experiment']:<55} | {row['Count']:<6} | {mean_str:<8} | {min_str:<8} | {max_str:<8} | {std_str:<8} | {rmse_str:<8}")
        print("-" * 145)
        print("\n")
    else:
        print(f"No results calculated for {log_dir}\n")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Calculate reward stats and RMSE.')
    parser.add_argument('directories', nargs='+', help='List of directories to process')
    parser.add_argument('--warmup', type=int, default=30, help='Number of warmup steps to exclude')
    
    args = parser.parse_args()
    
    for directory in args.directories:
        calculate_stats_for_directory(directory, args.warmup)
