from sklearn import model_selection
from sklearn.datasets import make_moons,make_classification,make_circles
import numpy as np
from sklearn.tree import DecisionTreeClassifier
import matplotlib.pyplot as plt
import sys
sys.path.append('../')
from aaaicode import lib



def get_data(datasetname,datasize,noisedegree):
    '''
    Description
        This function returns data based on the name of a dataset

    Parameters
        datasetname: 'connect','moon','linear','circle'
        datasize: the size of training data
        noisedegree: the degree of noise in the original training data

    Returns
        X_train,y_train,X_hold,y_hold
    '''

    if 'moon' in datasetname:
        X_train, y_train = make_moons(noise=noisedegree, random_state=1, n_samples=datasize*2)
    if 'circle' in datasetname:
        X_train, y_train = make_circles(noise=noisedegree, random_state=1, n_samples=datasize*2)
    if 'linear' in datasetname:
        X, y = make_classification(n_samples=datasize*2, n_features=2, n_redundant=0, n_informative=1, random_state=42,
                                   n_clusters_per_class=1, flip_y=noisedegree)
        rng = np.random.RandomState(2)
        X += 2 * rng.uniform(size=X.shape)
        X_train = X
        y_train = y
    X_train, X_test, y_train, y_test = model_selection.train_test_split(X_train, y_train, test_size=0.5)

    return X_train,y_train,X_test,y_test

def draw_Figure8(datasetname,noisedegree,metricname):
    '''
    Parameters
        datasetname: 'moon','linear','circle'
        noisedegree: degree of noise in the original training data
        metricname: 'pv' or 'test'; when it is 'test', the figure shows the test accuracy

    Returns
        a subfigure in Figure 7
    '''

    datasizelist = [1000,10000,100000]
    #datasizelist = [1000, 10000]

    plt.figure(figsize=(3, 3))
    if 'moon' in datasetname:
        if noisedegree < 0.2:
            if 'test' in metricname:
                plt.ylabel("Test accuracy",fontsize = 19)
            else:
                plt.ylabel("PV score", fontsize=19)
        else:
            if 'test' in metricname:
                plt.ylabel("Test accuracy\n (noisy training data)",fontsize = 19)
            else:
                plt.ylabel("PV score\n (noisy training data)", fontsize=19)
    ax = plt.subplot(111)

    plt.figtext(0.5, 0.9, datasetname, fontsize=18, ha='center')
   # plt.ylabel("PV", fontsize=18)

    plt.tick_params(axis='x', labelsize=15)
    plt.tick_params(axis='y', labelsize=15)
    plt.yticks(np.arange(0, 1.1, 0.5))


    for datasize in datasizelist:
        print('Datasize: '+str(datasize))
        datasize = int(datasize)
        X_train,y_train,X_test,y_test = get_data(datasetname,datasize,noisedegree)


        depthlist = np.arange(1, 20, 2)
        ax.set_ylim([-0.05, 1.05])


        pvlist = []
        testlist = []
        for depth in depthlist:
            print('Depth: '+str(depth))
            dt = DecisionTreeClassifier(max_depth=depth)
            dic = lib.get_allmetrics_classic(dt,X_train,y_train,X_test,y_test)
            pvlist.append(dic['pv'])
            testlist.append(dic['test'])

        start, end = int(depthlist[0]), int(depthlist[-1])
        plt.xticks(np.arange(start - 2, end + 1, 4.0))
        if 'PV' in metricname:
            plt.plot(depthlist, pvlist, 'o-', label=str(datasize))
        else:
            plt.plot(depthlist, testlist, 'x-', label=str(datasize))
    plt.gray()

    plt.legend(fontsize=15)
    plt.savefig('./../plots/'+datasetname+'-'+str(noisedegree)+'-'+metricname+'-large-difftrain.pdf',bbox_inches='tight')
    plt.show()




if __name__ == '__main__':
    draw_Figure7('moon',0.0,'PV')
    draw_Figure7('moon', 0.0, 'test')

