import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
import os


def plot_llm_tokens(file_path: str, save_path: str):
    df = pd.read_csv(file_path)
    df_melted = df.melt(id_vars='agent', var_name='token_type', value_name='tokens')

    # Set Seaborn style for better aesthetics
    sns.set_style("whitegrid")
    plt.figure(figsize=(14, 8))  # Slightly larger figure

    sns.barplot(data=df_melted, x='agent', y='tokens', hue='token_type', palette='viridis',
                edgecolor='black', linewidth=0.7)  # Add subtle borders to bars

    plt.title('LLM Agent Input and Output Tokens', fontsize=16, fontweight='bold')
    plt.xlabel('Agent', fontsize=12)
    plt.ylabel('Number of Tokens', fontsize=12)
    plt.xticks(rotation=45, ha='right', fontsize=10)
    plt.yticks(fontsize=10)
    plt.legend(title='Token Type', fontsize=10, title_fontsize='12')
    plt.grid(axis='y', linestyle='--', alpha=0.7)  # Add horizontal grid lines
    plt.tight_layout()

    save_path = os.path.join(save_path, "llm_tokens_plot.png")
    plt.savefig(save_path, dpi=300)
    plt.close()


def plot_agent_running_time(file_path: str, save_path: str):
    df = pd.read_csv(file_path)
    df['running_time_in_minutes'] = df['running_time_in_seconds'] / 60

    sns.set_style("whitegrid")
    plt.figure(figsize=(14, 8))

    sns.barplot(data=df, x='agent', y='running_time_in_minutes', palette='plasma',
                edgecolor='black', linewidth=0.7)

    plt.title('Agent Running Time (in Minutes)', fontsize=16, fontweight='bold')
    plt.xlabel('Agent/Stage', fontsize=12)
    plt.ylabel('Running Time (Minutes)', fontsize=12)
    plt.xticks(rotation=45, ha='right', fontsize=10)
    plt.yticks(fontsize=10)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.tight_layout()

    save_path = os.path.join(save_path, "agent_running_time_plot.png")
    plt.savefig(save_path, dpi=300)
    plt.close()


def plot_debug_logs(file_path: str, save_path: str):
    df = pd.read_csv(file_path)
    df['is_successful'] = df['is_successful'].astype(bool)

    sns.set_style("whitegrid")
    fig, axes = plt.subplots(1, 1, figsize=(18, 7))

    # Plot total submit attempts
    sns.barplot(data=df, x='debug_mode', y='total_submit_attempts', hue='is_successful', palette='rocket', ax=axes,
                edgecolor='black', linewidth=0.7)
    axes.set_title('Total Submit Attempts by Debug Mode and Success', fontsize=14, fontweight='bold')
    axes.set_xlabel('Debug Mode', fontsize=12)
    axes.set_ylabel('Total Submit Attempts', fontsize=12)

    # Fix for the UserWarning: Set tick locations explicitly before setting labels
    axes.set_xticks(axes.get_xticks())  # This line ensures a fixed number of ticks
    axes.set_xticklabels(axes.get_xticklabels(), rotation=45, ha='right', fontsize=10)

    axes.tick_params(axis='y', labelsize=10)
    axes.legend(title='Is Successful?', fontsize=10, title_fontsize='12')
    axes.grid(axis='y', linestyle='--', alpha=0.7)

    plt.tight_layout()

    save_path = os.path.join(save_path, "debug_logs_plot.png")
    plt.savefig(save_path, dpi=300)
    plt.close()


def plot_checker_agent_correctness(file_path: str, save_path: str):
    df = pd.read_csv(file_path)
    df['is_it_correct'] = df['is_it_correct'].astype(bool)

    sns.set_style("whitegrid")
    plt.figure(figsize=(12, 7))

    sns.countplot(data=df, x='agent_name', hue='is_it_correct', palette='viridis',
                  edgecolor='black', linewidth=0.7)

    plt.title('Checker Agent Correctness Assessment', fontsize=16, fontweight='bold')
    plt.xlabel('Agent Name', fontsize=12)
    plt.ylabel('Count', fontsize=12)
    plt.xticks(rotation=45, ha='right', fontsize=10)
    plt.yticks(fontsize=10)
    plt.legend(title='Is Correct?', fontsize=10, title_fontsize='12')
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.tight_layout()

    save_path = os.path.join(save_path, "checker_agent_correctness_plot.png")
    plt.savefig(save_path, dpi=300)
    plt.close()


def generate_all_plots(llm_tokens_file: str, agent_time_file: str, debug_logs_file: str,
                       checker_agent_file: str, path_to_save: str
                       ):
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", FutureWarning)
        plot_llm_tokens(llm_tokens_file, path_to_save)
        plot_agent_running_time(agent_time_file, path_to_save)
        plot_debug_logs(debug_logs_file, path_to_save)
        plot_checker_agent_correctness(checker_agent_file, path_to_save)