import os
import numpy as np
import dataset_info
import importlib
import pmlb
importlib.reload(dataset_info)
from ucimlrepo import fetch_ucirepo 

def get_coords_from_data(dataset_name):
    if dataset_name == "SPECT":
        spect_heart = fetch_ucirepo(id=95) 
        # data (as pandas dataframes) 
        X = spect_heart.data.features 
        y = spect_heart.data.targets 
        X_np = X.to_numpy()
        print("load UCI data done")

    elif dataset_name == "Chess":
        chess_king_rook_vs_king_pawn = fetch_ucirepo(id=22)
        X = chess_king_rook_vs_king_pawn.data.features 
        y = chess_king_rook_vs_king_pawn.data.targets

        X_np = X.to_numpy()
        print("load UCI data done")

    elif dataset_name == "SolarFlare":
        solar_flare = fetch_ucirepo(id=89) 
        # data (as pandas dataframes) 
        X = solar_flare.data.features 
        y = solar_flare.data.targets

        X_np = X.to_numpy()
        print("load UCI data done")

    elif dataset_name == "Lymphography":
        lymphography = fetch_ucirepo(id=63) 
        X = lymphography.data.features 
        y = lymphography.data.targets
        X = X.loc[:,~X.columns.duplicated()] 
        X = X.iloc[:, :-1] # remove NaN value

        X_np = X.to_numpy()
        print("load UCI data done")

    elif dataset_name == "Tumor":
        with open('rawdata/tumor', 'rb') as f:
            a = pickle.load(f)
        X = a[0]
        X_np = np.array( X.astype(int) )
        X = pd.DataFrame(X)

    elif dataset_name == "Votes":
        with open('rawdata/votes', 'rb') as f:
            a = pickle.load(f)
        X = a[0]
        X_np = np.array( X.astype(int) )
        X = pd.DataFrame(X)


    elif dataset_name == "DMFT":
        data_path = "rawdata/DMFT.data"
        with open(data_path, "r", encoding="utf-8") as f:
            lines = f.read().splitlines()
        X = [line.split(',') for line in lines]
        X_np = np.array(X)

    elif dataset_name == "Led7":
        X_np = np.array( pmlb.fetch_data("led7"))

    else:
        error("dataset name error")

    # Show all categories.
    print("Categories...")
    sizes = []

    if dataset_name in ["DMFT", "Led7"]:
        D = len(X_np[0]) # tensor dim
        N = len(X_np)    # the number of observed elements
        for d in range(D):
            m = np.unique(X_np[:,d])
            print(m)
            sizes.append(len(m))
        print("tensor size is", tuple(sizes))

        if dataset_name == "DMFT":
            names_col = ["begin", "end", "gender", "ethnic", "prevention"]
        elif dataset_name == "Led7":
            names_col = ["a1", "a2", "a3", "a4", "a5", "a6", "a7", "class"]
        else:
            println("dataset_name error")

        # Get dict to see correspondence between category and number
        atts = {}
        for d, col in enumerate(names_col):
            nuq = np.unique(X_np[:,d])
            J = len(nuq)
            att = { nuq[j] : j for j in range(J) }
            atts[d] = att

    else:
    #For UCI datasets
        for col in X:
            m = X[col].unique()
            print(m)
            sizes.append(len(m))
        print("tensor size is", tuple(sizes))

        names_col = X.columns
        # tensor dim
        D = len(names_col)

        # observed elements
        N = len(X)

        # Get dict to see correspondence between category and number
        atts = {}
        for d, col in enumerate(names_col):
            nuq = X[col].unique()
            J = len(nuq)
            att = { nuq[j] : j for j in range(J) }
            atts[d] = att

    # To make npy file in COO format, prepare integer matrix
    X_np_coords = np.zeros((D, N), dtype='int64')
    for n in range(N):
        categories = X_np[n,:]
        for d in range(D):
            integer = atts[d][categories[d]]
            X_np_coords[d, n] = int(integer)
    np.random.shuffle(np.transpose(X_np_coords))
    return atts, X_np_coords

def sep_train_valid_test(X_np_coords):
    # Seprate train, valid, and test
    N = np.shape(X_np_coords)[1]

    idx_begin_train = 0
    idx_end_train   = int(N*0.7)
    idx_begin_valid = idx_end_train
    idx_end_valid   = idx_end_train + int(N*0.15) + 1
    idx_begin_test  = idx_end_valid
    idx_end_test    = idx_end_valid + int(N*0.15) + 1

    X_train = np.transpose(X_np_coords[:,idx_begin_train:idx_end_train])
    X_valid = np.transpose(X_np_coords[:,idx_begin_valid:idx_end_valid])
    X_test  = np.transpose(X_np_coords[:,idx_begin_test :idx_end_test])

    n_train = np.shape(X_train)[0]
    n_valid = np.shape(X_valid)[0]
    n_test  = np.shape(X_test)[0]
    assert N == n_train + n_valid + n_test, "data is leaking!!"
    
    return X_train, X_valid, X_test

def tuple_skipping_m(N, m):
    """
    For example,
    tuple_skipping_m(5,2) = (0,1,3,4)
    tuple_skipping_m(7,3) = (0,1,2,4,5,6,7)
    tuple_skipping_m(4,1) = (0,2,3)
    """
    return tuple(i for i in range(N) if i != m)

def reduce_same_indices(X):
    Xuni, Xval = np.unique(X, axis=0, return_counts=True)
    return Xuni, Xval

def no_empty(dataset_name, coords):
    D = dataset_info.tensor_dims[dataset_name]
    tensor_size = dataset_info.tensor_sizes[dataset_name]
    for d in range(D):
        unique_n_idx = len(np.unique(coords[:,d]))
        if unique_n_idx != tensor_size[d]:
            return False
    return True 

def prep_npys(dataset_name):
    cnt = 0
    while(True):
        atts, X_np_coords = get_coords_from_data(dataset_name)
        tensor_dim = len(atts)

        X_train, X_valid, X_test = sep_train_valid_test(X_np_coords)
        uniX_train_coords, uniX_train_values = reduce_same_indices(X_train)
        uniX_valid_coords, uniX_valid_values = reduce_same_indices(X_valid)
        uniX_test_coords, uniX_test_values = reduce_same_indices(X_test)
        
        assert cnt < 25, "This dataset is too sparse..."
        if no_empty(dataset_name, uniX_train_coords): 
            print("there is no empty label")
            break
        else:
            print("empty label exists. shuffle again..")
            cnt += 1

    dataset_dir = "../data/"
    save_path_train_data_coords = os.path.join(dataset_dir, dataset_name, "X_train_coords.npy")
    save_path_train_data_values = os.path.join(dataset_dir, dataset_name, "X_train_values.npy")
    
    save_path_valid_data_coords = os.path.join(dataset_dir, dataset_name, "X_valid_coords.npy")
    save_path_valid_data_values = os.path.join(dataset_dir, dataset_name, "X_valid_values.npy")
    
    save_path_test_data_coords = os.path.join(dataset_dir, dataset_name, "X_test_coords.npy")
    save_path_test_data_values = os.path.join(dataset_dir, dataset_name, "X_test_values.npy")
    
    save_path_attribute  = os.path.join(dataset_dir, dataset_name, "att2int.npy")
    # when you read att dict, please write as follows:
    # np.load("att2int.npy", allow_pickle="TRUE").item()

    overwrite_save = True
    np.save(save_path_train_data_coords, uniX_train_coords)
    np.save(save_path_train_data_values, uniX_train_values)
    
    np.save(save_path_valid_data_coords, uniX_valid_coords)
    np.save(save_path_valid_data_values, uniX_valid_values)
    
    np.save(save_path_test_data_coords,  uniX_test_coords)
    np.save(save_path_test_data_values,  uniX_test_values)
    
    np.save(save_path_attribute,  atts) 
    

for dataset_name in ["Led7", "Chess", "Lymphography", "SPECT", "SolarFlare"]:
    prep_npys(dataset_name)

#dataset_name = "DMFT"
#prep_npys(dataset_name)
#dataset_name = "Tumor"
#prep_npys(dataset_name)
#dataset_name = "Votes"
#prep_npys(dataset_name)
