import pandas as pd
import json
def parse_symptom_probs(answer_dict, threshold=0.8):
    """
    Converts symptom probability dict to binary labels using a threshold.
    Args:
        answer_dict (dict): e.g. {'(A): Diarrhea': {'0': 1.0, '1': 4e-08}, ...}
        threshold (float): threshold for positive label
    Returns:
        dict: e.g. {'(A): Diarrhea': 0, ...}
    """
    answer_dict = json.loads(answer_dict.replace("'", '"'))
    parsed_answer = {}
    for symptom, probs in answer_dict.items():
        prob_1 = probs.get('1', 0)
        parsed_answer[symptom] = int(prob_1 > threshold)
    return parsed_answer

df_train = pd.read_csv('data/ontreatment/on_treatment_train.csv')
gpt_4o_annotations = pd.read_csv("data/ontreatment/additional_on_treatment_gpt_4o_annotations.csv")

# convert the string representation of dictionary to actual dictionary
# parsed the probability json
gpt_4o_annotations['parsed_labels'] = gpt_4o_annotations['probability_json'].apply(parse_symptom_probs, threshold=0.8)
# convert the json into wide columns
parsed_wide = gpt_4o_annotations['parsed_labels'].apply(pd.Series)
gpt_4o_annotations = pd.concat([gpt_4o_annotations, parsed_wide], axis=1)

sample_dict = {}
for proportion in [0.4, 0.6, 0.8, 1.0, 2.0, 3.0]:
    sample = gpt_4o_annotations.sample(n=int(proportion*df_train.shape[0]), random_state=1)
    threshold_sample = pd.concat([sample, df_train], ignore_index=True)
    threshold_sample.to_csv(f'data/on_treatment_train_{int(proportion*100)}_threshold.csv', index=False)