import re
import pandas as pd
import sys
import os
import argparse

# TODO: haven't tested yet

def extract_metrics_from_log(log_file_path):
    """
    Extract system efficiency, average_time_per_proposal_tok_ms, scoring_time_ms, 
    and verification_time_ms from vLLM log file.
    
    Args:
        log_file_path: Path to the log file
        
    Returns:
        Dictionary containing the extracted metrics as lists
    """
    with open(log_file_path, 'r') as file:
        log_content = file.read()
    
    # Extract system efficiency values
    system_efficiency = re.findall(r'System efficiency: ([\d.]+)', log_content)
    system_efficiency = [float(se) for se in system_efficiency]
    
    # Extract SpecDecodeWorker stage times
    stage_times_pattern = r'SpecDecodeWorker stage times: average_time_per_proposal_tok_ms=([\d.]+) scoring_time_ms=([\d.]+) verification_time_ms=([\d.]+)'
    stage_times = re.findall(stage_times_pattern, log_content)
    
    # Create lists for each metric
    avg_proposal_times = [float(times[0]) for times in stage_times]
    scoring_times = [float(times[1]) for times in stage_times]
    verification_times = [float(times[2]) for times in stage_times]
    
    # Create a dictionary with the metrics
    metrics_dict = {
        'system_efficiency': system_efficiency,
        'average_time_per_proposal_tok_ms': avg_proposal_times,
        'scoring_time_ms': scoring_times,
        'verification_time_ms': verification_times
    }
    
    return metrics_dict

def extract_params_from_filename(filename):
    """
    Extract num_prompts, exp_num, and spec_num from the filename.
    
    Args:
        filename: The log filename
        
    Returns:
        Dictionary with extracted parameters
    """
    params = {}
    
    # Extract num_prompts
    num_prompts_match = re.search(r'np(\d+)', filename)
    if num_prompts_match:
        params['num_prompts'] = int(num_prompts_match.group(1))
    else:
        params['num_prompts'] = None
    
    # Extract exp_num
    exp_num_match = re.search(r'expnum(\d+)', filename)
    if exp_num_match:
        params['exp_num'] = int(exp_num_match.group(1))
    else:
        params['exp_num'] = None
    
    # Extract spec_num
    spec_num_match = re.search(r'st(\d+)', filename)
    if spec_num_match:
        params['num_speculative_tokens'] = int(spec_num_match.group(1))
    else:
        params['num_speculative_tokens'] = None
    
    return params


def process_log_files(log_dir, output_file):
    """
    Process all log files in the directory and save results to a CSV file.
    
    Args:
        log_dir: Directory containing log files
        output_file: Path to output CSV file
        exclude_patterns: List of regex patterns for log files to exclude. If None, all .log and .txt files are processed
    """
    results = []
    
    # Ensure log directory exists
    if not os.path.exists(log_dir):
        print(f"Error: Directory '{log_dir}' does not exist.")
        return
    
    # Create output directory if it doesn't exist
    output_dir = os.path.dirname(output_file)
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    # Get all log files in the directory
    all_files = os.listdir(log_dir)
    log_files = [f for f in all_files if f.endswith('.log') or f.endswith('.txt')]
    
    print(f"Processing {len(log_files)} log files...")
    
    for log_file in log_files:
        log_path = os.path.join(log_dir, log_file)
        
        try:
            # Extract metrics from log file
            metrics = extract_metrics_from_log(log_path)
            
            # Extract parameters from filename
            params = extract_params_from_filename(log_file)
            
            # Calculate trimmed means for each metric
            result_row = {
                'num_prompts': params['num_prompts'],
                'exp_num': params['exp_num'],
                'num_speculative_tokens': params['num_speculative_tokens']
            }
            
            # Calculate and add trimmed means for each metric
            for metric_name, values in metrics.items():
                # The first two results may not be stable yet
                trimmed_mean = sum(values[2:]) / len(values[2:]) if len(values) > 2 else None
                result_row[f'{metric_name}'] = trimmed_mean
            
            results.append(result_row)
            # print(f"Processed {log_file}")
            
        except Exception as e:
            print(f"Error processing {log_file}: {e}")

    print(f"Processed {len(log_files)} log files successfully.")
    
    # Create a DataFrame and save to CSV
    if results:
        df = pd.DataFrame(results)
        
        # Save to CSV
        df.to_csv(output_file, index=False)
        print(f"Results saved to {output_file}")
    else:
        print("No results to save")


def parse_arguments():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(description='Process log files and extract metrics')
    
    parser.add_argument('--logdir', type=str, default='./log_results/sd/',
                        help='Directory containing log files (default: /log_results/sd/)')
    
    parser.add_argument('--output', type=str, default='./csv_results/summary/sd_log.csv',
                        help='Path to output CSV file (default: ./csv_results/sd_log.csv)')
    
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_arguments()
    process_log_files(
        log_dir=args.logdir,
        output_file=args.output
    )