#!/usr/bin/env python3
import os
import re
import pandas as pd
import argparse
from pathlib import Path
import glob


def create_summary_directory(output_path):
    """sdreate a directory named 'summary' at the specified location."""
    summary_dir = Path(output_path)
    summary_dir.mkdir(parents=True, exist_ok=True)
    return summary_dir


def merge_csv_files(directory):
    """
    Merge all CSV files in the given directory.
    
    args:
        directory (str): Path to the directory containing CSV files
        
    Returns:
        pandas.DataFrame: Merged DataFrame
    """
    directory_path = Path(directory)
    
    # Get all CSV files in the directory
    csv_files = list(directory_path.glob("*.csv"))
    
    
    # Read and concatenate all CSV files
    dfs = []
    for file in csv_files:
        try:
            df = pd.read_csv(file)
            dfs.append(df)
        except Exception as e:
            print(f"Error reading {file}: {e}")
    
    if not dfs:
        raise ValueError("No valid CSV files could be read")
    
    # sdoncatenate all dataframes
    return pd.concat(dfs, ignore_index=True)


def validate_csv_sizes(prefill_df, ar_df, sd_df, sd_log_df, k):
    """
    Validate that ar.csv has the same number of rows as prefill.csv,
    and sd.csv has k times the number of rows as prefill.csv.
    
    args:
        prefill_df (pandas.DataFrame): DataFrame from prefill.csv
        ar_df (pandas.DataFrame): DataFrame from ar.csv
        sd_df (pandas.DataFrame): DataFrame from sd.csv
        sd_log_df (pandas.DataFrame): DataFrame from sd_log.csv
        k (int): Expected multiplier for sd
        
    Returns:
        bool: True if validation passes, False otherwise
    """
    prefill_rows = len(prefill_df)
    ar_rows = len(ar_df)
    sd_rows = len(sd_df)
    sd_log_rows = len(sd_log_df)
    
    if prefill_rows != ar_rows:
        print(f"Validation failed: prefill has {prefill_rows} rows, ar has {ar_rows} rows")
        return False
    
    if sd_rows != sd_log_rows:
        print(f"Validation failed: sd has {sd_rows} rows, sd_log has {sd_log_rows} rows")
        return False
    
    if sd_rows != prefill_rows * k:
        print(f"Validation failed: sd has {sd_rows} rows, expected {prefill_rows * k} rows (prefill's {prefill_rows} rows × {k})")
        return False
    
    print(f"Validation passed: prefill has {sd_rows} rows, ar has {ar_rows} rows, sd has {sd_rows} rows, sd_log has {sd_log_rows} rows")
    return True


def merge_final_summary(prefill_df, ar_df, sd_df, sd_log_df):
    """
    Merge prefill, ar, and sd DataFrames into a final summary DataFrame.
    For each row in sd with values c1, c2, c3, find the matching row in prefill and ar,
    and add columns from prefill and ar to the row.
    
    args:
        prefill_df (pandas.DataFrame): DataFrame from prefill.csv
        ar_df (pandas.DataFrame): DataFrame from ar.csv
        sd_df (pandas.DataFrame): DataFrame from sd.csv
        sd_log_df (pandas.DataFrame): DataFrame from sd_log.csv

    Returns:
        pandas.DataFrame: Merged summary DataFrame
    """
    
    dataset_col = "dataset"
    temp_col = "temperature"
    np_col = "num_prompts"
    exp_num_col = "exp_num"
    # check if the columns exist in the dataframes
    if dataset_col not in prefill_df.columns or dataset_col not in ar_df.columns or dataset_col not in sd_df.columns:
        raise ValueError(f"Column '{dataset_col}' not found in one of the DataFrames")
    if temp_col not in prefill_df.columns or temp_col not in ar_df.columns or temp_col not in sd_df.columns:
        raise ValueError(f"Column '{temp_col}' not found in one of the DataFrames")
    if np_col not in prefill_df.columns or np_col not in ar_df.columns or np_col not in sd_df.columns or np_col not in sd_log_df.columns:
        raise ValueError(f"Column '{np_col}' not found in one of the DataFrames")
    if exp_num_col not in prefill_df.columns or exp_num_col not in ar_df.columns or exp_num_col not in sd_df.columns or exp_num_col not in sd_log_df.columns:
        raise ValueError(f"Column '{exp_num_col}' not found in one of the DataFrames")
    
    # Make copies of the dataframes to avoid modifying the originals
    prefill_copy = prefill_df.copy()
    ar_copy = ar_df.copy()
    sd_copy = sd_df.copy()
    sd_log_copy = sd_log_df.copy()
    
    # Create a key column for merging, namely, noting which columns are used to merge
    prefill_copy['merge_key'] = prefill_copy[dataset_col].astype(str) + "_" + prefill_copy[temp_col].astype(str) + "_" + prefill_copy[np_col].astype(str) + "_" + prefill_copy[exp_num_col].astype(str)
    ar_copy['merge_key'] = ar_copy[dataset_col].astype(str) + "_" + ar_copy[temp_col].astype(str) + "_" + ar_copy[np_col].astype(str) + "_" + ar_copy[exp_num_col].astype(str)
    sd_copy['merge_key'] = sd_copy[dataset_col].astype(str) + "_" + sd_copy[temp_col].astype(str) + "_" + sd_copy[np_col].astype(str) + "_" + sd_copy[exp_num_col].astype(str)
    
    # Check if for each unique c1,c2,c3 in sd, there is exactly one matching row in prefill and ar
    c_keys = set(sd_copy['merge_key'])
    a_keys = set(prefill_copy['merge_key'])
    b_keys = set(ar_copy['merge_key'])
    
    missing_in_a = c_keys - a_keys
    missing_in_b = c_keys - b_keys
    
    if missing_in_a:
        raise ValueError(f"{len(missing_in_a)} keys in sd are missing in prefill")
    if missing_in_b:
        raise ValueError(f"{len(missing_in_b)} keys in sd are missing in ar")
    
    # Count occurrences of keys in prefill and ar to check uniqueness
    a_key_counts = prefill_copy['merge_key'].value_counts()
    b_key_counts = ar_copy['merge_key'].value_counts()
    
    duplicate_a_keys = [key for key, count in a_key_counts.items() if count > 1]
    duplicate_b_keys = [key for key, count in b_key_counts.items() if count > 1]
    
    if duplicate_a_keys:
        raise ValueError(f"{len(duplicate_a_keys)} keys in prefill have multiple rows")
        
    if duplicate_b_keys:
        raise ValueError(f"{len(duplicate_b_keys)} keys in ar have multiple rows")
    
    # Perform the merges one by one
    # First merge sd with prefill
    rename_map_prefill = {'total_time': 'prefill_time'}
    prefill_subset = prefill_copy[['merge_key','total_time']].rename(columns=rename_map_prefill)
    merged = pd.merge(
        sd_copy, 
        prefill_subset,
        on='merge_key', 
        how='left'
    )
    
    # Then merge with ar
    rename_map_ar = {'total_time': 'ar_time'}
    ar_subset = ar_copy[['merge_key','total_time']].rename(columns=rename_map_ar)
    merged = pd.merge(
        merged, 
        ar_subset,
        on='merge_key', 
        how='left'
    )
    
    # Drop the merge key column
    drop_columns = ['merge_key', 'seqlen', 'tensor_parallel_size', 'total_prefill_tokens', 'total_decode_tokens', 'prefill_speed', 'decode_speed', 'seed']
    merged = merged.drop(drop_columns, axis=1)

    # Merge the merged with sd_log
    sd_log_copy["spec_merge_key"] = sd_log_copy['num_prompts'].astype(str) + "_" + sd_log_copy['exp_num'].astype(str) + "_" + sd_log_copy['num_speculative_tokens'].astype(str)
    merged["spec_merge_key"] = merged['num_prompts'].astype(str) + "_" + merged['exp_num'].astype(str) + "_" + merged['num_speculative_tokens'].astype(str)
    sd_log_copy_subset = sd_log_copy[['spec_merge_key', 'system_efficiency', 'average_time_per_proposal_tok_ms', 'scoring_time_ms', 'verification_time_ms']]
    merged = pd.merge(
        merged, 
        sd_log_copy_subset,
        on='spec_merge_key',
        how='left'
    )
    merged = merged.drop('spec_merge_key', axis=1)

    return merged


def get_moe_summary(moe_summary_path):
    """
    Load the MoE summary CSV file and get expnum=8 datasets.
    
    args:
        moe_summary_path (str): Path to the MoE summary CSV file
        
    Returns:
        pandas.DataFrame: DataFrame with renamed columns
    """
    moe_df = pd.read_csv(moe_summary_path)
    
    # Filter for exp_num = 8
    moe_df = moe_df[(moe_df['dataset'] == "humaneval") & (moe_df['temperature'] == 0.0) & (moe_df['num_speculative_tokens'] != 3)]
    moe_df["exp_num"] = 8

    return moe_df


def main():
    parser = argparse.ArgumentParser(description="CSV Merger and Validator")
    parser.add_argument("--output", "-o", default="./csv_results/summary", help="Output directory for summary")
    parser.add_argument("--prefill_input", default="./csv_results/prefill", help="Directory prefill containing CSV files")
    parser.add_argument("--ar_input", default="./csv_results/ar", help="Directory auto-regression containing CSV files")
    parser.add_argument("--sd_input", default="./csv_results/sd", help="Directory speculative decoding containing CSV files")
    parser.add_argument("--sd_log", default="./csv_results/summary/sd_log.csv", help="Path to sd_log.csv")
    parser.add_argument("--k", type=int, default=2, help="Expected multiplier for the number of rows in speculative decoding relative to auto-regression")
    parser.add_argument("--moe_summary", default="../moe/csv_results/summary/summary.csv", help="Path to the summary.csv file from the MoE model")
    
    args = parser.parse_args()
    
    # Step 1: Create summary directory
    summary_dir = create_summary_directory(args.output)
    print(f"Create summary directory at {summary_dir}")
    
    # Step 2: Merge CSV files from each directory
    print(f"\nMerging CSV files from directory prefill: {args.prefill_input}")
    prefill_df = merge_csv_files(args.prefill_input)
    prefill_df.sort_values(by=['dataset', 'temperature', 'num_prompts'], inplace=True)
    prefill_output_path = summary_dir / "prefill.csv"
    prefill_df.to_csv(prefill_output_path, index=False)
    print(f"Saved merged prefill.csv with {len(prefill_df)} rows to {prefill_output_path}")
    
    print(f"\nMerging CSV files from directory ar: {args.ar_input}")
    ar_df = merge_csv_files(args.ar_input)
    ar_df.sort_values(by=['dataset', 'temperature', 'num_prompts'], inplace=True)
    ar_output_path = summary_dir / "ar.csv"
    ar_df.to_csv(ar_output_path, index=False)
    print(f"Saved merged ar.csv with {len(ar_df)} rows to {ar_output_path}")
    
    print(f"\nMerging CSV files from directory sd: {args.sd_input}")   
    sd_df = merge_csv_files(args.sd_input)
    sd_df.sort_values(by=['dataset', 'temperature', 'num_prompts','num_speculative_tokens'], inplace=True)
    sd_output_path = summary_dir / "sd.csv"
    sd_df.to_csv(sd_output_path, index=False)
    print(f"Saved merged sd.csv with {len(sd_df)} rows to {sd_output_path}")

    # Step 2.1: load sd_log.csv
    sd_log_df = pd.read_csv(args.sd_log)
    
    # Step 3: Validate sizes
    print("\nValidating CSV sizes...")
    if not validate_csv_sizes(prefill_df, ar_df, sd_df, sd_log_df, args.k):
        raise ValueError("Validation failed.")
    
    # Step 4: Merge to create final summary
    print("\nCreating final summary...")
    summary_df = merge_final_summary(prefill_df, ar_df, sd_df, sd_log_df)

    # Step 5: concatenate with MoE summary
    print("\nConcatenating with MoE summary...")
    moe_df = get_moe_summary(args.moe_summary)
    moe_df = moe_df[summary_df.columns]
    summary_df = pd.concat([summary_df, moe_df], ignore_index=True)
    summary_df.sort_values(by=['num_speculative_tokens', 'exp_num', 'num_prompts'], inplace=True)

    # Step 6: Save the final summary
    summary_output_path = summary_dir / "summary.csv"
    summary_df.to_csv(summary_output_path, index=False)
    print(f"Saved summary.csv with {len(summary_df)} rows to {summary_output_path}")
    
    print("\nProcess completed successfully!")

    
    return 0


if __name__ == "__main__":
    main()