from scipy.io import loadmat, savemat
from scipy.io import arff
import h5py
import numpy as np
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import warnings
warnings.filterwarnings('ignore')


def load_data(data_path):
    if 'http' in data_path or 'smtp' in data_path:
        h5py_data = h5py.File(data_path)
        init_data = np.transpose(h5py_data["X"])
        init_label = np.transpose(h5py_data["y"])
    elif 'arff' in data_path:
        arff_data = arff.loadarff(data_path)
        init_data = pd.DataFrame(arff_data[0])
        init_data["outlier"] = init_data["outlier"].apply(lambda x:x.decode('utf-8'))
        init_label = init_data["outlier"].apply(lambda x:0 if 'no' in x else 1).values.tolist()
        init_data.drop(columns = ["id","outlier"],inplace=True)
        init_data = init_data.values
    else:
        mat_data = loadmat(data_path)
        init_data = mat_data["X"]
        init_label = mat_data["y"]
    init_key = []
    for size in range(init_data.shape[1]):
        init_key.append("col"+str(size))
    init_data = pd.DataFrame(init_data, columns = init_key)
    init_label = list(map(int, init_label))
    contamination = sum(init_label) / len(init_label)

    return init_data, init_label, contamination


def analysis_data(count, neg, data_path, ts2  = True, ts3 = False):
    data, label, contamination = load_data(data_path)
    summary_key = data_path.split('/')[-1].split('.')[0]
    print('data shape {}, inlier\'s number {}, outlier\'s number {}'.format(data.shape, int(data.shape[0]*(1-contamination)), int(data.shape[0]*contamination)))
    scaler = MinMaxScaler()
    scaler.fit(data)
    data = scaler.fit_transform(data)

    if neg:
        neg_data = np.random.uniform(0,1, data.shape)
        data = np.concatenate((data, neg_data))
        neg_label = [2]* len(label)
        label = label + neg_label

    print(data.min(), data.max())
    # data = pd.DataFrame(data)
    # print(data.describe())

    if ts2:
        print('tsne-2d')
        ts = TSNE(n_components = 2, init='pca', random_state=0)
        data_ts = ts.fit_transform(data)
        plt.figure(figsize=(9,9))
        for i in range(data_ts.shape[0]):
            plt.scatter(data_ts[i, 0], data_ts[i, 1], color = plt.cm.Set1(int(label[i])))
        plt.title(summary_key)
        if not neg:
            plt.savefig('tsne_plot/origin_dataset/'+ str(count) +'_'+summary_key+'.jpg')
        else:
            plt.savefig('tsne_plot/origin_plus_neg_dataset/'+ str(count) +'_'+summary_key+'.jpg')

    
    if ts3:
        print('tsne-3d')
        ts3 = TSNE(n_components = 3, init='pca', random_state=0)
        data_ts3 = ts3.fit_transform(data)
        fig = plt.figure(figsize=(9,9))
        ax = Axes3D(fig)
        for i in range(data_ts3.shape[0]):
            ax.scatter3D(data_ts3[i, 0], data_ts3[i, 1], data_ts3[i, 2], color = plt.cm.Set1(int(label[i])))

data_path_list = ["init/musk.mat","init/breastw.mat","arff/KDDCup99_idf.arff", "arff/WDBC_withoutdupl_v01.arff"]

for data_path in data_path_list:
    neg = True
    analysis_data(count, neg, './data/'+data_path)
    count += 1