from utils import DataLoader, find_best_parameters, compute_quadratic_features, matrix_sqrt, iso_scale, normalize
import numpy as np 
import matplotlib.pyplot as plt
from sklearn.preprocessing import PolynomialFeatures, scale
from sklearn.model_selection import train_test_split
from classifier import QuadraticNuclear
from sklearn.svm import LinearSVC
import pickle
import torch
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("output_name", help="ouput file name")
args = parser.parse_args()


np.random.seed(0)
random_seeds = [1, 7, 50, 808, 300]

'''
SET DATASET NAMES HERE
----------------------
'''
datasets = ['diabetes'] #, 'heart', 'splice', 'liver-disorders']

'''
----------------------
'''

experimental_data = {'svm_before':{fname:[] for fname in datasets},
                     'svm_after': {fname:[] for fname in datasets},
                     'nuc_before': {fname:[] for fname in datasets},
                     'nuc_after': {fname:[] for fname in datasets} }

for random_seed in random_seeds:
    # Setting some hidden seeds back to zero
    np.random.seed(0)

    #Printing current cross-val seed
    print("SEED ", random_seed)

    for fname in datasets:

        print("\n### DATASET {} ### \n".format(fname))
        data = DataLoader(fname)
        data.compute_quadratic_features()
        

        #Before scaling
        print("--- BEFORE SCALING ---")
        
        data.split(test_size=0.2, random_seed=random_seed)
       
        #SVM Classification
        C_lst = [0.0000001, 0.000001, 0.00001, 0.0001, 0.001, 0.01, 0.1, 1, 10]
        svm_clf = LinearSVC(C=0.1, dual=False, max_iter=10000)
        best_svm_params = find_best_parameters(svm_clf, {'C':C_lst}, data.X_train, data.y_train, n_folds_cv=4, logfile=fname+'_log')
        print(best_svm_params)
        svm_clf.C = best_svm_params['C']
        svm_clf.fit(data.X_train, data.y_train)
        train_accuracy = svm_clf.score(data.X_train, data.y_train)
        test_accuracy = svm_clf.score(data.X_test, data.y_test)

        print("SVM train/test : ", train_accuracy, test_accuracy)
        experimental_data['svm_before'][fname].append((train_accuracy, test_accuracy))
        
        
        #Nuclear Classification
        nuclear_clf = QuadraticNuclear(lr=0.00005, lam=1)
        lam = [0.0001, 0.001, 0.01, 0.1, 1, 10, 50, 100, 500, 1000, 10000]

        X_train, X_test, y_train, y_test = train_test_split(data.X, data.y, 
                                        test_size = 0.2, random_state = random_seed)

        best_nuclear_params = find_best_parameters(nuclear_clf, {'lam':lam}, X_train, y_train, n_folds_cv=4, logfile=fname+'_log')
        print("Best NuclearQuad param : ", best_nuclear_params)
        nuclear_clf.lam = best_nuclear_params['lam']
        nuclear_clf.fit(X_train, y_train, n_epoch=20)
        nuc_train_accuracy = nuclear_clf.score(X_train, y_train)
        nuc_test_accuracy = nuclear_clf.score(X_test, y_test)


        experimental_data['nuc_before'][fname].append((nuc_train_accuracy, nuc_test_accuracy))
        print("Nuclear train/test :", nuc_train_accuracy, nuc_test_accuracy)
        

        #Scaling
        print("------------ AFTER SCALING ------------")
        
        X_train, X_test, y_train, y_test = train_test_split(data.X, data.y, 
                                        test_size = 0.2, random_state = random_seed)

        training_mean = np.mean(X_train, axis=0)
        X_train_scaled, training_covariance = iso_scale(scale(X_train, with_std=False))
        X_test_scaled = normalize(X_test - training_mean, training_covariance)

        
        X_quad_train, X_quad_test = compute_quadratic_features(X_train_scaled, X_test_scaled, homogeneous=False)
        #SVM Classification
        C_lst = [0.0000001, 0.000001, 0.00001, 0.0001, 0.001, 0.01, 0.1, 1, 10]
        svm_clf = LinearSVC(C=0.1, dual=False, max_iter=10000)
        best_svm_params = find_best_parameters(svm_clf, {'C':C_lst}, X_quad_train, y_train, n_folds_cv=4, logfile=fname+'_log')
        print(best_svm_params)
        svm_clf.C = best_svm_params['C']
        svm_clf.fit(X_quad_train, y_train)
        train_accuracy = svm_clf.score(X_quad_train, y_train)
        test_accuracy = svm_clf.score(X_quad_test, y_test)
        print("SVM train/test : ", train_accuracy, test_accuracy)
        experimental_data['svm_after'][fname].append((train_accuracy, test_accuracy))
        
        #Nuclear Classification
        nuclear_clf = QuadraticNuclear(lr=0.00005, lam=1)
        #lam = np.linspace(0.01, 1000, 25)

        best_nuclear_params = find_best_parameters(nuclear_clf, {'lam':lam}, X_train_scaled, y_train, n_folds_cv=4, logfile=fname+'_log')
        print("Best NuclearQuad param : ", best_nuclear_params)
        nuclear_clf.lam = best_nuclear_params['lam']
        nuclear_clf.fit(X_train_scaled, y_train, n_epoch=20)
        nuc_train_accuracy = nuclear_clf.score(X_train_scaled, y_train)
        nuc_test_accuracy = nuclear_clf.score(X_test_scaled, y_test)
        experimental_data['nuc_after'][fname].append((nuc_train_accuracy, nuc_test_accuracy))
        print("Nuclear train/test :", nuc_train_accuracy, nuc_test_accuracy)


experiment_file = open(args.output_name, "wb")
pickle.dump(experimental_data, experiment_file)
experiment_file.close()


