import os
import re
import pandas as pd
import argparse
from pathlib import Path
import glob


def create_summary_directory(output_path):
    """Create a directory named 'summary' at the specified location."""
    summary_dir = Path(output_path)
    summary_dir.mkdir(parents=True, exist_ok=True)
    return summary_dir


def merge_csv_files(directory):
    """
    Merge all CSV files in the given directory that match the pattern.
    
    prefillrgs:
        directory (str): Path to the directory containing CSV files
        
    Returns:
        pandas.DataFrame: Merged DataFrame
    """
    directory_path = Path(directory)
    
    # Get all CSV files in the directory
    csv_files = list(directory_path.glob("*.csv"))
    
    # Read and concatenate all CSV files
    dfs = []
    for file in csv_files:
        try:
            df = pd.read_csv(file)
            dfs.append(df)
        except Exception as e:
            print(f"Error reading {file}: {e}")
    
    if not dfs:
        raise ValueError("No valid CSV files could be read")
    
    # Concatenate all dataframes
    return pd.concat(dfs, ignore_index=True)


def validate_csv_sizes(prefill_df, ar_df, sd_df, sd_log_df, k):
    """
    Validate that ar.csv has the same number of rows as prefill.csv,
    and sd.csv has k times the number of rows as prefill.csv.
    
    prefillrgs:
        prefill_df (pandas.DataFrame): DataFrame from prefill.csv
        ar_df (pandas.DataFrame): DataFrame from ar.csv
        sd_df (pandas.DataFrame): DataFrame from sd.csv
        sd_log_df (pandas.DataFrame): DataFrame from sd_log.csv
        k (int): Expected multiplier for sd
        
    Returns:
        bool: True if validation passes, False otherwise
    """
    prefill_rows = len(prefill_df)
    ar_rows = len(ar_df)
    sd_rows = len(sd_df)
    sd_log_rows = len(sd_log_df)
    
    if prefill_rows != ar_rows:
        raise ValueError(f"Validation failed: prefill has {prefill_rows} rows, ar has {ar_rows} rows")
    
    if sd_rows != sd_log_rows:
        raise ValueError(f"Validation failed: sd has {sd_rows} rows, sd_log has {sd_log_rows} rows")
    
    if sd_rows != prefill_rows * k:
        raise ValueError(f"Validation failed: sd has {sd_rows} rows, expected {prefill_rows * k} rows (prefill's {prefill_rows} rows × {k})")
    
    print(f"Validation passed: prefill has {sd_rows} rows, ar has {ar_rows} rows, sd has {sd_rows} rows, sd_log has {sd_log_rows} rows")
    return True


def merge_final_summary(prefill_df, ar_df, sd_df, sd_log_df):
    """
    Merge prefill, ar, and sd DataFrames into a final summary DataFrame.
    For each row in sd with values c1, c2, c3, find the matching row in prefill and ar,
    and add columns from prefill and ar to the row.
    
    prefillrgs:
        prefill_df (pandas.DataFrame): DataFrame from prefill.csv
        ar_df (pandas.DataFrame): DataFrame from ar.csv
        sd_df (pandas.DataFrame): DataFrame from sd.csv
        sd_log_df (pandas.DataFrame): DataFrame from sd_log.csv

    Returns:
        pandas.DataFrame: Merged summary DataFrame
    """
    
    dataset_col = "dataset"
    temp_col = "temperature"
    np_col = "num_prompts"
    # check if the columns exist in the dataframes
    if dataset_col not in prefill_df.columns or dataset_col not in ar_df.columns or dataset_col not in sd_df.columns:
        raise ValueError(f"Column '{dataset_col}' not found in one of the DataFrames")
    if temp_col not in prefill_df.columns or temp_col not in ar_df.columns or temp_col not in sd_df.columns:
        raise ValueError(f"Column '{temp_col}' not found in one of the DataFrames")
    if np_col not in prefill_df.columns or np_col not in ar_df.columns or np_col not in sd_df.columns or np_col not in sd_log_df.columns:
        raise ValueError(f"Column '{np_col}' not found in one of the DataFrames")
    
    # Make copies of the dataframes to avoid modifying the originals
    prefill_copy = prefill_df.copy()
    ar_copy = ar_df.copy()
    sd_copy = sd_df.copy()
    sd_log_copy = sd_log_df.copy()
    
    # Create a key column for merging, namely, noting which columns are used to merge
    prefill_copy['merge_key'] = prefill_copy[dataset_col].astype(str) + "_" + prefill_copy[temp_col].astype(str) + "_" + prefill_copy[np_col].astype(str) 
    ar_copy['merge_key'] = ar_copy[dataset_col].astype(str) + "_" + ar_copy[temp_col].astype(str) + "_" + ar_copy[np_col].astype(str)
    sd_copy['merge_key'] = sd_copy[dataset_col].astype(str) + "_" + sd_copy[temp_col].astype(str) + "_" + sd_copy[np_col].astype(str)
    
    # Check if for each unique c1,c2,c3 in sd, there is exactly one matching row in prefill and ar
    sd_keys = set(sd_copy['merge_key'])
    prefill_keys = set(prefill_copy['merge_key'])
    ar_keys = set(ar_copy['merge_key'])
    
    missing_in_prefill = sd_keys - prefill_keys
    missing_in_ar = sd_keys - ar_keys
    
    if missing_in_prefill:
        raise ValueError(f"{len(missing_in_prefill)} keys in sd are missing in prefill")
    if missing_in_ar:
        raise ValueError(f"{len(missing_in_ar)} keys in sd are missing in ar")
    
    # Count occurrences of keys in prefill and ar to check uniqueness
    prefill_key_counts = prefill_copy['merge_key'].value_counts()
    ar_key_counts = ar_copy['merge_key'].value_counts()
    
    duplicate_prefill_keys = [key for key, count in prefill_key_counts.items() if count > 1]
    duplicate_ar_keys = [key for key, count in ar_key_counts.items() if count > 1]
    
    if duplicate_prefill_keys:
        raise ValueError(f"{len(duplicate_prefill_keys)} keys in prefill have multiple rows")
        
    if duplicate_ar_keys:
        raise ValueError(f"{len(duplicate_ar_keys)} keys in ar have multiple rows")
    
    # Perform the merges one by one
    # First merge sd with prefill
    rename_map_prefill = {'total_time': 'prefill_time'}
    prefill_subset = prefill_copy[['merge_key','total_time']].rename(columns=rename_map_prefill)
    merged = pd.merge(
        sd_copy, 
        prefill_subset,
        on='merge_key', 
        how='left'
    )
    
    # Then merge with ar
    rename_map_ar = {'total_time': 'ar_time'}
    ar_subset = ar_copy[['merge_key','total_time']].rename(columns=rename_map_ar)
    merged = pd.merge(
        merged, 
        ar_subset,
        on='merge_key', 
        how='left'
    )
    
    # Drop the merge key colum and unused columns
    drop_columns = ['seqlen', 'tensor_parallel_size', 'total_prefill_tokens', 'total_decode_tokens', 'prefill_speed', 'decode_speed', 'seed']
    merged = merged.drop(drop_columns, axis=1)

    # Merge the merged with sd_log
    
    sd_log_copy["spec_merge_key"] = sd_log_copy[dataset_col].astype(str) + "_" + sd_log_copy[temp_col].astype(str) + "_" + sd_log_copy[np_col].astype(str) + "_" + sd_log_copy['num_speculative_tokens'].astype(str)
    merged["spec_merge_key"] = merged['merge_key'].astype(str) + "_" + merged['num_speculative_tokens'].astype(str)
    sd_log_copy_subset = sd_log_copy[['spec_merge_key', 'system_efficiency', 'average_time_per_proposal_tok_ms', 'scoring_time_ms', 'verification_time_ms']]
    merged = pd.merge(
        merged, 
        sd_log_copy_subset,
        on='spec_merge_key',
        how='left'
    )
    merged = merged.drop(['spec_merge_key', 'merge_key'], axis=1)

    return merged


def main():
    parser = argparse.ArgumentParser(description="CSV Merger and Validator")
    parser.add_argument("--output", "-o", default="./csv_results/summary", help="Output directory for summary")
    parser.add_argument("--prefill_input", default="./csv_results/prefill", help="Directory prefill containing CSV files")
    parser.add_argument("--ar_input", default="./csv_results/ar", help="Directory auto-regression containing CSV files")
    parser.add_argument("--sd_input", default="./csv_results/sd", help="Directory speculative decoding containing CSV files")
    parser.add_argument("--sd_log", default="./csv_results/summary/sd_log.csv", help="Path to sd_log.csv")
    parser.add_argument("--k", type=int, default=3, help="Expected multiplier for the number of rows in speculative decoding relative to auto-regression")
    
    args = parser.parse_args()
    
    # Step 1: Create summary directory
    summary_dir = create_summary_directory(args.output)
    print(f"Create summary directory at {summary_dir}")
    
    # Step 2: Merge CSV files from each directory and save them to the summary directory
    print(f"\nMerging CSV files from directory prefill: {args.prefill_input}")
    prefill_df = merge_csv_files(args.prefill_input)
    prefill_df.sort_values(by=['dataset', 'temperature', 'num_prompts'], inplace=True)
    prefill_output_path = summary_dir / "prefill.csv"
    prefill_df.to_csv(prefill_output_path, index=False)
    print(f"Saved merged prefill.csv with {len(prefill_df)} rows to {prefill_output_path}")
    
    print(f"\nMerging CSV files from directory ar: {args.ar_input}")
    ar_df = merge_csv_files(args.ar_input)
    ar_df.sort_values(by=['dataset', 'temperature', 'num_prompts'], inplace=True)
    ar_output_path = summary_dir / "ar.csv"
    ar_df.to_csv(ar_output_path, index=False)
    print(f"Saved merged ar.csv with {len(ar_df)} rows to {ar_output_path}")
    
    print(f"\nMerging CSV files from directory sd: {args.sd_input}")   
    sd_df = merge_csv_files(args.sd_input)
    sd_df.sort_values(by=['dataset', 'temperature', 'num_prompts','num_speculative_tokens'], inplace=True)
    sd_output_path = summary_dir / "sd.csv"
    sd_df.to_csv(sd_output_path, index=False)
    print(f"Saved merged sd.csv with {len(sd_df)} rows to {sd_output_path}")

    # Step 2.1: load sd_log.csv
    sd_log_df = pd.read_csv(args.sd_log)
    
    # Step 3: Validate sizes
    print("\nValidating CSV sizes...")
    if not validate_csv_sizes(prefill_df, ar_df, sd_df, sd_log_df, args.k):
        raise ValueError("Validation failed")
    
    # Step 4: Merge files to create final summary
    print("\nCreating final summary...")
    summary_df = merge_final_summary(prefill_df, ar_df, sd_df, sd_log_df)
    summary_output_path = summary_dir / "summary.csv"
    summary_df.to_csv(summary_output_path, index=False)
    print(f"Saved summary.csv with {len(summary_df)} rows to {summary_output_path}")
    
    print("\nProcess completed successfully!")

    
    return 0


if __name__ == "__main__":
    main()