from dotenv import load_dotenv
load_dotenv()
import pandas as pd 
import pickle
import os 
BIOS_DATA_PATH = os.environ['BIOS_DATA_PATH']
import torch.nn.functional as F


def load_dataset(
    data_path: str = BIOS_DATA_PATH, 
    train_file_name: str = 'bios_train.pickle',
    dev_file_name: str = 'bios_dev.pickle',
    test_file_name: str = 'bios_test.pickle'
) -> dict[str, pd.DataFrame]:
    filename_dict = dict(zip(['train', 'dev', 'test'], [train_file_name, dev_file_name, test_file_name]))
    datasets = {}
    for split, fname in filename_dict.items():
        with open(os.path.join(data_path, fname), 'rb') as f:
            df = pd.DataFrame(pickle.load(f))[[
                'hard_text_untokenized', 'p'
            ]]
            df = df.rename(columns={'hard_text_untokenized': 'text', 'p': 'label'})
            datasets[split] = df

    return datasets
