import os
import json
import matplotlib.pyplot as plt

input_file_path = '/home/ubuntu/AI-ESP/output_cleaned_statistics.json'
# Directory to save plots
output_dir = '/home/ubuntu/AI-ESP/data_analysis/9_9/'

# Mapping for model name replacements
model_name_mapping = {
    "llama-3": "llama-3.1-405b",
    "chatgpt": "gpt-3.5-turbo",
    "gpt-4-turbo-2024-04-09": "gpt-4o-2024-08-09",
    "mistral": "mistral-large-latest",
    "gemini-1.5-pro": "gemini-1.5-pro",
    "claude-3-5-sonnet-20240620": "claude-3-5-sonnet-20240620"
}

with open(input_file_path, 'r') as file:
    data = json.load(file)

os.makedirs(output_dir, exist_ok=True)

colors = ['r', 'g', 'b', 'c', 'm', 'y', 'k']
markers = ['o', 's', 'D', '^', 'v', '<', '>']

model_performance = {}

for game_data in data.values():
    for model_name, model_data in game_data.items():
        for system_prompt_index, stats in model_data.items():
            if model_name not in model_performance:
                model_performance[model_name] = []
            model_performance[model_name].append(stats['win_rate'])

average_win_rate = {model: sum(rates) / len(rates) for model, rates in model_performance.items()}

sorted_models = sorted(average_win_rate.items(), key=lambda x: x[1], reverse=True)

# assign colors and markers to models based on the sorted list and mapped model names
model_style_mapping = {model_name_mapping.get(model_name, model_name): (colors[i % len(colors)], markers[i % len(markers)])
                       for i, (model_name, _) in enumerate(sorted_models)}

# a function to plot game statistics with averaged data across system_prompt_index and sorted legend
def plot_game_statistics(game_name, game_data, model_style_mapping, output_dir):
    plt.figure(figsize=(10, 6))

    model_averages = []

    for model_name, model_data in game_data.items():
        total_turn_count = 0
        total_win_rate = 0
        total_user_rating = 0
        count = 0

        for system_prompt_index, stats in model_data.items():
            total_turn_count += stats['average_turn_count']
            total_win_rate += stats['win_rate']
            total_user_rating += stats.get('average_user_rating', 0)
            count += 1

        if count > 0:
            avg_turn_count = total_turn_count / count
            avg_win_rate = total_win_rate / count
            avg_user_rating = total_user_rating / count

            model_averages.append({
                'model_name': model_name,
                'avg_turn_count': avg_turn_count,
                'avg_win_rate': avg_win_rate,
                'avg_user_rating': avg_user_rating
            })

    # sort the models by win rate (descending order)
    model_averages_sorted = sorted(model_averages, key=lambda x: x['avg_win_rate'], reverse=True)

    for i, model_data in enumerate(model_averages_sorted):
        model_name = model_name_mapping.get(model_data['model_name'], model_data['model_name'])
        avg_turn_count = model_data['avg_turn_count']
        avg_win_rate = model_data['avg_win_rate']

        color, marker = model_style_mapping[model_name]
        display_name = model_name_mapping.get(model_name, model_name)

        # Plot the marker for the averaged model data
        plt.scatter(avg_turn_count, avg_win_rate, color=color, marker=marker,
                    label=f'{display_name} (Win Rate: {avg_win_rate:.2f})', s=100, alpha=0.75)

    plt.title(f'{game_name} - Model Performance (Averaged)')
    plt.xlabel('Average Turn Count')
    plt.ylabel('Win Rate')
    plt.xlim(left=0)  # Ensures the x-axis starts at 0
    plt.ylim(0, 1)    # Ensures the y-axis is between 0 and 1
    plt.legend(fontsize=10, loc='best', bbox_to_anchor=(1.05, 1), borderaxespad=0.)
    plt.grid(True)

    plt.tight_layout(rect=[0, 0, 0.75, 1])  # Adjust the right margin for the legend

    # save  plot
    plot_file_path = os.path.join(output_dir, f'{game_name}_performance_averaged.png')
    plt.savefig(plot_file_path)
    plt.show()
    plt.close()

# Define a function to plot average user rating per game per model
def plot_average_user_rating(game_name, game_data, model_style_mapping, output_dir):
    plt.figure(figsize=(10, 6))

    model_averages = []

    # Collect and average the data for each model
    for model_name, model_data in game_data.items():
        total_user_rating = 0
        count = 0

        for system_prompt_index, stats in model_data.items():
            total_user_rating += stats.get('average_user_rating', 0)
            count += 1

        if count > 0:
            avg_user_rating = total_user_rating / count

            model_averages.append({
                'model_name': model_name_mapping.get(model_name, model_name),
                'avg_user_rating': avg_user_rating
            })

    # Sort the models by user rating (descending order)
    model_averages_sorted = sorted(model_averages, key=lambda x: x['avg_user_rating'], reverse=True)

    # Plot the models in sorted order by average user rating
    for i, model_data in enumerate(model_averages_sorted):
        model_name = model_data['model_name']
        avg_user_rating = model_data['avg_user_rating']

        color, marker = model_style_mapping[model_name]
        display_name = model_name_mapping.get(model_name, model_name)

        # Plot the marker for the averaged model data
        plt.scatter(i, avg_user_rating, color=color, marker=marker,
                    label=f'{display_name} (Avg Rating: {avg_user_rating:.2f})', s=100, alpha=0.75)

    plt.title(f'{game_name} - Model User Rating (Averaged)')
    plt.xlabel('Models (Ranked by User Rating)')
    plt.ylabel('Average User Rating')
    plt.xticks(range(len(model_averages_sorted)), [model_data['model_name'] for model_data in model_averages_sorted], rotation=45)
    plt.ylim(0, 10)  # Assuming user ratings are on a scale of 0-10
    plt.grid(True)

    # Adjust layout to reduce empty space
    plt.tight_layout(rect=[0, 0, 0.75, 1])  # Adjust the right margin for the legend

    # Save the plot
    plot_file_path = os.path.join(output_dir, f'{game_name}_user_rating_averaged.png')
    plt.savefig(plot_file_path)
    plt.show()
    plt.close()

# generate plots for each game with averaged system_prompt_index data and sorted legend
for game_name, game_data in data.items():
    plot_game_statistics(game_name, game_data, model_style_mapping, output_dir)

print(f"Plots saved in {output_dir}")

# Generate plots for each game with averaged user rating
for game_name, game_data in data.items():
    plot_average_user_rating(game_name, game_data, model_style_mapping, output_dir)

print(f"User rating plots saved in {output_dir}")