from dotenv import load_dotenv
load_dotenv()
import pandas as pd 
import os 
import json

GOEMOTIONS_DATA_PATH = os.environ['GOEMOTIONS_DATA_PATH']


def filter_multiple_targets(df: pd.DataFrame, target_col: str='label') -> pd.DataFrame: 
    # drop rows with multiple labels
    df = df[df[target_col].apply(lambda x: len(x) == 1)]
    # unwrap labels
    df.loc[:, target_col] = df[target_col].apply(lambda x: x[0])
    return df



def load_dataset(
    data_path: str=GOEMOTIONS_DATA_PATH,
    train_file_name: str='train.tsv',
    dev_file_name: str='dev.tsv',
    test_file_name: str='test.tsv',
    ekman: bool=False
) -> dict[str, pd.DataFrame]:

    with open(f'{data_path}/emotions.txt') as f:
        emotions = f.readlines()
    id_to_emotion = {i: e.strip() for i, e in enumerate(emotions)}

    if ekman:
        # map str -> list[str] of ekman categories to emotion labels 
        with open(f'{data_path}/ekman_mapping.json', 'r') as f:
            ekman_mapping = json.load(f)

        label_to_ekman = {'neutral': 'neutral'}
        for ekman, labels in ekman_mapping.items():
            for label in labels:
                label_to_ekman[label] = ekman
        

    filename_dict = dict(zip(['train', 'dev', 'test'], [train_file_name, dev_file_name, test_file_name]))
    datasets = {}
    for split, fname in filename_dict.items():
        df = pd.read_csv(
            f'{data_path}/{fname}', 
            sep='\t', 
            header=None,
            names=['text', 'label', 'comment_id']
        )
        df['label'] = df['label'].apply(lambda x: [id_to_emotion[int(i)] for i in x.split(',')])
        df = df.drop(columns=['comment_id'])
        df = filter_multiple_targets(df)

        if ekman:
            df['label'] = df['label'].apply(lambda x: label_to_ekman[x]) 

        datasets[split] = df

    return datasets
