import os
import numpy  as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
from sklearn.model_selection import train_test_split
# from matplotlib import pyplot as plt
from sklearn.datasets import make_classification
# import seaborn as sns
np.random.seed(0)

def process_mnist():
            
    data_path = os.path.join("data/", "mnist")
    
    with open(data_path + '/train.npy', 'rb') as f:
        x_train = np.load(f)
        y_train = np.load(f).reshape(-1, 1)

    with open(data_path + '/test.npy', 'rb') as f:
        x_test = np.load(f)
        y_test = np.load(f).reshape(-1, 1)

    x_train = x_train.reshape(-1, 28 * 28) / 255.
    x_test = x_test.reshape(-1, 28 * 28) / 255.

    train = np.concatenate([x_train, y_train], axis=1)
    test = np.concatenate([x_test, y_test], axis=1)


    # train, test = train_test_split(x_train, test_size=0.2, shuffle=True, random_state=randomseed) 
    train, val = train[int(len(test)/2):], train[:int(len(test)/2)]

    print(train.shape, val.shape, test.shape)
    return pd.DataFrame(train), pd.DataFrame(val), pd.DataFrame(test)

def process_toy():
    # X, y = make_classification(n_samples=2000, n_features=2, n_informative=2, n_redundant=0,
    #                         n_clusters_per_class=1, random_state=6)
    # data = np.concatenate([X, y.reshape(-1, 1)], axis=1)

    
    # data = np.array(processed_data).T
    data = np.array(pd.read_csv("data/ring_16.csv"))
    data = np.concatenate([data[:, 2:], data[:, :2]], axis=1)
    processed_data = []
    for i in range(4):
        processed_data.append((data[:, i] - data[:, i].min()) / (data[:, i].max() - data[:, i].min()))
    custom_palette = sns.color_palette("tab20", 4)
    data = pd.DataFrame(np.array(processed_data).T)

    # train = pd.DataFrame(np.concatenate([train, val]))
    data.columns = ["D1","D2","c1","c2",]

    plot = sns.scatterplot(x="c1", y="c2", hue="D1", palette=custom_palette, data=data, s=15,linewidth=0, legend=False)
    plot.get_figure().savefig(f"toydata_orig.png")
    # exit()

    # plt.scatter(data[:, 0], data[:, 1], marker='o', c=data[:, 2],
    #             s=100, edgecolor="k", linewidth=2)
    # plt.savefig("toydata_orig.png")
    data = np.array(data)
    train = pd.DataFrame(data[:6000]).astype(float)
    val = pd.DataFrame(data[6000:7000]).astype(float)
    test = pd.DataFrame(data[7000:]).astype(float)
    print(train.shape)
    print(val.shape)
    print(test.shape)
    return train, val, test

def process_spambase():
    randomseed = 1
    data = pd.read_csv("data/spambase.data")

    # data.insert(0, 'target', data.iloc[:, -1])
    # data = data.iloc[:, :-1]

    train, test = train_test_split(data, test_size=0.2, shuffle=True, random_state=randomseed) 
    train, val = train[int(len(test)/2):], train[:int(len(test)/2)]
    print(train.shape)
    print(val.shape)
    print(test.shape)
    train, val, test = np.array(train), np.array(val), np.array(test)
    # categorical = []
    # continuous = list(set(list(range(train.shape[-1]))) - set(categorical))

    # encoder = LabelEncoder()
    # train_encoded_data = []
    # val_encoded_data = []
    # test_encoded_data = []
    # for i in range(train.shape[-1]):
    #     train_curr = train[:, i].copy()
    #     val_curr = val[:, i].copy()
    #     test_curr = test[:, i].copy()
    #     # train_curr_ = (train_curr - np.min(train_curr)) / (np.max(train_curr) - np.min(train_curr))
    #     # val_curr_ = (val_curr - np.min(val_curr)) / (np.max(val_curr) - np.min(val_curr))
    #     # test_curr_ = (test_curr - np.min(test_curr)) / (np.max(test_curr) - np.min(test_curr))

    #         # standardization
    #         # train_curr_ = (train_curr - np.mean(train_curr)) / np.std(train_curr) * 4
    #         # test_curr_ = (test_curr - np.mean(test_curr)) / np.std(test_curr) * 4
    #     train_encoded_data.append(train_curr)
    #     val_encoded_data.append(val_curr)
    #     test_encoded_data.append(test_curr)
    # train = np.concatenate(train_encoded_data, axis=1)
    # val = np.concatenate(val_encoded_data, axis=1)
    # test = np.concatenate(test_encoded_data, axis=1)

    train = pd.DataFrame(train).astype(float)
    val = pd.DataFrame(val).astype(float)
    test = pd.DataFrame(test).astype(float)

    return train, val, test

def process_adult():
    train = pd.read_csv("data/adult_train.csv")
    test = pd.read_csv("data/adult_test.csv")

    # test.insert(0, 'target', test['label'])
    # test = test.iloc[:, :-1]

    # train.insert(0, 'target', train['label'])
    # train = train.iloc[:, :-1]

    train = np.array(train.dropna(axis=0))
    test = np.array(test.dropna(axis=0))

    train, val = train[int(len(test)/2):], train[:int(len(test)/2)]
    print(train.shape)
    print(val.shape)
    print(test.shape)

    # categorical = [0, 2, 4, 6,7,8,9,10,14] # adult
    categorical = [1, 3, 5, 6, 7, 8, 9, 13, 14]

    encoder = LabelEncoder()
    train_encoded_data = []
    val_encoded_data = []
    test_encoded_data = []

    for i in range(train.shape[-1]):
        train_curr = train[:, i].copy()
        val_curr = val[:, i].copy()
        test_curr = test[:, i].copy()
        if i in categorical:
            train_curr = train_curr.reshape(-1, 1)
            val_curr = val_curr.reshape(-1, 1)
            test_curr = test_curr.reshape(-1, 1)

            encoder = LabelEncoder()
            encoder.fit(train_curr)

            train_curr_ = encoder.transform(train_curr) # / (len(encoder.classes_) - 1)
            val_curr_ = encoder.transform(val_curr) #/(len(encoder.classes_) - 1)
            test_curr_ = encoder.transform(test_curr) # /(len(encoder.classes_) - 1)

            train_encoded_data.append(train_curr_.reshape(-1, 1))
            val_encoded_data.append(val_curr_.reshape(-1, 1))
            test_encoded_data.append(test_curr_.reshape(-1, 1))
        else:
            train_encoded_data.append(train_curr.reshape(-1, 1))
            val_encoded_data.append(val_curr.reshape(-1, 1))
            test_encoded_data.append(test_curr.reshape(-1, 1))

    train = np.concatenate(train_encoded_data, axis=1)
    val = np.concatenate(val_encoded_data, axis=1)
    test = np.concatenate(test_encoded_data, axis=1)

    train = pd.DataFrame(train).astype(float)
    val = pd.DataFrame(val).astype(float)
    test = pd.DataFrame(test).astype(float)

    return train, val, test

def process_balance():
    train = pd.read_csv("data/balance-scale.data")

    # train = np.array(train.dropna(axis=0))
    # test = np.array(test.dropna(axis=0))


    randomseed = 3
    train, test = train_test_split(train, test_size=0.2, shuffle=True, random_state=randomseed) 
    train = np.array(train)
    test = np.array(test)
    train, val = train[int(len(test)/2):], train[:int(len(test)/2)]

    print(train.shape)
    print(val.shape)
    print(test.shape)

    # categorical = [0, 2, 4, 6,7,8,9,10,14] # adult
    categorical = [0,1,2,3,4]

    encoder = LabelEncoder()
    train_encoded_data = []
    val_encoded_data = []
    test_encoded_data = []

    for i in range(train.shape[-1]):

        train_curr = train[:, i].copy()
        val_curr = val[:, i].copy()
        test_curr = test[:, i].copy()

        train_curr = train_curr.reshape(-1, 1)
        val_curr = val_curr.reshape(-1, 1)
        test_curr = test_curr.reshape(-1, 1)

        encoder = LabelEncoder()
        encoder.fit(train_curr)

        train_curr_ = encoder.transform(train_curr) # / (len(encoder.classes_) - 1)
        val_curr_ = encoder.transform(val_curr) #/(len(encoder.classes_) - 1)
        test_curr_ = encoder.transform(test_curr) # /(len(encoder.classes_) - 1)

        train_encoded_data.append(train_curr_.reshape(-1, 1))
        val_encoded_data.append(val_curr_.reshape(-1, 1))
        test_encoded_data.append(test_curr_.reshape(-1, 1))

    train = np.concatenate(train_encoded_data, axis=1)
    val = np.concatenate(val_encoded_data, axis=1)
    test = np.concatenate(test_encoded_data, axis=1)

    train = pd.DataFrame(train).astype(float)
    val = pd.DataFrame(val).astype(float)
    test = pd.DataFrame(test).astype(float)

    return train, val, test


def process_phishing():
    data = pd.read_csv("data/WA_Fn-UseC_-Telco-Customer-Churn.csv", header=0)
    data = np.array(data.dropna(axis=0))[:, 1:]
    data = data[data[:, -2] != ' ']
    continuous = [4, 17, 18]
    categorical = list(set(list(range(data.shape[-1]))) - set(continuous))


    randomseed = 1
    train, test = train_test_split(data, test_size=0.15, shuffle=True, random_state=randomseed) 
    train = np.array(train)
    test = np.array(test)
    train, val = train[int(len(test)/2):], train[:int(len(test)/2)]

    train_encoded_data = []
    val_encoded_data = []
    test_encoded_data = []

    for i in range(train.shape[-1]):

        train_curr = train[:, i].copy()
        val_curr = val[:, i].copy()
        test_curr = test[:, i].copy()
        if i in categorical:
            train_curr = train_curr.reshape(-1, 1)
            val_curr = val_curr.reshape(-1, 1)
            test_curr = test_curr.reshape(-1, 1)

            encoder = LabelEncoder()
            encoder.fit(train_curr)

            train_curr_ = encoder.transform(train_curr) # / (len(encoder.classes_) - 1)
            val_curr_ = encoder.transform(val_curr) #/(len(encoder.classes_) - 1)
            test_curr_ = encoder.transform(test_curr) # /(len(encoder.classes_) - 1)

            train_encoded_data.append(train_curr_.reshape(-1, 1))
            val_encoded_data.append(val_curr_.reshape(-1, 1))
            test_encoded_data.append(test_curr_.reshape(-1, 1))
        else:
            train_encoded_data.append(train_curr.reshape(-1, 1))
            val_encoded_data.append(val_curr.reshape(-1, 1))
            test_encoded_data.append(test_curr.reshape(-1, 1))

    train = np.concatenate(train_encoded_data, axis=1)
    val = np.concatenate(val_encoded_data, axis=1)
    test = np.concatenate(test_encoded_data, axis=1)

    train = pd.DataFrame(train).astype(float)
    val = pd.DataFrame(val).astype(float)
    test = pd.DataFrame(test).astype(float)

    return train, val, test

def process_default():
    data = pd.read_csv("data/default.csv", header=1, index_col=0)
    data = np.array(data.dropna(axis=0))
    # data = data[data[:, -2] != ' ']
    continuous = [0, 2, 4, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22]
    categorical = list(set(list(range(data.shape[-1]))) - set(continuous))


    randomseed = 3
    train, test = train_test_split(data, test_size=0.2, shuffle=True, random_state=randomseed) 
    train = np.array(train)
    test = np.array(test)
    train, val = train[int(len(test)/2):], train[:int(len(test)/2)]

    train_encoded_data = []
    val_encoded_data = []
    test_encoded_data = []

    for i in range(train.shape[-1]):

        train_curr = train[:, i].copy()
        val_curr = val[:, i].copy()
        test_curr = test[:, i].copy()
        if i in categorical:
            train_curr = train_curr.reshape(-1, 1)
            val_curr = val_curr.reshape(-1, 1)
            test_curr = test_curr.reshape(-1, 1)

            encoder = LabelEncoder()
            encoder.fit(train_curr)

            train_curr_ = encoder.transform(train_curr) # / (len(encoder.classes_) - 1)
            val_curr_ = encoder.transform(val_curr) #/(len(encoder.classes_) - 1)
            test_curr_ = encoder.transform(test_curr) # /(len(encoder.classes_) - 1)

            train_encoded_data.append(train_curr_.reshape(-1, 1))
            val_encoded_data.append(val_curr_.reshape(-1, 1))
            test_encoded_data.append(test_curr_.reshape(-1, 1))
        else:
            train_encoded_data.append(train_curr.reshape(-1, 1))
            val_encoded_data.append(val_curr.reshape(-1, 1))
            test_encoded_data.append(test_curr.reshape(-1, 1))

    train = np.concatenate(train_encoded_data, axis=1)
    val = np.concatenate(val_encoded_data, axis=1)
    test = np.concatenate(test_encoded_data, axis=1)

    train = pd.DataFrame(train).astype(float)
    val = pd.DataFrame(val).astype(float)
    test = pd.DataFrame(test).astype(float)

    return train, val, test

def process_blog():
    train = np.array(pd.read_csv("data/blog/blogData_train.csv", header=None))
    test = []
    for i in range(1, 30):
        tmp = np.array(pd.read_csv(f"data/blog/blogData_test-2012.02.{str(i).zfill(2)}.00_00.csv", header=None))
        test.append(tmp)

    for i in range(1, 32):
        if i >=26:
            tmp = pd.read_csv(f"data/blog/blogData_test-2012.03.{str(i).zfill(2)}.01_00.csv", header=None)
        else:
            tmp = pd.read_csv(f"data/blog/blogData_test-2012.03.{str(i).zfill(2)}.00_00.csv", header=None)
        test.append(tmp)
    test = np.concatenate(test)
    # continuous: 262 263 
    
    train[:, 262] = train[:, 263:269].argmax(axis=1)
    train[:, 263] = train[:, 269:276].argmax(axis=1)

    train[:, 264] = train[:, 276]
    train[:, 265] = train[:, 277]
    train[:, 266] = train[:, 278]
    train[:, 267] = train[:, 279]
    train[:, 268] = (train[:, -1] == 0).astype(int)
    train = train[:, :269]
    
    test[:, 262] = test[:, 263:269].argmax(axis=1)
    test[:, 263] = test[:, 269:276].argmax(axis=1)
    test[:, 264] = test[:, 276]
    test[:, 265] = test[:, 277]
    test[:, 266] = test[:, 278]
    test[:, 267] = test[:, 279]

    test[:, 268] = (test[:, -1] == 0).astype(int)
    test = test[:, :269]

    train, val = train[int(len(test)/2):], train[:int(len(test)/2)]

    return pd.DataFrame(train), pd.DataFrame(val), pd.DataFrame(test)


def process_covtype(): #TODO 
    data = np.array(pd.read_csv("data/covtype.data", header=None))
    # num_col_data = data.shape[-1]
    # Wilderness_Area = np.argmax(data[:, 10:14], axis=1).reshape(-1, 1)
    # Soil_Type = np.argmax(data[:, 14:-1], axis=1).reshape(-1, 1)
    # data = np.concatenate([data[:, :10], Wilderness_Area, Soil_Type, data[:, -1].reshape(-1, 1)], axis=1)
    test = data[:11340, :]
    valid = data[11340:15120, :]
    train = data[15120:, :]

    continuous = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
    categorical = list(set(list(range(data.shape[-1]))[:-1]) - set(continuous))

    # x_train, y_train = train[:, :-1], train[:, -1]
    # x_valid, y_valid = valid[:, :-1], valid[:, -1]
    # x_test, y_test = test[:, :-1], test[:, -1]

    # encoder = LabelEncoder()
    # y_train = encoder.fit_transform(y_train)
    # y_valid = encoder.fit_transform(y_valid)
    # y_test = encoder.fit_transform(y_test)


    train_encoded_data = []
    valid_encoded_data = []
    test_encoded_data = []
    # cat_dims = []
    for i in range(train.shape[-1]-1):
        
        train_curr = train[:, i].copy()
        valid_curr = valid[:, i].copy()
        test_curr = test[:, i].copy()
        if i in categorical:
            train_curr = train_curr.reshape(-1, 1)
            valid_curr = valid_curr.reshape(-1, 1)
            test_curr = test_curr.reshape(-1, 1)

            encoder = LabelEncoder()
            encoder.fit(train_curr)
            # cat_dims.append(len(encoder.classes_))

            train_curr_ = encoder.transform(train_curr)
            valid_curr_ = encoder.transform(valid_curr)
            test_curr_ = encoder.transform(test_curr)

            train_encoded_data.append(train_curr_.reshape(-1, 1))
            valid_encoded_data.append(valid_curr_.reshape(-1, 1))
            test_encoded_data.append(test_curr_.reshape(-1, 1))

        else:
            # train_curr_ = (train_curr - np.min(train_curr)) / (np.max(train_curr) - np.min(train_curr))
            # valid_curr_ = (valid_curr - np.min(valid_curr)) / (np.max(valid_curr) - np.min(valid_curr))
            # test_curr_ = (test_curr - np.min(test_curr)) / (np.max(test_curr) - np.min(test_curr))
            train_encoded_data.append(train_curr.reshape(-1, 1))
            valid_encoded_data.append(valid_curr.reshape(-1, 1))
            test_encoded_data.append(test_curr.reshape(-1, 1))

    train = pd.DataFrame(np.concatenate(train_encoded_data, axis=1)).astype(float)
    valid = pd.DataFrame(np.concatenate(valid_encoded_data, axis=1)).astype(float)
    test = pd.DataFrame(np.concatenate(test_encoded_data, axis=1)).astype(float)
    # categorical.pop()
    # x_train, y_train = train.iloc[:, :-1], train.iloc[:, -1]
    # x_valid, y_valid = valid.iloc[:, :-1], valid.iloc[:, -1]
    # x_test, y_test = test.iloc[:, :-1], test.iloc[:, -1]
    return train, valid, test # y_train, x_test, y_test, x_valid, y_valid, categorical, continuous, cat_dims


    
  
def process_shoppers():
    data = pd.read_csv("data/online_shoppers_intention.csv")

    continuous = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
    categorical = list(set(list(range(data.shape[-1]))) - set(continuous))

    randomseed = 1
    train, test = train_test_split(data, test_size=0.2, shuffle=True, random_state=randomseed) 
    train = np.array(train)
    test = np.array(test)
    train, val = train[int(len(test)/2):], train[:int(len(test)/2)]

    train_encoded_data = []
    val_encoded_data = []
    test_encoded_data = []
    
    cat_dims = {}
    for i in range(train.shape[-1]):
        
        train_curr = train[:, i].copy()
        val_curr = val[:, i].copy()
        test_curr = test[:, i].copy()
        if i in categorical:
            train_curr = train_curr.reshape(-1, 1)
            val_curr = val_curr.reshape(-1, 1)
            test_curr = test_curr.reshape(-1, 1)

            encoder = LabelEncoder()
            encoder.fit(train_curr)
            cat_dims[i] = len(encoder.classes_)

            train_curr_ = encoder.transform(train_curr) # / (len(encoder.classes_) - 1)
            val_curr_ = encoder.transform(val_curr) #/(len(encoder.classes_) - 1)
            test_curr_ = encoder.transform(test_curr) # /(len(encoder.classes_) - 1)

            train_encoded_data.append(train_curr_.reshape(-1, 1))
            val_encoded_data.append(val_curr_.reshape(-1, 1))
            test_encoded_data.append(test_curr_.reshape(-1, 1))

        else:
            train_curr_ = train_curr
            val_curr_ = val_curr
            test_curr_ = test_curr
            # train_curr_ = (train_curr - np.min(train_curr)) / (np.max(train_curr) - np.min(train_curr))
            # val_curr_ = (val_curr - np.min(val_curr)) / (np.max(val_curr) - np.min(val_curr))
            # test_curr_ = (test_curr - np.min(test_curr)) / (np.max(test_curr) - np.min(test_curr))
            train_encoded_data.append(train_curr_.reshape(-1, 1))
            val_encoded_data.append(val_curr_.reshape(-1, 1))
            test_encoded_data.append(test_curr_.reshape(-1, 1))
    
    train = np.concatenate(train_encoded_data, axis=1)
    # train[:, :-1] =( train[:, :-1] - train[:, :-1].mean() ) / train[:, :-1].std()
    val = np.concatenate(val_encoded_data, axis=1)
    # val[:, :-1] = (val[:, :-1] - val[:, :-1].mean()) / val[:, :-1].std()
    test = np.concatenate(test_encoded_data, axis=1)
    # test[:, :-1] = ( test[:, :-1] - test[:, :-1].mean() )/   test[:, :-1].std()

    train = pd.DataFrame(train).astype(float)
    val = pd.DataFrame(val).astype(float)
    test = pd.DataFrame(test).astype(float)

    return train, val, test # , cat_dims



# def process_shoppers():
#     data = pd.read_csv("data/online_shoppers_intention.csv")

#     continuous = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
#     categorical = list(set(list(range(data.shape[-1]))) - set(continuous))

#     randomseed = 1
#     train, test = train_test_split(data, test_size=0.2, shuffle=True, random_state=randomseed) 
#     train = np.array(train)
#     test = np.array(test)
#     train, val = train[int(len(test)/2):], train[:int(len(test)/2)]

#     train_encoded_data = []
#     val_encoded_data = []
#     test_encoded_data = []
    
#     cat_dims = []
#     for i in range(train.shape[-1]):
        
#         train_curr = train[:, i].copy()
#         val_curr = val[:, i].copy()
#         test_curr = test[:, i].copy()
#         if i in categorical:
#             train_curr = train_curr.reshape(-1, 1)
#             val_curr = val_curr.reshape(-1, 1)
#             test_curr = test_curr.reshape(-1, 1)

#             encoder = OneHotEncoder(sparse=False)
#             encoder.fit(train_curr)
#             # cat_dims.append(len(encoder.classes_))

#             train_curr_ = encoder.transform(train_curr) # / (len(encoder.classes_) - 1)
#             val_curr_ = encoder.transform(val_curr) #/(len(encoder.classes_) - 1)
#             test_curr_ = encoder.transform(test_curr) # /(len(encoder.classes_) - 1)

#             train_encoded_data.append(train_curr_) #.reshape(-1, 1))
#             val_encoded_data.append(val_curr_) #.reshape(-1, 1))
#             test_encoded_data.append(test_curr_) #.reshape(-1, 1))

#         else:
#             train_curr_ = train_curr
#             val_curr_ = val_curr
#             test_curr_ = test_curr
#             train_curr_ = (train_curr - np.min(train_curr)) / (np.max(train_curr) - np.min(train_curr))
#             val_curr_ = (val_curr - np.min(val_curr)) / (np.max(val_curr) - np.min(val_curr))
#             test_curr_ = (test_curr - np.min(test_curr)) / (np.max(test_curr) - np.min(test_curr))
#             train_encoded_data.append(train_curr_.reshape(-1, 1))
#             val_encoded_data.append(val_curr_.reshape(-1, 1))
#             test_encoded_data.append(test_curr_.reshape(-1, 1))
    
#     train = np.concatenate(train_encoded_data, axis=1)
#     # train[:, :-1] =( train[:, :-1] - train[:, :-1].mean() ) / train[:, :-1].std()
#     val = np.concatenate(val_encoded_data, axis=1)
#     # val[:, :-1] = (val[:, :-1] - val[:, :-1].mean()) / val[:, :-1].std()
#     test = np.concatenate(test_encoded_data, axis=1)
#     # test[:, :-1] = ( test[:, :-1] - test[:, :-1].mean() )/   test[:, :-1].std()


    
#     train = pd.DataFrame(train).astype(float)
#     val = pd.DataFrame(val).astype(float)
#     test = pd.DataFrame(test).astype(float)

#     return train, val, test

