import pandas as pd
import datetime
import os
import json

def analyze_trace_data(file_path, start_time, end_time):
    """
    Analyzes trace data from a CSV file, filtering records between start_time and end_time.
    
    Args:
        file_path: Path to the trace data CSV file
        start_time: Start time for filtering (format: 'HH:MM')
        end_time: End time for filtering (format: 'HH:MM')
    
    Returns:
        DataFrame with filtered data
    """
    print(f"Reading trace data from {file_path}...")
    # Read the CSV file
    df = pd.read_csv(file_path)
    
    # Convert timestamp to datetime
    df['TIMESTAMP'] = pd.to_datetime(df['TIMESTAMP'])
    
    # Extract the date from the first row to use as reference date
    reference_date = df['TIMESTAMP'].iloc[0].date()
    
    # Create datetime objects for start and end times
    start_datetime = datetime.datetime.combine(reference_date, 
                                      datetime.datetime.strptime(start_time, '%H:%M').time())
    end_datetime = datetime.datetime.combine(reference_date, 
                                    datetime.datetime.strptime(end_time, '%H:%M').time())
    
    # Filter data between start and end times
    filtered_df = df[(df['TIMESTAMP'] >= start_datetime) & 
                     (df['TIMESTAMP'] <= end_datetime)]
    
    print(f"Found {len(filtered_df)} records between {start_time} and {end_time}")
    
    return filtered_df

def extract_jsonl_with_timestamps(jsonl_path, filtered_trace_df, output_path):
    """
    Extracts samples from JSONL file and adds timestamps based on trace data.
    If there are more timestamps than data samples, uses round-robin approach to repeat data.
    
    Args:
        jsonl_path: Path to the input JSONL file
        filtered_trace_df: DataFrame with filtered trace data containing timestamps
        output_path: Path to save the enriched JSONL file
    """
    print(f"Processing JSONL file: {jsonl_path}")
    
    # Get the number of timestamps
    num_timestamps = len(filtered_trace_df)
    print(f"Number of timestamps to process: {num_timestamps}")
    
    # Get sorted timestamps from the trace data
    timestamps = filtered_trace_df['TIMESTAMP'].sort_values().reset_index(drop=True)
    
    # Read all valid items from the JSONL file
    valid_items = []
    with open(jsonl_path, 'r', encoding='utf-8') as f:
        for i, line in enumerate(f):
            if line.strip():
                try:
                    item = json.loads(line)
                    valid_items.append(item)
                except json.JSONDecodeError:
                    print(f"Error parsing JSON at line {i+1}, skipping")
    
    print(f"Found {len(valid_items)} valid items in JSONL file")
    
    # Prepare the output items with timestamps using round-robin if needed
    output_items = []
    for i in range(num_timestamps):
        # Use modulo to implement round-robin
        item_index = i % len(valid_items)
        item = valid_items[item_index]
        timestamp = timestamps[i]
        
        if item.get('ttft') and item.get('tpot'):
            enriched_item = {
                "timestamp": timestamp.strftime("%Y-%m-%d %H:%M:%S.%f"),
                "prompt": item["prompt"],
                "generated": item["generated"],
                "ttft": item["ttft"],
                "tpot": item["tpot"]
            }
        else:
            enriched_item = {
                "timestamp": timestamp.strftime("%Y-%m-%d %H:%M:%S.%f"),
                "prompt": item["prompt"],
                "generated": item["generated"],
            }
        output_items.append(enriched_item)
    
    # Save the enriched items to the output file
    with open(output_path, 'w', encoding='utf-8') as f:
        for item in output_items:
            f.write(json.dumps(item) + '\n')
    
    print(f"Saved {len(output_items)} items with timestamps to {output_path}")
    if len(valid_items) < num_timestamps:
        print(f"Note: Data was repeated {num_timestamps // len(valid_items) + 1} times to match all timestamps")

def process_trace_data(trace_file, jsonl_file, output_jsonl, start_time=None, end_time=None, duration_seconds=None):
    """
    Process trace data and generate timestamped dataset.
    
    Args:
        trace_file: Path to the trace data CSV file
        jsonl_file: Path to the input JSONL file
        output_jsonl: Path to save the enriched JSONL file
        start_time: Optional start time for filtering (format: 'HH:MM'). If not provided, uses beginning of dataset
        end_time: Optional end time for filtering (format: 'HH:MM')
        duration_seconds: Optional duration in seconds from start_time
    """
    # Read trace data
    print(f"Reading trace data from {trace_file}...")
    df = pd.read_csv(trace_file)
    df['TIMESTAMP'] = pd.to_datetime(df['TIMESTAMP'])
    
    # Get reference date from the first timestamp
    reference_date = df['TIMESTAMP'].iloc[0].date()
    
    # If no start_time provided, use the beginning of the dataset
    if start_time is None:
        start_datetime = df['TIMESTAMP'].min()
        print(f"Using beginning of dataset as start time: {start_datetime}")
    else:
        start_datetime = datetime.datetime.combine(reference_date, 
                                          datetime.datetime.strptime(start_time, '%H:%M').time())
        print(f"Using specified start time: {start_time}")
    
    # Calculate end time based on provided parameters
    if duration_seconds is not None:
        # Calculate end time based on duration
        end_datetime = start_datetime + datetime.timedelta(seconds=duration_seconds)
        print(f"Using duration of {duration_seconds} seconds")
    elif end_time:
        # Use provided end time
        end_datetime = datetime.datetime.combine(reference_date, 
                                        datetime.datetime.strptime(end_time, '%H:%M').time())
        print(f"Using end time {end_time}")
    else:
        # If no end time or duration specified, use all data until the end
        end_datetime = df['TIMESTAMP'].max()
        print("No end time or duration specified, using all data until the end")
    
    # Filter the data
    df = df[(df['TIMESTAMP'] >= start_datetime) & 
            (df['TIMESTAMP'] <= end_datetime)]
    print(f"Found {len(df)} records in the specified time range")
    
    # Process the JSONL file with timestamps
    extract_jsonl_with_timestamps(jsonl_file, df, output_jsonl)

def main():
    # Default file paths
    trace_file = "datasets/AzureLLMInferenceTrace_code.csv"

    # # llama8b-sharegpt
    train_jsonl_file = "datasets/llama8b-sharegpt/train.jsonl"
    train_output_jsonl = "datasets/llama8b-sharegpt/train_timestamped.jsonl"
    test_jsonl_file = "datasets/llama8b-sharegpt/test_with_slo.jsonl"
    test_output_jsonl = "datasets/llama8b-sharegpt/test_with_slo_timestamped.jsonl"
    
    # llama8b-lmsys
    # train_jsonl_file = "datasets/llama8b-lmsys/train.jsonl"
    # train_output_jsonl = "datasets/llama8b-lmsys/train_timestamped.jsonl"
    # test_jsonl_file = "datasets/llama8b-lmsys/test_with_slo.jsonl"
    # test_output_jsonl = "datasets/llama8b-lmsys/test_with_slo_timestamped.jsonl"
    
    # # llama70b-sharegpt
    # train_jsonl_file = "datasets/llama70b-sharegpt/train.jsonl"
    # train_output_jsonl = "datasets/llama70b-sharegpt/train_timestamped.jsonl"
    # test_jsonl_file = "datasets/llama70b-sharegpt/test_with_slo.jsonl"
    # test_output_jsonl = "datasets/llama70b-sharegpt/test_with_slo_timestamped.jsonl"
    
    # # llama70b-lmsys
    # jsonl_file = "datasets/llama70b-lmsys/test_with_slo.jsonl"
    # output_jsonl = "datasets/llama70b-lmsys/test_with_slo_timestamped.jsonl"
    
    # # gemma27b-sharegpt
    # train_jsonl_file = "datasets/gemma27b-sharegpt/train.jsonl"
    # train_output_jsonl = "datasets/gemma27b-sharegpt/train_timestamped.jsonl"
    # test_jsonl_file = "datasets/gemma27b-sharegpt/test_with_slo.jsonl"
    # test_output_jsonl = "datasets/gemma27b-sharegpt/test_with_slo_timestamped.jsonl"

    # # gemma27b-lmsys
    # train_jsonl_file = "datasets/gemma27b-lmsys/train.jsonl"
    # train_output_jsonl = "datasets/gemma27b-lmsys/train_timestamped.jsonl"
    # test_jsonl_file = "datasets/gemma27b-lmsys/test_with_slo.jsonl"
    # test_output_jsonl = "datasets/gemma27b-lmsys/test_with_slo_timestamped.jsonl"

    # Example: Use all timestamps
    # process_trace_data(trace_file, train_jsonl_file, output_jsonl)
    
    # Example: Use specific time range with end time
    # process_trace_data(trace_file, jsonl_file, output_jsonl, 
    #                   start_time="18:30", end_time="18:40")
    
    # Example: Use specific time range with duration
    # process_trace_data(trace_file, jsonl_file, output_jsonl, 
    #                   start_time="18:30", duration_seconds=600)  # 10 minutes duration
    
    # Example: Use duration from beginning of dataset
    # process_trace_data(trace_file, train_jsonl_file, train_output_jsonl, 
                    #   duration_seconds=1200)  
    process_trace_data(trace_file, test_jsonl_file, test_output_jsonl, 
                      duration_seconds=1200)  

if __name__ == "__main__":
    main()
