import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, matthews_corrcoef

plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['font.size'] = 28  # Adjust the font size as needed
plt.rcParams['font.weight'] = 'bold'
plt.rcParams['axes.titleweight'] = 'bold'
plt.rcParams['axes.labelweight'] = 'bold'
import matplotlib.lines as mlines
import ast

color_method1 = 'tab:blue'
color_method2 = 'tab:orange'

def label_corr(df, threshold=0.9):
    df['label'] = df['r_obs'].apply(lambda x: 1 if abs(x) >= threshold else 0)
    print(df['label'].value_counts())

def get_most_probable(distribution):
    distribution = ast.literal_eval(distribution)
    values, probs = zip(*distribution)
    # get the value with the highest prob
    max_prob_index = probs.index(max(probs))
    return values[max_prob_index]

def get_dist(distribution):
    distribution = ast.literal_eval(distribution)
    return distribution

def compute_precision_recall_accuracy(df, threshold=0.9, column='predicted_coef'):
    # calculate the precision
    tp = len(df[(df['label'] == 1) & (abs(df[column]) >= threshold)])
    fp = len(df[(df['label'] == 0) & (abs(df[column]) >= threshold)])
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0

    # calculate the recall
    fn = len(df[(df['label'] == 1) & (abs(df[column]) < threshold)])
    tn = len(df[(df['label'] == 0) & (abs(df[column]) < threshold)])
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0

    # calculate the accuracy
    accuracy = (tp+tn)/len(df) if len(df) > 0 else 0

    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

    # compute mcc
    mcc_numerator = (tp * tn) - (fp * fn)
    mcc_denominator = ((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)) ** 0.5
    mcc = mcc_numerator / mcc_denominator if mcc_denominator > 0 else 0

    return precision, recall, accuracy, f1, mcc

def compute_metrics_zero_shot(df, thresholds, method):
    data = {}
    roberta_zero_shot = pd.read_csv("outputs/roberta_classifier/test_with_predictions_train_ratio_0.0.csv")
    for threshold in thresholds:
        print(f"Threshold: {threshold}")
        label_corr(df, threshold)
        if method == "ours":
            precision, recall, accuracy, f1, mcc = compute_precision_recall_accuracy(df, threshold, 'most_probable')
        elif method == "trummer23":
            # join df and roberta on pair_id and then append the column label to df
            df = df.merge(roberta_zero_shot[['pair_id', 'predictions']], on='pair_id', how='left')
            precision, recall, accuracy, f1, mcc = compute_precision_recall_accuracy(df, threshold, 'predictions')
        data[threshold] = [
            accuracy,
            precision,
            recall,
            f1,
            mcc
        ]
    return data

def compute_metrics_with_training(thresholds, training_ratio=0.2, our_path=""):
    data0 = {}
    data1 = {}
    data2 = {}
    df_ours = pd.read_csv(our_path)
    for threshold in thresholds:
        path = f"outputs/roberta_classifier/test_with_predictions_train_ratio_{training_ratio}_coef_t_{threshold}.csv"
        df = pd.read_csv(path)
        pair_ids = df['pair_id'].tolist()
        df_sub = df_ours[df_ours['pair_id'].isin(pair_ids)]
        accuracy = accuracy_score(df['labels'], df['predictions'])
        precision = precision_score(df['labels'], df['predictions'])
        recall = recall_score(df['labels'], df['predictions'])
        f1 = f1_score(df['labels'], df['predictions'])
        mcc = matthews_corrcoef(df['labels'], df['predictions'])
        data0[threshold] = compute_metrics_zero_shot(df_sub, [threshold], "trummer23")[threshold]
        data1[threshold] = [
            accuracy,
            precision,
            recall,
            f1,
            mcc
        ]
        data2[threshold] = compute_metrics_zero_shot(df_sub, [threshold], "ours")[threshold]
    return data0, data1, data2

def plot_in_one():
    thresholds = [0.5, 0.6, 0.7, 0.8]
    metrics_names = ['Acc.', 'Prec.', 'Recall', 'F1', 'MCC']
    import seaborn as sns
    palette = sns.color_palette("Set2", n_colors=3)

    color_baseline, color_method1, color_method2 = palette
    color_marker   = '#B22222' 
  
    data_path = "outputs/real_world_correlations/real_world_correlations_gpt-4o_lc_prior_iter_0.csv"
  
    baseline_results, method1_results, method2_results = compute_metrics_with_training(thresholds, 0.2, data_path)
    
    # Set up subplots
    fig, axs = plt.subplots(1, len(thresholds), figsize=(25, 6), sharey=True)
    
    # narrower bar so three will fit
    width = 0.30
    
    for i, thresh in enumerate(thresholds):
        ax = axs[i]
        x = np.arange(len(metrics_names))
        
        # grab the three method arrays
        m0 = baseline_results[thresh]
        m1 = method1_results[thresh]
        m2 = method2_results[thresh]
        import matplotlib.colors as mcolors
        # plot left, center, right
        bars0 = ax.bar(x - width,     m0, width, label='Zero-shot RoBERTa', 
                       facecolor=mcolors.to_rgba(color_baseline, 0.6),
                    #    color=color_baseline,
                        # alpha=0.3,
                        edgecolor='black', 
                        linewidth=1)
        bars1 = ax.bar(x, m1, width, label='Roberta 20% training', 
                       facecolor=mcolors.to_rgba(color_method1, 0.6), 
                    #    color=color_method1, 
                    #    alpha=0.3,
                        edgecolor='black',)
        bars2 = ax.bar(x + width,     m2, width, label='LCP (Ours)', 
                       facecolor=mcolors.to_rgba(color_method2, 0.6),
                    #    color=color_method2, 
                    #    alpha=0.3,
                        edgecolor='black' )
        
        ax.set_title(f'Threshold = {thresh}')
        ax.set_xticks(x)
        ax.set_xticklabels(metrics_names, rotation=0, fontsize=25)
       
        ymin, ymax = ax.get_ylim()
        offset = 0.005 * (ymax - ymin)
        
        annotate_size = 10

        for bars, col in zip([bars0, bars1, bars2],
                            [color_baseline, color_method1, color_method2]):
            for bar in bars:
                h = bar.get_height()
                if h > 0:
                    ax.text(
                        bar.get_x() + bar.get_width()/2,
                        h/2,
                        f"{h:.2f}",
                        ha='center', va='center',
                        rotation='vertical',
                        fontsize=23, fontweight='bold',
                        color='black'
                        )
        
        # add red 'X' for MCC=0
        mcc_idx = metrics_names.index('MCC')
        for j, (mx, vec) in enumerate(zip([x[mcc_idx]-width, x[mcc_idx], x[mcc_idx]+width],
                                          [m0, m1, m2])):
            if vec[mcc_idx] == 0:
                ax.plot(mx, ymin + 0.04*(ymax-ymin),
                        marker='x', color=color_marker,
                        markersize=10, markeredgewidth=2)
    
    baseline_patch = mlines.Line2D([], [], color=color_baseline, marker='s', linestyle='None',
                                   markersize=25, label='Zero-shot RoBERTa')
    m1_patch      = mlines.Line2D([], [], color=color_method1, marker='s', linestyle='None',
                                   markersize=25, label='RoBERTa 20% training')
    m2_patch      = mlines.Line2D([], [], color=color_method2, marker='s', linestyle='None',
                                   markersize=25, label='LCP (Ours)')
    x_marker      = mlines.Line2D([], [], color=color_marker,  marker='x', linestyle='None',
                                   markersize=25, label='MCC = 0')
    
    fig.legend(handles=[baseline_patch, m1_patch, m2_patch, x_marker],
               loc='lower center', ncol=4, fontsize='medium')
    
    plt.tight_layout(rect=[0, 0.1, 1, 1])
    print("saved at figures/classify_comparison_20_training_all.png")
    plt.savefig('figures/classify_comparison_20_training_all.png', dpi=300, bbox_inches='tight')
    plt.savefig('figures/classify_comparison_20_training_all.pdf', bbox_inches='tight')
    plt.close()

if __name__ == "__main__":
    plot_in_one()