import json
from collections import defaultdict
from tqdm import tqdm
from datetime import datetime

# File paths
input_file_path = '/home/ubuntu/AI-ESP-serve/output.json'
output_cleaned_path = '/home/ubuntu/AI-ESP-serve/output_cleaned.json'
output_statistics_path = '/home/ubuntu/AI-ESP-serve/output_cleaned_statistics.json'

# Step 1: Keep only unique game sessions with the latest timestamp
def clean_game_sessions(input_file_path, output_cleaned_path):
    with open(input_file_path, 'r') as file:
        data = json.load(file)

    unique_sessions = {}
    for session in tqdm(data, desc="Cleaning Sessions"):
        session_id = session["game_session_id"]
        session_timestamp = datetime.strptime(session["date"], '%Y-%m-%d %H:%M:%S')  # Convert timestamp to datetime

        if session_id not in unique_sessions:
            unique_sessions[session_id] = session
        else:
            # Keep the session with the latest timestamp
            existing_session_timestamp = datetime.strptime(unique_sessions[session_id]["date"], '%Y-%m-%d %H:%M:%S')
            if session_timestamp > existing_session_timestamp:
                unique_sessions[session_id] = session

    # Convert the unique sessions back to a list
    cleaned_data = list(unique_sessions.values())

    # Save the cleaned data to output_cleaned.json
    with open(output_cleaned_path, 'w') as outfile:
        json.dump(cleaned_data, outfile, indent=4)

    return cleaned_data

# Step 2: Compute game statistics with sorted system_prompt_index and user_rating
def compute_game_statistics(cleaned_data, output_statistics_path):
    game_status_counts = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(int))))
    turn_counts = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
    user_ratings = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
    win_criteria = {
        'Akinator': ['MODEL_WIN'],
        'Bluffing': ['MODEL_WIN'],
        'Taboo': ['MODEL_WIN', 'MAX_ROUND_REACHED']
    }

    for session in tqdm(cleaned_data, desc="Computing Statistics"):
        game_name = session["game_name"]
        template_name = session["template_name"]
        system_prompt_index = session["system_prompt_index"]
        game_status = session["game_status"]
        rounds = session.get("round", 0)
        user_rating = session.get("user_rating", None)

        # Increment game status count, grouped by system_prompt_index
        game_status_counts[game_name][template_name][system_prompt_index][game_status] += 1
        turn_counts[game_name][template_name][system_prompt_index].append(rounds)
        if user_rating is not None:
            user_ratings[game_name][template_name][system_prompt_index].append(user_rating)

    # Create a sorted statistics dictionary
    sorted_game_status_counts = {}

    for game_name, templates in game_status_counts.items():
        sorted_game_status_counts[game_name] = {}
        for template_name, prompts in templates.items():
            sorted_game_status_counts[game_name][template_name] = {}
            
            # Sort system_prompt_index from 0 to 4
            for system_prompt_index in sorted(prompts.keys()):
                statuses = prompts[system_prompt_index]
                total_turns = sum(turn_counts[game_name][template_name][system_prompt_index])
                total_sessions = len(turn_counts[game_name][template_name][system_prompt_index])
                avg_turn_count = total_turns / total_sessions if total_sessions > 0 else 0
                win_count = sum(statuses[status] for status in win_criteria.get(game_name, []))
                total_games = sum(statuses.values())
                win_rate = win_count / total_games if total_games > 0 else 0
                avg_user_rating = (sum(user_ratings[game_name][template_name][system_prompt_index]) /
                                   len(user_ratings[game_name][template_name][system_prompt_index])
                                   if user_ratings[game_name][template_name][system_prompt_index] else 0)

                # Store sorted statistics
                sorted_game_status_counts[game_name][template_name][system_prompt_index] = {
                    **statuses,
                    'average_turn_count': avg_turn_count,
                    'win_rate': win_rate,
                    'average_user_rating': avg_user_rating
                }

    # Save the sorted statistics to output_cleaned_statistics.json
    with open(output_statistics_path, 'w') as stats_file:
        json.dump(sorted_game_status_counts, stats_file, indent=4)

# Execute the functions
cleaned_data = clean_game_sessions(input_file_path, output_cleaned_path)
compute_game_statistics(cleaned_data, output_statistics_path)
