import numpy as np
import pandas as pd
import json
import os
import random
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, MinMaxScaler, OneHotEncoder, LabelEncoder
path = '.'#os.path.dirname(os.path.abspath(__file__))
def get_dataset(dataname):
    '''
    Get dataset from the data folder, remove columns with more than 20% missing values and remove rows with missing values
    Args:
        dataname: str, name of the dataset
    Returns:
        data: pandas dataframe, dataset
        info: dict, information about the dataset
    '''
    if not os.path.exists(f'{path}/data/{dataname}/{dataname}.csv') and not os.path.exists(f'{path}/data/{dataname}/{dataname}_info.json'):
        assert False, f'Dataset {dataname} not found'
        return None
    data = pd.read_csv(f'{path}/data/{dataname}/{dataname}.csv')
    info = json.load(open(f'{path}/data/{dataname}/{dataname}_info.json'))
    N_cols = info['N_cols']
    C_cols = info['C_cols']

    #remove columns with more than 20% missing values
    for col in N_cols:
        if data[col].isnull().sum() > 0.2*data.shape[0]:
            data = data.drop(columns=[col])
            info['N_cols'].remove(col)
    for col in C_cols:
        if data[col].isnull().sum() > 0.2*data.shape[0]:
            data = data.drop(columns=[col])
            info['C_cols'].remove(col)

    #remove rows with missing values
    data = data.dropna()

    # # convert missing values of categorical columns nan to '?'
    # data[C_cols] = data[C_cols].fillna('?')
    return data, info
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)

def get_dataset_few_shot_generated(dataname,shot,seed,gpt_shot,example=False,description=False):
    set_seed(seed)
    if example:
        df = pd.read_csv(f'{path}/data/generated/{dataname}/{dataname}-{shot}-{seed}-{example}-{description}.csv')
    else:
        df = pd.read_csv(f'{path}/data/generated/{dataname}/{dataname}-{example}-{description}.csv')
    info = json.load(open(f'{path}/data/{dataname}/{dataname}_info.json'))
    unique,count = np.unique(df[info['target']],return_counts=True)
    N_cols = info['N_cols']
    C_cols = info['C_cols']
    target = info['target']
    y = df[[target]].copy().to_numpy().flatten()
    unique,count = np.unique(y,return_counts=True)
    assert (count >= gpt_shot).all(), f'Class {unique[count < gpt_shot]} has less than {gpt_shot} samples'
    index = []
    for target_class in unique.tolist():
        target_index = np.where(y == target_class)[0]
        target_index = np.random.choice(target_index,gpt_shot,replace=False)
        index += target_index.tolist()
    return df.iloc[index]
def get_dataset_few_shot(dataname,shot,seed):
    '''
    Get few shot dataset
    Args:
        dataname: str, name of the dataset
        shot: int, number of samples per class
        seed: int, random seed
    Returns:
        df: pandas dataframe, dataset
        info: dict, information about the dataset
        X: dict, features
            {
                'support_x_n': numpy array, numerical features of support set
                'support_x_c': numpy array, categorical features of support set
                'query_x_n': numpy array, numerical features of query set
                'query_x_c': numpy array, categorical features of query set
                'unlabeled_x_n': numpy array, numerical features of unlabeled set
                'unlabeled_x_c': numpy array, categorical features of unlabeled set
            }
        y: dict, labels
            {
                'support_y': numpy array, labels of support set
                'query_y': numpy array, labels of query set
            }
    '''
    set_seed(seed)
    df, info = get_dataset(dataname)
    target = info['target']
    assert info['task_type'] == 'binclass' or info['task_type'] == 'multiclass', 'Only support classification task'
    uniuqe = df[target].unique().tolist()
    X_train, X_test, y_train, y_test = train_test_split(df.drop(columns=[target]), df[[target]], test_size=0.2, random_state=seed)

    
    # sample shot samples from each class
    # reindex
    X_train = X_train.reset_index(drop=True)
    y_train = y_train.reset_index(drop=True)

    #shuffle and split
    labeld_index = []
    unlabeled_index = []
    for target_class in uniuqe:
        index = y_train[y_train[target] == target_class].index
        assert len(index) >= shot, f'Class {target_class} has less than {shot} samples'
        labeld_index += list(index[:shot])
    labeld_index = np.array(labeld_index)
    unlabeled_index = np.array(list(set(range(X_train.shape[0])) - set(labeld_index)))
    X_train_labeld = X_train.loc[labeld_index]
    y_train_labeld = y_train.loc[labeld_index]
    X_train_unlabeld = X_train.loc[unlabeled_index]
    y_train_unlabeld = y_train.loc[unlabeled_index]

    N_cols = info['N_cols']
    C_cols = info['C_cols']

    y_support = y_train_labeld[target].values.flatten()
    y_query = y_test[target].values.flatten()
    y_unlabeled = y_train_unlabeld[target].values.flatten()
    # y_support_num = np.zeros_like(y_train_labeld).astype(np.int64)
    # y_query_num = np.zeros_like(y_test).astype(np.int64)
    # for i, label in enumerate(uniuqe):
    #     y_query_num[y_test == label] = i
    #     y_support_num[y_train_labeld == label] = i

    X = {
        'support_x_n': X_train_labeld[N_cols].values if len(N_cols) > 0 else None,
        'support_x_c': X_train_labeld[C_cols].values if len(C_cols) > 0 else None,
        'query_x_n': X_test[N_cols].values if len(N_cols) > 0 else None,
        'query_x_c': X_test[C_cols].values if len(C_cols) > 0 else None,
        'unlabeled_x_n': X_train_unlabeld[N_cols].values if len(N_cols) > 0 else None,
        'unlabeled_x_c': X_train_unlabeld[C_cols].values if len(C_cols) > 0 else None
    }
    y = {
        'support_y': y_support,
        'query_y': y_query,
        'unlabeled_y': y_unlabeled
    }
    return df, info, X, y
def XY2df(X,y,info):
    N_cols = info['N_cols']
    C_cols = info['C_cols']
    if len(N_cols) > 0:
        X['support_x_n'] = pd.DataFrame(X['support_x_n'],columns=N_cols)
        X['query_x_n'] = pd.DataFrame(X['query_x_n'],columns=N_cols)
        X['unlabeled_x_n'] = pd.DataFrame(X['unlabeled_x_n'],columns=N_cols)
    if len(C_cols) > 0:
        X['support_x_c'] = pd.DataFrame(X['support_x_c'],columns=C_cols)
        X['query_x_c'] = pd.DataFrame(X['query_x_c'],columns=C_cols)
        X['unlabeled_x_c'] = pd.DataFrame(X['unlabeled_x_c'],columns=C_cols)
    target = info['target']
    y['support_y'] = pd.DataFrame(y['support_y'],columns=[target])
    y['query_y'] = pd.DataFrame(y['query_y'],columns=[target])
    y['unlabeled_y'] = pd.DataFrame(y['unlabeled_y'],columns=[target])
    return X,y

def label2num(support_y,query_y):
    '''
    Convert labels to numbers
    Args:
        support_y: numpy array, labels of support set
        query_y: numpy array, labels of query set
    Returns:
        support_y_num: numpy array, numerical labels of support set
        query_y_num: numpy array, numerical labels of query set
    '''
    support_y_num = np.zeros(support_y.shape).astype(np.int64)
    query_y_num = np.zeros(query_y.shape).astype(np.int64)
    uniuqe = np.unique(support_y)
    for i, label in enumerate(uniuqe):
        support_y_num[support_y == label] = i
        query_y_num[query_y == label] = i
    return support_y_num, query_y_num

def get_cat2num_info(dataname):
    info = json.load(open(f'{path}/data/{dataname}/cat2num.json'))
    return info





