import os
import numpy as np
import sys
sys.path.append("config")
import config_path

import pmlb
from ucimlrepo import fetch_ucirepo 

import dataset_info
import importlib
importlib.reload(dataset_info)

datasests_by_rawdata = ["DMFT", "Led7", "Mofn", "XD6", "ThreeOfNine", "GermanGSS", "BCW", "PTumor", "PPD", "Vehicle", "ConfAd", "Coronary", "AsiaLung", "Cleveland", "Sensory"]

def get_coords_from_data(dataset_name, target=True):

    ids = {
            "Nursery":76,
            "SPECT":95,
            "CarEvaluation":19,
            "Monk":70,
            "SolarFlare":89,
            "Credit":27,
            "Income":20,
            "Tumor":83,
            "TicTacToe":101,
            "Chess":22,
            "Chess2":23,
            "Connect4":26,
            "Hayesroth":44,
            "BalanceScale":12,
            "Lenses":58,
            "Lymphography":63,
            "Mushroom":73,
            "Votes":105,
            "DTCR":915
            }

    if dataset_name in ids.keys():
        repo = fetch_ucirepo(id=ids[dataset_name])
        X = repo.data.features
        y = repo.data.targets
        if dataset_name == "Credit":
            X = X.drop(["A2","A3","A8","A11","A14","A15"], axis=1)
            X = X.dropna()
            #y = y.loc[X.index]
        if dataset_name == "DTCR":
            X = X.drop(["Age"], axis=1)
        if dataset_name == "Mushroom":
            X = X.drop(["stalk-root"], axis=1)
        if dataset_name == "Votes":
            X = X.dropna()
            y = y.loc[X.index]
        if dataset_name == "Tumor":
            X = X.drop(["axillar","skin","sex","histologic-type", "degree-of-diffe"], axis=1)
        if dataset_name == "Hayesroth":
            X = X[0:132]
            y = y[0:132]
        if dataset_name == "Lymphography":
            X = X.loc[:,~X.columns.duplicated()]
            X = X.iloc[:,:-1] # remove N/A
        if dataset_name == "Income":
            X = X.drop(columns=["age", "fnlwgt", "education-num", "capital-gain", "capital-loss", "hours-per-week"])
        X = X.join(y)

        print(y)
        print(X)
        X_np = X.to_numpy()
        print("load UCI data done")

    elif dataset_name in datasests_by_rawdata:
        base_dir = config_path.data_repo_real
        data_path = os.path.join(base_dir, "raw_data", dataset_name, f"{dataset_name}.data")
        with open(data_path, "r", encoding="utf-8") as f:
            lines = f.read().splitlines()
        X = [line.split(',') for line in lines]
        X_np = np.array(X)
        if dataset_name == "BCW":
            # since the first row is id, we cut it
            X_np = X_np[:,1:]
        if dataset_name == "PTumor":
            columns_to_delete = [2, 3]
            X_np = np.delete(X_np, columns_to_delete, axis=1)
        if dataset_name == "ConfAd":
            # isPresent is the target
            # Thus, we switch it
            X_np[:, [3, 6]] = X_np[:, [6, 3]]

    elif dataset_name == "Parity5p5":
        X_np = np.array( pmlb.fetch_data("parity5+5"))

    else:
        error("dataset name error")

    # Show all categories.
    print("Categories...")
    sizes = []

    if dataset_name in datasests_by_rawdata or dataset_name in ["Parity5p5"]:
        D = len(X_np[0]) # tensor dim
        if dataset_name == "Parity5p5":
            D = D - 1
        N = len(X_np)    # the number of observed elements
        for d in range(D):
            m = np.unique(X_np[:,d])
            print(m)
            sizes.append(len(m))

        print("tensor size is", tuple(sizes))

        D_in_info = dataset_info.tensor_dims[dataset_name]
        tensor_size_in_info = dataset_info.tensor_sizes[dataset_name]

        assert D_in_info == len(sizes), f"dataset_info includes wrong infromation in f{dataset_name}"
        assert tuple(tensor_size_in_info) == tuple(sizes), f"dataset_info includes wrong infromation f{dataset_name}"


        if dataset_name == "GermanGSS":
            names_col = ["Political_system", "Age", "Time", "Schooling", "Region", "Class"]
        elif dataset_name == "Parity5p5":
            names_col = ["a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9", "a10", "class"]
        elif dataset_name == "DMFT":
            names_col = ["begin", "end", "gender", "ethnic", "Class"]
        elif dataset_name == "Sensory":
            names_col = ["Occasion", "Judges", "Interval", "Sittings", "Position", "Squares", "Rows", "Columns", "Halfplot", "Trellis", "Method", "Class"]
        elif dataset_name == "Led7":
            names_col = ["1", "2", "3", "4", "5", "6", "7", "Class"]
        elif dataset_name == "Mofn":
            names_col = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "class"]
        elif dataset_name == "XD6":
            names_col = ["1", "2", "3", "4", "5", "6", "7", "8", "9", "class"]
        elif dataset_name == "Coronary":
            names_col = ["Smoking", "M.work", "P.work", "Pressure", "Proteins", "Family"]
        elif dataset_name == "ThreeOfNine":
            names_col = ["1", "2", "3", "4", "5", "6", "7", "8", "9", "class"]
        elif dataset_name == "BCW":
            names_col = ["1", "2", "3", "4", "5", "6", "7", "8", "9", "class"]
        elif dataset_name == "Vehicle":
            names_col = ["Alchol", "Gender", "Type", "Age", "class"]
        elif dataset_name == "Cleveland":
            names_col = ["sex", "cp", "fbs", "restecg", "exang", "slope", "thal", "class"]
        elif dataset_name == "AsiaLung":
            names_col = ["dysponea", "tuberculosis", "lungcancer", "bronchitis", "visit2asia", "smoking", "Xray", "tvl"]
        elif dataset_name == "ConfAd":
            names_col = ["uni", "tshirtsize", "favsubject", "regDataCat", "wegan", "participation", "isPresent"]
        elif dataset_name == "PPD":
            names_col = ["Lcore", "LSurf", "Lo2", "Lbp", "Sstbl", "Cstbl", "bpstbl", "comfort", "class"]
        elif dataset_name == "PTumor":
            names_col = ["Age", "sex", "bone", "bone_marrow", "lung", "pleura", "peritoneum", "liver", "brain", "skin", "neck", "suprac", "axillar", "mediastinum", "abdominal", "class"]
        else:
            print("dataset_name error")

        # Get dict to see correspondence between category and number
        atts = {}
        for d, col in enumerate(names_col):
            nuq = np.unique(X_np[:,d])
            J = len(nuq)
            att = { nuq[j] : j for j in range(J) }
            atts[d] = att

    else:
    #For UCI datasets
        for col in X:
            m = X[col].unique()
            print(m)
            sizes.append(len(m))

        print("tensor size is", tuple(sizes))
        D_in_info = dataset_info.tensor_dims[dataset_name]
        tensor_size_in_info = dataset_info.tensor_sizes[dataset_name]

        assert D_in_info == len(sizes), f"dataset_info includes wrong infromation in f{dataset_name}"
        assert tuple(tensor_size_in_info) == tuple(sizes), f"dataset_info includes wrong infromation f{dataset_name}"

        names_col = X.columns
        # tensor dim
        D = len(names_col)

        # observed elements
        N = len(X)

        # Get dict to see correspondence between category and number
        atts = {}
        for d, col in enumerate(names_col):
            nuq = X[col].unique()
            J = len(nuq)
            att = { nuq[j] : j for j in range(J) }
            atts[d] = att

    # To make npy file in COO format, prepare integer matrix
    X_np_coords = np.zeros((D, N), dtype='int64')
    for n in range(N):
        categories = X_np[n,:]
        for d in range(D):
            integer = atts[d][categories[d]]
            X_np_coords[d, n] = int(integer)
    np.random.shuffle(np.transpose(X_np_coords))
    return atts, X_np_coords

def sep_train_valid_test(X_np_coords):
    # Seprate train, valid, and test
    N = np.shape(X_np_coords)[1]

    idx_begin_train = 0
    idx_end_train   = int(N*0.7)
    idx_begin_valid = idx_end_train
    idx_end_valid   = idx_end_train + int(N*0.15) + 1
    idx_begin_test  = idx_end_valid
    idx_end_test    = idx_end_valid + int(N*0.15) + 1

    X_train = np.transpose(X_np_coords[:,idx_begin_train:idx_end_train])
    X_valid = np.transpose(X_np_coords[:,idx_begin_valid:idx_end_valid])
    X_test  = np.transpose(X_np_coords[:,idx_begin_test :idx_end_test])

    n_train = np.shape(X_train)[0]
    n_valid = np.shape(X_valid)[0]
    n_test  = np.shape(X_test)[0]
    assert N == n_train + n_valid + n_test, "data is leaking!!"
    
    return X_train, X_valid, X_test

def tuple_skipping_m(N, m):
    """
    For example,
    tuple_skipping_m(5,2) = (0,1,3,4)
    tuple_skipping_m(7,3) = (0,1,2,4,5,6,7)
    tuple_skipping_m(4,1) = (0,2,3)
    """
    return tuple(i for i in range(N) if i != m)

def reduce_same_indices(X):
    Xuni, Xval = np.unique(X, axis=0, return_counts=True)
    return Xuni, Xval

def no_empty(dataset_name, coords):
    D = dataset_info.tensor_dims[dataset_name]
    tensor_size = dataset_info.tensor_sizes[dataset_name]
    for d in range(D):
        unique_n_idx = len(np.unique(coords[:,d]))
        if unique_n_idx != tensor_size[d]:
            return False
    return True 

def prep_npys(dataset_name):
    cnt = 0
    while(True):
        atts, X_np_coords = get_coords_from_data(dataset_name)
        tensor_dim = len(atts)

        X_train, X_valid, X_test = sep_train_valid_test(X_np_coords)
        uniX_train_coords, uniX_train_values = reduce_same_indices(X_train)
        uniX_valid_coords, uniX_valid_values = reduce_same_indices(X_valid)
        uniX_test_coords, uniX_test_values = reduce_same_indices(X_test)
        
        assert cnt < 25, "This dataset is too sparse..."
        if no_empty(dataset_name, uniX_train_coords): 
            print("there is no empty label")
            break
        else:
            print("empty label exists. shuffle again..")
            cnt += 1

    dataset_dir = config_path.data_repo_real
    save_path_train_data_coords = os.path.join(dataset_dir, dataset_name, "X_train_coords.npy")
    save_path_train_data_values = os.path.join(dataset_dir, dataset_name, "X_train_values.npy")
    
    save_path_valid_data_coords = os.path.join(dataset_dir, dataset_name, "X_valid_coords.npy")
    save_path_valid_data_values = os.path.join(dataset_dir, dataset_name, "X_valid_values.npy")
    
    save_path_test_data_coords = os.path.join(dataset_dir, dataset_name, "X_test_coords.npy")
    save_path_test_data_values = os.path.join(dataset_dir, dataset_name, "X_test_values.npy")
    
    save_path_attribute  = os.path.join(dataset_dir, dataset_name, "att2int.npy")
    # when you read att dict, please write as follows:
    # np.load("../data/car_eval/att2int.npy", allow_pickle="TRUE").item()

    overwrite_save = True
    np.save(save_path_train_data_coords, uniX_train_coords)
    np.save(save_path_train_data_values, uniX_train_values)
    
    np.save(save_path_valid_data_coords, uniX_valid_coords)
    np.save(save_path_valid_data_values, uniX_valid_values)
    
    np.save(save_path_test_data_coords,  uniX_test_coords)
    np.save(save_path_test_data_values,  uniX_test_values)
    
    np.save(save_path_attribute,  atts) 

    print("saved in", save_path_train_data_coords)
    print("saved in", save_path_train_data_values)

    print("saved in", save_path_valid_data_coords)
    print("saved in", save_path_valid_data_values)

    print("saved in", save_path_test_data_coords)
    print("saved in", save_path_test_data_values)

#datasets_name  = "Chess"
#datasets_name  = "Chess2"
#datasets_name  = "Income"
#datasets_name  = "TicTacToe"
#datasets_name  = "Connect4"
#datasets_name  = "Hayesroth"
#datasets_name  = "BalanceScale"
#datasets_name  = "Lenses"
#datasets_name  = "Lymphography"
#datasets_name  = "Mushroom"
#datasets_name  = "Votes"
#datasets_name  = "Tumor"
#dataset_name = "DTCR"
#dataset_name = "PTumor"
#dataset_name = "PPD"
#dataset_name = "Vehicle"
#dataset_name = "ConfAd"
#dataset_name = "Coronary"
#dataset_name = "AsiaLung"
#dataset_name = "Cleveland"
#dataset_name = "Sensory"

dataset_name = "SolarFlare"
print(f"running {dataset_name}")
prep_npys(dataset_name)