from matplotlib.artist import ArtistInspector
import pandas as pd
import numpy as np
import torch
from sklearn import preprocessing




def data_preprocessing(dataset_name):
    if dataset_name in ["appendicitis", "wisconsin", "australian", "german", "titanic", "phoneme", "spambase", "segment0", "page-blocks0"]:
        # onehot encoding(if needed).
        attributes_info = [i.strip().split(" ") for i in open("./data_raw/%s.dat"%dataset_name).readlines()]
        attributes_info = attributes_info[:attributes_info.index(["@data"])]
        attributes_need_onehot = dict()
        for a in attributes_info:
            if "@attribute" in a:
                if "{" in a[2]:
                    attributes_need_onehot[a[1]] = True
                else:
                    attributes_need_onehot[a[1]] = False
        data = [i.strip().split(",") for i in open("./data_raw/%s.dat"%dataset_name).readlines()]
        data = data[data.index(["@data"])+1:]
        dataframe = pd.DataFrame(data=data, columns=attributes_need_onehot.keys())
        dataframe_X = dataframe.iloc[:, 0:-1]
        dataframe_y = dataframe.iloc[:, -1]
        dataframe_X = pd.get_dummies(dataframe_X, columns=[k for k, v in attributes_need_onehot.items() if v == True][:-1])
        X = np.array(dataframe_X.values.tolist()).astype("float32")
        y = np.array(dataframe_y.values.tolist())
        # y -> {0, 1}.
        le = preprocessing.LabelEncoder(); le.fit(y)
        y = np.array(le.transform(y)).astype("int")
    """
    elif dataset_name == "spambase":
        with open("./data_raw/%s.data"%dataset_name, "r") as f:
            data_str = f.read()
            data_str = data_str.replace("\n", ",")
            data = np.fromstring(data_str, sep=",").reshape((4601, 58))
            X = data[:, 0:-2]
            y = data[:, -1]
    """
    # unique, counts = np.unique(y, return_counts=True)
    # print(dict(zip(unique, counts)))
    X = torch.tensor(X)
    y = torch.tensor(y)
    y = y.masked_fill(mask=(y == 0).bool(), value=-1).reshape((X.shape[0], 1))
    save_data(ins=X, label=y, dataset_name=dataset_name)


def save_data(ins, label, dataset_name, path="./data/"):
    torch.save(ins, path+"ins_%s.pt"%(dataset_name))
    torch.save(label, path+"lable_%s.pt"%(dataset_name))




for dataset_name in ["appendicitis", "wisconsin", "australian", "german", "titanic", "phoneme", "spambase", "segment0", "page-blocks0"]:
    data_preprocessing(dataset_name)