import statistics

import pandas
from sklearn import model_selection
from sklearn.datasets import make_moons,make_classification,make_circles
import numpy as np
from sklearn.tree import DecisionTreeClassifier
import matplotlib.pyplot as plt
import sys
sys.path.append('../')
from aaaicode import lib


def get_data(datasetname,datasize):
    '''
    Description
        This function returns data based on the name of a dataset

    Parameters
        datasetname: 'connect','moon','linear','circle'

    Returns
        X_train,y_train,X_hold,y_hold
    '''

    if 'connect' in datasetname:
        names = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18',
                 '19', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '30', '31', '32', '33', '34', '35',
                 '36', '37', '38', '39', '40', '41', '42', 'result']
        traindatafilepath = "./../dataset/connect.data"

        labelnum = 42
        headlist = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18',
                    '19', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '30', '31', '32', '33', '34',
                    '35', '36', '37', '38', '39', '40', '41', '42']
        whole_datset_use = pandas.read_csv(traindatafilepath, names=names)

        lib.fit_string_data(whole_datset_use, headlist)

        array = whole_datset_use.values
        X_train = array[:, 0:labelnum]
        y_train = array[:, labelnum]

        X_train, X_hold, y_train, y_hold = model_selection.train_test_split(X_train, y_train, test_size=67557-datasize,
                                                                            random_state=42)


    if 'moon' in datasetname:
        X_train, y_train = make_moons(noise=0.0, random_state=0, n_samples=datasize)
    if 'circle' in datasetname:
        X_train, y_train = make_circles(noise=0.0, random_state=0, n_samples=datasize)
    if 'linear' in datasetname:
        X, y = make_classification(n_samples=datasize, n_features=2, n_redundant=0, n_informative=1, random_state=42,
                                   n_clusters_per_class=1, flip_y=-1)
        rng = np.random.RandomState(2)
        X += 2 * rng.uniform(size=X.shape)
        X_train = X
        y_train = y

    return X_train,y_train

def draw_Figure6(datasetname):
    '''
    Parameters
        datasetname: 'moon','linear','circle','connect'
    Returns
        a subfigure in Figure 6
    '''

    if 'connect' in datasetname:
        datasizelist = np.arange(1000,10000,2000)
    else:
        datasizelist = np.arange(100,1000,200)

    plt.figure(figsize=(4, 4))
    ax = plt.subplot(111)

    plt.figtext(0.5, 0.9, datasetname, fontsize=25, ha='center')
    plt.ylabel("MV", fontsize=25)
    plt.tick_params(axis='x', labelsize=22)
    plt.tick_params(axis='y', labelsize=22)


    for datasize in datasizelist:
        pvmeanlist = []
        pvstdlist = []
        pvmean_add_std = []
        pvmean_subs_std = []
        datasize = int(datasize)
        pvlist = []
        pvaveragelist = []
        cvlist = []
        testlist = []
        if 'connect' in datasetname:
            depthlist = np.arange(6, 26, 2)
            ax.set_ylim([0.55, 0.8])
        else:
            depthlist = np.arange(1, 20, 2)
            ax.set_ylim([-0.05, 1.05])
        for depth in depthlist:
            for j in range(0, 11):
                print(j)
                X_train,y_train = get_data(datasetname,datasize)
                dt = DecisionTreeClassifier(max_depth=depth)
                pv = lib.get_PV_classic(dt,X_train,y_train)
                pvlist.append(pv)
            pvaveragelist.append(statistics.mean(pvlist))


        start, end = int(depthlist[0]), int(depthlist[-1])


        plt.xticks(np.arange(start - 3, end + 1, 4.0))
        plt.plot(depthlist, pvaveragelist, 'o-', label=str(datasize))

    plt.gray()
    plt.legend(fontsize=18)
    plt.savefig("./../plots/" + datasetname + "-difftrain.pdf", bbox_inches='tight')
    plt.show()




if __name__ == '__main__':
    draw_Figure6('moon') # or 'circle','linear','connect'

