# the code is modified from scikit-learn tutorial:
# https://scikit-learn.org/stable/auto_examples/svm/plot_rbf_parameters.html#sphx-glr-auto-examples-svm-plot-rbf-parameters-py
print(__doc__)

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
from sklearn import model_selection

from sklearn.svm import SVC
from sklearn.preprocessing import StandardScaler
from sklearn.datasets import load_iris,load_breast_cancer,load_wine

import sys
sys.path.append('../')
from aaaicode import lib
from mpl_toolkits.axes_grid1 import make_axes_locatable



# Utility function to move the midpoint of a colormap to be around
# the values of interest.

class MidpointNormalize(Normalize):

    def __init__(self, vmin=None, vmax=None, midpoint=None, clip=False):
        self.midpoint = midpoint
        Normalize.__init__(self, vmin, vmax, clip)

    def __call__(self, value, clip=None):
        x, y = [self.vmin, self.midpoint, self.vmax], [0, 0.5, 1]
        return np.ma.masked_array(np.interp(value, x, y))

def draw_Figure4(datasetname):
    # #############################################################################
    # Load and prepare data set

    if 'iris' in datasetname:
        iris = load_iris()
        X_original = iris.data
        y_original = iris.target
    if 'cancer' in datasetname:
        cancer = load_breast_cancer()
        X_original = cancer.data
        y_original = cancer.target
    if 'wine' in datasetname:
        wine = load_wine()
        X_original = wine.data
        y_original = wine.target

    scaler = StandardScaler()
    X_original = scaler.fit_transform(X_original)

    X,X_hold, y, y_hold = model_selection.train_test_split(X_original,y_original,test_size=0.00000001,random_state=42)


    # #############################################################################
    # Train classifiers
    #
    # For an initial search, a logarithmic grid with basis
    # 10 is often helpful. Using a basis of 2, a finer
    # tuning can be achieved but at a much higher cost.

    C_range = np.logspace(-2, 8, 100)


    gamma_range = np.logspace(-9, 3, 100)


    cvscorelist = []
    pvscorelist = []
    cnt = 0
    totallen = len(C_range)*len(gamma_range)
    for C in C_range:
        for gamma in gamma_range:
            cnt+=1
            clf = SVC(C=C, gamma=gamma)
            pvreturn = lib.get_PV_classic(clf,X,y)
            pvscorelist.append(pvreturn)

            cvresult = model_selection.cross_val_score(clf, X, y, cv=3).mean()
            cvscorelist.append(cvresult)

            print(str(cnt)+' in '+str(totallen))


    pvscores = np.array(pvscorelist).reshape(len(C_range), len(gamma_range))
    cvscores = np.array(cvscorelist).reshape(len(C_range), len(gamma_range))


    # #############################################################################
    # Visualization

    # Draw heatmap of the validation accuracy as a function of gamma and C
    #
    # The score are encoded as colors with the hot colormap which varies from dark
    # red to bright yellow. As the most interesting scores are all located in the
    # 0.92 to 0.97 range we use a custom normalizer to set the mid-point to 0.92 so
    # as to make it easier to visualize the small variations of score values in the
    # interesting range while not brutally collapsing all the low score values to
    # the same color.

    fig, axs = plt.subplots(1, 2, figsize=(10, 3))
    plt.setp(axs, xticks=[0, len(gamma_range)-1], xticklabels=[gamma_range[0], gamma_range[-1]])
    plt.setp(axs, yticks=[0, len(gamma_range)-1], yticklabels=[C_range[0], C_range[-1]])
    i1 = axs[0].imshow(cvscores, interpolation='nearest', cmap=plt.cm.hot,
                   norm=MidpointNormalize(vmin=0.2, midpoint=0.92))
    axs[0].set_title('CV')
    axs[0].set_xlabel('gamma')
    axs[0].set_ylabel('C')
    divider = make_axes_locatable(axs[0])
    cax = divider.append_axes("right", size="5%", pad=0.05)
    fig.colorbar(i1,ax=axs[0],cax=cax)

    i2=  axs[1].imshow(pvscores, interpolation='nearest', cmap=plt.cm.hot,
              norm=MidpointNormalize(vmin=0.2, midpoint=0.92))
    axs[1].set_title('PV')
    axs[1].set_xlabel('gamma')
    divider = make_axes_locatable(axs[1])
    cax = divider.append_axes("right", size="5%", pad=0.05)
    fig.colorbar(i1,ax=axs[1],cax=cax)


    fig.subplots_adjust(left=.05, right=0.95, bottom=0.15, top=0.95)
    plt.savefig("./../plots/svm"+datasetname+".pdf",bbox_inches='tight')
    plt.show()

if __name__ == '__main__':
    draw_Figure4('iris')
    #draw_Figure5('iris')


