import json
import sys
import matplotlib.pyplot as plt

def load_jsonl(file_path):
    with open(file_path, 'r') as file:
        data = [json.loads(line) for line in file]
    return data

def extract_metrics_from_choices(data):
    total_tokens = 0
    total_wall_time = 0.0
    
    for entry in data:
        for choice in entry.get('choices', []):
            if 'new_tokens' in choice and 'wall_time' in choice:
                total_tokens += sum(choice['new_tokens']) if isinstance(choice['new_tokens'], list) else choice['new_tokens']
                total_wall_time += sum(choice['wall_time']) if isinstance(choice['wall_time'], list) else choice['wall_time']
                
    return total_tokens, total_wall_time

def extract_individual_times(data):
    tokens_list = []
    wall_time_list = []
    
    for entry in data:
        for choice in entry.get('choices', []):
            if 'new_tokens' in choice and 'wall_time' in choice:
                tokens = sum(choice['new_tokens']) if isinstance(choice['new_tokens'], list) else choice['new_tokens']
                wall_time = sum(choice['wall_time']) if isinstance(choice['wall_time'], list) else choice['wall_time']
                tokens_list.append(tokens)
                wall_time_list.append(wall_time)
                
    return tokens_list, wall_time_list

def plot_bar_graphs(models, total_tokens, total_wall_time, generation_speed):
    # Plot Total Tokens
    plt.figure(figsize=(10, 6))
    plt.bar(models, total_tokens, color=['blue', 'orange'])
    plt.xlabel('Models')
    plt.ylabel('Total Tokens Generated')
    plt.title('Total Tokens Generated by Each Model')
    plt.savefig('total_tokens_generated.png')

    # Plot Total Wall Time
    plt.figure(figsize=(10, 6))
    plt.bar(models, total_wall_time, color=['blue', 'orange'])
    plt.xlabel('Models')
    plt.ylabel('Total Wall Time (seconds)')
    plt.title('Total Wall Time for Each Model')
    plt.savefig('total_wall_time.png')

    # Plot Generation Speed
    plt.figure(figsize=(10, 6))
    plt.bar(models, generation_speed, color=['blue', 'orange'])
    plt.xlabel('Models')
    plt.ylabel('Generation Speed (tokens per second)')
    plt.title('Generation Speed of Each Model')
    plt.savefig('generation_speed.png')

    # Combined Plot
    fig, axs = plt.subplots(3, 1, figsize=(12, 18))

    # Total Tokens
    axs[0].bar(models, total_tokens, color=['blue', 'orange'])
    axs[0].set_xlabel('Models')
    axs[0].set_ylabel('Total Tokens Generated')
    axs[0].set_title('Total Tokens Generated by Each Model')

    # Total Wall Time
    axs[1].bar(models, total_wall_time, color=['blue', 'orange'])
    axs[1].set_xlabel('Models')
    axs[1].set_ylabel('Total Wall Time (seconds)')
    axs[1].set_title('Total Wall Time for Each Model')

    # Generation Speed
    axs[2].bar(models, generation_speed, color=['blue', 'orange'])
    axs[2].set_xlabel('Models')
    axs[2].set_ylabel('Generation Speed (tokens per second)')
    axs[2].set_title('Generation Speed of Each Model')

    plt.tight_layout()
    plt.savefig('combined_plot.png')

def plot_individual_times(wall_time_list1, wall_time_list2):
    # Plot individual generation times for tokens for both files
    plt.figure(figsize=(12, 6))
    plt.bar(range(len(wall_time_list1)), wall_time_list1, color='blue', label='Baseline Model', alpha=0.7)
    plt.bar(range(len(wall_time_list2)), wall_time_list2, color='orange', label='Rest Version', alpha=0.7)
    plt.xlabel('Generation Instance')
    plt.ylabel('Wall Time (seconds)')
    plt.title('Individual Generation Times for Tokens')
    plt.legend()
    plt.savefig('individual_generation_times_combined.png')

def main(file1_path, file2_path):
    data1 = load_jsonl(file1_path)
    data2 = load_jsonl(file2_path)

    total_tokens1, total_wall_time1 = extract_metrics_from_choices(data1)
    total_tokens2, total_wall_time2 = extract_metrics_from_choices(data2)

    generation_speed1 = total_tokens1 / total_wall_time1 if total_wall_time1 > 0 else float('inf')
    generation_speed2 = total_tokens2 / total_wall_time2 if total_wall_time2 > 0 else float('inf')

    models = ['Baseline', 'Rest Version']
    total_tokens = [total_tokens1, total_tokens2]
    total_wall_time = [total_wall_time1, total_wall_time2]
    generation_speed = [generation_speed1, generation_speed2]

    plot_bar_graphs(models, total_tokens, total_wall_time, generation_speed)

    tokens_list1, wall_time_list1 = extract_individual_times(data1)
    tokens_list2, wall_time_list2 = extract_individual_times(data2)

    plot_individual_times(wall_time_list1, wall_time_list2)

    # Calculate and print average generation times
    average_time1 = sum(wall_time_list1) / len(wall_time_list1) if len(wall_time_list1) > 0 else float('inf')
    average_time2 = sum(wall_time_list2) / len(wall_time_list2) if len(wall_time_list2) > 0 else float('inf')
    print(f"Average Generation Time - Baseline: {average_time1:.2f} seconds")
    print(f"Average Generation Time - Rest Version: {average_time2:.2f} seconds")

if __name__ == "__main__":
    if len(sys.argv) != 3:
        print("Usage: python script.py <baseline_file.jsonl> <rest_version_file.jsonl>")
    else:
        main(sys.argv[1], sys.argv[2])
