import os
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

import seaborn as sns
from matplotlib.lines import Line2D

def fix_fonts(title=20, label=20, xtick=15, ytick=15, default=15):
    # Set the global font family to 'Times New Roman'
    # keep running into
    plt.rc('font', family='serif', serif=['Times New Roman'])
    plt.rcParams["font.family"] = "Times New Roman"
    # Optional: also specify it as the default serif font
    plt.rcParams["font.serif"] = ["Times New Roman"]

    # Set the global default font size (e.g., to 14)
    plt.rcParams["font.size"] = default
    plt.rcParams["xtick.labelsize"] = xtick  # Optional: specific size for x-axis ticks
    plt.rcParams["ytick.labelsize"] = ytick  # Optional: specific size for y-axis ticks
    plt.rcParams["axes.labelsize"] = label  # Optional: specific size for axis labels
    plt.rcParams["axes.titlesize"] = title  # Optional: specific size for plot titles


def plot_overlap_and_circuit_performance():
    # subplots with 3 rows
    fix_fonts(title=15,label=15, xtick=15, ytick=15, default=11.5)
    fig, axes = plt.subplots(3, 1, sharex=True, figsize=(10,6))

    ## top plot is bar plot of head overlap coefficient between each group
    overlap_pct = pd.DataFrame({
        'Category': ['Group A', 'Group B', 'Group C', 'Group D'],
        'Overlap Coefficient': [48/120, 6/60, 6/80, 1/80]
    })

    # Use barplot for pre-aggregated data
    sns.barplot(x='Category', y='Overlap Coefficient', data=overlap_pct, ax=axes[0], fill=False, color="black")
    plt.xlabel('')
    axes[0].set_title("Head Overlap Coefficient Per Group")


    ## second plot is description object top-k accuracy across
    desc_acc = pd.DataFrame({
        'Category': ['Group A', 'Group B', 'Group C', 'Group D'],
        'Desc Obj': [0.2, 0.2, 0.2, 0.2]
    })

    # Use barplot for pre-aggregated data
    sns.barplot(x='Category', y='Desc Obj', data=desc_acc, ax=axes[1], color=sns.color_palette()[0])
    plt.xlabel('')
    # Full model performance (0.91)
    axes[1].axhline(y=0.91, color='black', linewidth=2, label='Full Model')
    # Description circuit performance top-k (0.52)
    axes[1].axhline(y=0.52, color=sns.color_palette()[0], linestyle='--', linewidth=2, label='Description Circuit')
    # Put circuit performance top-k (0.08)
    axes[1].axhline(y=0.08, color=sns.color_palette()[1], linestyle='--', linewidth=2, label='Put Circuit')
    # Random circuit performance top-k (0.034)
    axes[1].axhline(y=0.034, color='gray', linestyle=':', linewidth=2, label='Random Circuit')
    axes[1].set_title("Top-K Accuracy of Predicting Object")

    ## third plot is description object top-k accuracy across
    desc_acc = pd.DataFrame({
        'Category': ['Group A', 'Group B', 'Group C', 'Group D'],
        'Put Obj': [0.2, 0.2, 0.2, 0.2]
    })

    # Use barplot for pre-aggregated data
    sns.barplot(x='Category', y='Put Obj', data=desc_acc, ax=axes[2], color=sns.color_palette()[1])
    plt.xlabel('')
    # Full model performance (0.92)
    axes[2].axhline(y=0.92, color='black', linewidth=2, label='Full Model')
    # Description circuit performance top-k (0.52)
    axes[2].axhline(y=0.01, color=sns.color_palette()[0], linestyle='--', linewidth=2, label='Description Circuit')
    # Put circuit performance top-k (0.08)
    axes[2].axhline(y=0.99, color=sns.color_palette()[1], linestyle='--', linewidth=2, label='Put Circuit')
    # Random circuit performance top-k (0.034)
    axes[2].axhline(y=0.011, color='gray', linestyle=':', linewidth=2, label='Random Circuit')

    # handles, labels = axes[1].get_legend_handles_labels()
    handles = [
        # mpatches.Patch(color="black", label="Full Model"),
        mpatches.Patch(color=sns.color_palette()[0], label="Description Circuit"),
        mpatches.Patch(color=sns.color_palette()[1], label="Put Circuit"),


        Line2D([0], [0], color='black', linestyle='--', label='Full Circuit'),
        mpatches.Patch(facecolor='white', edgecolor='black', label='LOO Circuit', linewidth=1.0,),
        Line2D([0], [0], color='black', linestyle='-', label='Full Model'),
        Line2D([0], [0], color='gray', linestyle=':', label='Random Circuit'),
    ]
    # Place a single legend on the figure, above the subplots
    # 'upper center' puts it horizontally centered at the top
    # bbox_to_anchor defines the position in figure coordinates (0, 0 to 1, 1)
    fig.legend(handles=handles, loc='upper center', bbox_to_anchor=(0.5, 1.), ncol=6)

    # Adjust layout to prevent the legend from overlapping with the titles/plots
    plt.tight_layout(rect=[0, 0, 1, 0.95])  # Adjust the rect to make space for the legend above

    # plt.show()
    plt.savefig('../outputs/overlap_and_circuit_performance.png', dpi=600)

    pass

if __name__ == '__main__':
    plot_overlap_and_circuit_performance()