import os
import string
import pandas as pd
import numpy as np
from datasets import load_dataset
from utils.util import letters, construct_data_path

paths_and_names = {
    'logiqa':              ('lucasmccabe/logiqa', 'default'),
    'aqua':                ('aqua_rat', 'raw'),
    'truthfulqa':          ('truthful_qa', 'multiple_choice'),
}

splits = {
    'logiqa':              ('train', 'test'),
    'aqua':                ('train', 'test'),
    'truthfulqa':          ('validation',  None),
}

titles = {
    'logiqa':              'LogiQA',
    'aqua':                'AQuA',
    'truthfulqa':          'TruthfulQA',
}

def postprocess(data_name, df):
    """
    Postprocess the dataframe for the given dataset.
    """
    if data_name not in splits.keys():
        raise ValueError(f"Unknown dataset name: {data_name}")
    elif data_name == 'logiqa':
        df["label"] = df["correct_option"].apply(lambda correct_option: letters[correct_option])
        df["question"] = df["context"] + '\n' + df["query"]
    elif data_name == 'aqua':
        df["label"] = df["correct"]
        df["options"] = df["options"].apply(lambda options: [option.lstrip()[2:].lstrip().rstrip() for option in options])
    return df

def insert_in_options(row):
    options = row['options'].tolist()
    options.insert(letters.index(row['label']), row['best answer'])
    return options

def shuffle_options_and_labels(row):
    # Extract choices and labels
    choices, labels = row['options'], row['labels']
    combined = list(zip(choices, labels))
    np.random.shuffle(combined)
    shuffled_choices, shuffled_labels = zip(*combined)
    return pd.Series({'options': list(shuffled_choices), 'labels': list(shuffled_labels)})

def stratified_sample(data, col, size, random_state=42):
    # Stratified sampling based on the 'col' column
    stratified_data =  data.groupby(col, group_keys=False).apply(lambda x: x.sample(int(np.ceil(size * len(x) / len(data))),
                                                                                    random_state=random_state))
    
    # Sample the stratified data (since we use ceil, we might have more samples than needed)
    stratified_data = stratified_data.sample(n=size, random_state=random_state)
    return stratified_data

def load_truthful_qa_splits():
    # Load the dataset
    truthfulqa_dataset = pd.DataFrame(load_dataset("truthful_qa", "multiple_choice")['validation'])
    truthfulqa_attribute_data = pd.DataFrame(load_dataset("truthful_qa", "generation")['validation'])

    # Extract options and labels from mc1_targets
    truthfulqa_dataset['options'] = truthfulqa_dataset['mc1_targets'].apply(lambda x: x['choices'])
    truthfulqa_dataset['labels'] = truthfulqa_dataset['mc1_targets'].apply(lambda x: x['labels'])

    # Shuffle the options and labels within each row (answer is currently always A)
    np.random.seed(42)
    truthfulqa_dataset[['options', 'labels']]  = truthfulqa_dataset.apply(shuffle_options_and_labels, axis=1)
    truthfulqa_dataset['label'] = truthfulqa_dataset['labels'].apply(lambda x: letters[x.index(1)])

    # Map questions to categories and add category column to dataset
    question_category_map = dict(
        zip(
            truthfulqa_attribute_data['question'].apply(lambda question: question.strip()).to_list(), 
            truthfulqa_attribute_data['category'].to_list()
        )
    )
    truthfulqa_dataset['category'] = truthfulqa_dataset['question'].apply(lambda question: question_category_map[question.strip()])

    # Stratified sample for train split
    truthfulqa_train_dataset = stratified_sample(
        truthfulqa_dataset, 
        'category',
        int(np.ceil(0.8*len(truthfulqa_dataset)))
    )

    # Use remainder for test split
    truthfulqa_test_dataset = truthfulqa_dataset[~truthfulqa_dataset.index.isin(truthfulqa_train_dataset.index)]

    return truthfulqa_train_dataset, truthfulqa_test_dataset

def load_data_subset(config):
    """
    Load a subset of the data for the given dataset, split, number of samples and seed.
    If the directory does not exist, save the data to a parquet file.
    Otherwise, load the data from the parquet file.
    config: dictionary, configuration dictionary
            including dataset (the name) and dataset_params (which contains split, n and seed)
    returns: pandas DataFrame, subset of the data
    """
    # Initialize variables
    data_path = construct_data_path(config)
    data_name, data_params = config['dataset'], config['dataset_params']
    split, n, seed = data_params['split'], data_params['n'], data_params['seed']

    # Load the parquet file if it exists
    if os.path.exists(data_path):
        return pd.read_parquet(data_path)
    
    # Load the data from the dataset and save it to a parquet file
    elif data_name in splits.keys():
        # Get the path and name of the dataset, and the split
        path_arg, name_arg = paths_and_names[data_name]
        split_arg = splits[data_name][0] if split == 'train' else splits[data_name][1]

        # TruthfulQA only has a train split, so load/process it separately
        if data_name == 'truthfulqa':
            data = load_truthful_qa_splits()[0] if split == 'train' else load_truthful_qa_splits()[1]
        else:
            data = pd.DataFrame(load_dataset(path=path_arg, name=name_arg, split=split_arg))
        data = postprocess(data_name, data)
        data_subset = data.sample(n=n, random_state=seed)

        # Postprocess, sample, and save to parquet file
        if not os.path.exists(f'datasets/{data_name}/'):
            os.makedirs(f'datasets/{data_name}/')
        data_subset.to_parquet(data_path)

    # Unknown dataset name
    else:
        raise ValueError(f"Unknown dataset name: {data_name}")

    return data_subset