import pandas as pd
import re
from datasets import Dataset
from sklearn.utils import shuffle

def read_data(data_path):
    df = pd.read_csv(data_path)
    df = df.dropna()
    return df
# context,question,answer,explicit_answer,answer_index,options
def prepare_dataset(df):

    df = shuffle(df)
    df = df[:2000]
    df.rename(columns={'explicit_answer': 'correct answer'}, inplace=True)
    df['pretext'] = df.apply(lambda x: f"{x['context']}Speaker_1: {x['question']}\nSpeaker_2: {x['answer']}\n", axis=1)
    df.drop(['context','question','answer','answer_index'],axis=1,inplace=True)
    df.to_csv('~/iclr/pragmatics/global_datasets/task_3.csv',index=False)




if __name__ == "__main__":
    data_path = './data/grice_implicature_recovery.csv'
    # prompt_templates_path = './prompt_templates/task_2.csv'
    df = read_data(data_path)
    prepare_dataset(df)