import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import precision_recall_fscore_support


def load_data(output_dir, output_column='ontreatment_binary_response', answer_prefix='My answer is:'):
    # read all the inference
    inference_dict = {}
    accession_num_dict = {}
    for split in [0.4, 0.6, 0.8, 1, 2, 3]:
        inference_dict[split] = {}
        for fold in range(0, 5):
            path = f"{output_dir}/folds/{split}/{fold}/inference_test_set.csv"
            df = pd.read_csv(path)
            df = df[df[output_column].str.contains(answer_prefix)]
            inference_dict[split][fold] = df
            accession_num_dict[f'{split}_{fold}'] = df['Accession Number'].unique()


    # check if every list of accession number is the same to one another
    for key, val in accession_num_dict.items():
        if set(accession_num_dict['0.4_0']) != set(val):
            print(key)
    # filter based on the accession_num_dict['0.4_0']
    for split in [0.4, 0.6, 0.8, 1, 2, 3]:
        for fold in range(0, 5):
            df = inference_dict[split][fold]
            inference_dict[split][fold] = df[df['Accession Number'].isin(accession_num_dict['0.4_0'])]

    pred_cols = [col for col in inference_dict[0.6][1].columns if 'pred' in col]    
    performance_dict = {}
    for split in [0.4, 0.6, 0.8, 1, 2, 3]:
        performance_dict[split] = {}
        for fold in range(0, 5):
            y_true = []
            y_pred = []
            for col in pred_cols:
                y_pred.extend(inference_dict[split][fold][col].to_list())
                y_true.extend(inference_dict[split][fold][col[:-5]].to_list())
            precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average=None)
            performance_dict[split][fold] = {
                'f1': f1[0],
                'precision': precision[0],
                'recall': recall[0]
            }

    proportions = [0, 0.4, 0.6, 0.8, 1, 2, 3]
    mean_f1 = []
    std_f1 = []

    for p in proportions:
        f1_per_fold = []
        if p not in performance_dict: # this is the baseline
            mean_f1.append(0.783784)
            std_f1.append(0)
        else:
            for fold in performance_dict[p].values():
                f1_macro = fold['f1'].mean()  # macro-average over the two classes
                f1_per_fold.append(f1_macro)
            mean_f1.append(np.mean(f1_per_fold))
            std_f1.append(np.std(f1_per_fold, ddof=1))
    return proportions, mean_f1, std_f1

def make_figure(proportions, mean_f1, std_f1, save_dir):
    plt.figure()

    plt.errorbar(
        proportions,
        mean_f1,
        yerr=std_f1,
        marker='o',
        linestyle='-',
        color='green',
        capsize=4,          # nicer error-bar caps
        label='with augmentation'  # legend label
    )

    plt.errorbar(
        [proportions[0]],           # x
        [mean_f1[0]],               # y
        yerr=[std_f1[0]],           # its own error bar
        marker='o',
        linestyle='',               # no line for the single point
        color='crimson',            # whatever colour you like
        capsize=4,
        label='baseline (p = 0)'
    )

    plt.xlabel('Machine-to-Human Annotation Ratio')
    plt.ylabel('F1 Score')
    plt.title('Impact of Augmenting Training Data with Machine Annotations\n(Error bars = ±1 SD across 5 folds)')
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.savefig(
        f'{save_dir}/F1_score_comparison.png',
        dpi=300, bbox_inches='tight'
    )

if __name__ == "__main__":
    output_dir = "outputs"
    save_dir = "figures"
    proportions, mean_f1, std_f1 = load_data(output_dir)
    make_figure(proportions, mean_f1, std_f1, save_dir)