import os
import json
import matplotlib.pyplot as plt

input_file_path = '/home/ubuntu/AI-ESP/output_cleaned_statistics.json'
# Directory to save plots
output_dir = '/home/ubuntu/AI-ESP/data_analysis/9_9/'

# Mapping for model name replacements
model_name_mapping = {
    "llama-3": "llama-3-70b",
    "chatgpt": "gpt-3.5-turbo",
    "gpt-4-turbo-2024-04-09": "gpt-4o-2024-05-13",
    "mistral": "mistral-large-latest",
    "gemini-1.5-pro": "gemini-1.5-pro",
    "claude-3-5-sonnet-20240620": "claude-3-5-sonnet-20240620"
}

with open(input_file_path, 'r') as file:
    data = json.load(file)

# Create the output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

# Plot settings
colors = ['r', 'g', 'b', 'c', 'm', 'y', 'k']
markers = ['o', 's', 'D', '^', 'v', '<', '>']

# Create a common legend list with models ranked by win rate across all games
model_performance = {}

for game_data in data.values():
    for model_name, stats in game_data.items():
        if model_name not in model_performance:
            model_performance[model_name] = []
        model_performance[model_name].append(stats['win_rate'])

# Calculate average win rate across all games for each model
average_win_rate = {model: sum(rates) / len(rates) for model, rates in model_performance.items()}

# Sort models by their average win rate
sorted_models = sorted(average_win_rate.items(), key=lambda x: x[1], reverse=True)

# Assign a consistent color and marker to each model based on global ranking
model_style_mapping = {model_name: (colors[i % len(colors)], markers[i % len(markers)]) 
                       for i, (model_name, _) in enumerate(sorted_models)}

# Define a function to plot game statistics
def plot_game_statistics(game_name, game_data, model_style_mapping, output_dir):
    plt.figure(figsize=(10, 6))
    
    # Sort models by win rate for this specific game
    sorted_models_for_game = sorted(game_data.items(), key=lambda x: x[1]['win_rate'], reverse=True)
    
    for model_name, stats in sorted_models_for_game:
        color, marker = model_style_mapping[model_name]
        display_name = model_name_mapping.get(model_name, model_name)
        avg_turn_count = stats['average_turn_count']
        win_rate = stats['win_rate']
        plt.scatter(avg_turn_count, win_rate, color=color, marker=marker, label=f'{display_name} (Win Rate: {win_rate:.2f})', s=100, alpha=0.75)
    
    plt.title(f'{game_name} - Model Performance')
    plt.xlabel('Average Turn Count')
    plt.ylabel('Win Rate')
    plt.xlim(left=0)  # Ensures the x-axis starts at 0
    plt.ylim(0, 1)    # Ensures the y-axis is between 0 and 1
    plt.legend(fontsize=15)
    plt.grid(True)

    # Adjust layout to reduce empty space
    plt.tight_layout()
    
    # Save the plot
    plot_file_path = os.path.join(output_dir, f'{game_name}_performance.png')
    plt.savefig(plot_file_path)
    plt.close()

# Generate plots for each game
for game_name, game_data in data.items():
    plot_game_statistics(game_name, game_data, model_style_mapping, output_dir)

print(f"Plots saved in {output_dir}")