import pandas as pd
import torch

def hallucination_dataset(data_path,template,test_frac):
    data_df = pd.read_csv(data_path)
    test_df = data_df.sample(frac=test_frac, random_state=0)
    train_df = data_df[~data_df.index.isin(test_df.index)]

    train_data = train_df["instruction"].tolist()
    train_labels = train_df["label"].tolist()
    train_label_tsr = torch.tensor(train_labels).flatten().long()

    test_data = test_df["instruction"].tolist()
    test_labels = test_df['label'].tolist()
    test_label_tsr =  torch.tensor(test_labels).flatten().long()

    train_data = [template.format(instruction=s) for s in train_data]
    test_data = [template.format(instruction=s) for s in test_data]

    return train_data, train_label_tsr, test_data, test_label_tsr

def hallucination_answer_dataset(data_path,template,test_frac):
    data_df = pd.read_csv(data_path)
    test_df = data_df.sample(frac=test_frac, random_state=0)
    train_df = data_df[~data_df.index.isin(test_df.index)]

    train_data = train_df["instruction"].tolist()
    train_data_ans = train_df["response"].tolist()
    train_labels = train_df["label"].tolist()
    train_label_tsr = torch.tensor(train_labels).flatten().long()

    test_data = test_df["instruction"].tolist()
    # test_data_ans = test_df["response"].tolist()
    test_labels = test_df['label'].tolist()
    test_label_tsr = torch.tensor(test_labels).flatten().long()

    train_data = [template.format(instruction=s)+f"{r}" for s,r in zip(train_data, train_data_ans)]
    test_data = [template.format(instruction=s) for s in test_data]

    return train_data, train_label_tsr, test_data, test_label_tsr

if __name__=='__main__':
    hallucination_dataset("./data/hallu_dec.csv","[INST] {instruction} [/INST]",0.2)