import os
import pickle
import itertools
import numpy as np
import pandas as pd


def get_all_groups(groups):
    groups_pairs = [list(x) for x in itertools.combinations(groups, 2)]
    groups = [[x] for x in groups]
    groups += groups_pairs
    groups = [list(np.sort(x)) for x in groups]
    return groups

ALL_POSSIBLE_GROUP_COLS = [
    'num_properties_binned',
    'nationality',
    'sex or gender',
    'sport',
    'IMDb ID',
]
ALL_POSSIBLE_GROUPS = get_all_groups(ALL_POSSIBLE_GROUP_COLS)

# pickle functions
def save_pickle(data, filename):
    try:
        with open(filename, 'wb') as file:
            pickle.dump(data, file)
    except Exception as e:
        print(f"An error occurred while saving to {filename}: {e}")

def load_pickle(filename):
    try:
        with open(filename, 'rb') as file:
            data = pickle.load(file)
        return data
    except Exception as e:
        print(f"An error occurred while loading from {filename}: {e}")
        return None


# data loading functions
def get_save_dir(
    base_save_dir='./data/all/',
    temperature=1.0,
    run=0,
):
    save_dir = os.path.join(base_save_dir, f'temperature={temperature}/run={run}')
    return save_dir

def load_data(
    save_dir,
    model_name='Llama2_7B_Chat',
    split='nq,'
):
    df_with_properties = pd.read_csv(os.path.join(save_dir, f'{model_name}_{split}_consistency+group_properties.csv'))
    df_with_properties.loc[:, 'correct'] = (df_with_properties['label'] == 'S').astype(bool)
    return df_with_properties

def split_data(df, seed, proportion_cal=0.8):
    rng = np.random.RandomState(seed)

    topics_unique = np.sort(df['topic'].unique())
    mask = rng.rand(len(topics_unique)) < proportion_cal
    topics_calibration = topics_unique[mask]

    mask_calibration = df['topic'].isin(topics_calibration)
    df_calibration = df.loc[mask_calibration].reset_index(drop=True)
    df_test = df.loc[~mask_calibration].reset_index(drop=True)

    return df_calibration, df_test

def split_data_biased(
    df, 
    seed,
    proportion_cal,
    bias_col,
    bias_value,
    bias_proportion,
):
    # split data where col != value
    mask = df[bias_col].astype(str) != bias_value
    df_without = df.loc[mask]
    df_without_calibration, df_without_test = split_data(df_without, seed, proportion_cal)

    # split topics where col == value
    df_topics = df[['topic', bias_col]].drop_duplicates()
    mask = df_topics[bias_col].astype(str) == bias_value
    df_topics_with = df_topics.loc[mask]
    test_size_with = int(len(df_topics_with) * bias_proportion)
    topics_test = df_topics_with.sample(test_size_with, random_state=seed)['topic'].values
    topics_calibration = df_topics_with.loc[~df_topics_with['topic'].isin(topics_test), 'topic'].values
    df_with_test = df[df['topic'].isin(topics_test)]
    df_with_calibration = df[df['topic'].isin(topics_calibration)]

    # combine
    df_calibration = pd.concat((df_with_calibration, df_without_calibration))
    df_test = pd.concat((df_with_test, df_without_test))

    # sort again (shouldn't be necessary, but just in case)
    df_calibration.sort_values(['topic', 'score', 'label'], inplace=True)
    df_test.sort_values(['topic', 'score', 'label'], inplace=True)

    return df_calibration, df_test