import numpy as np
import scipy
import torch
import argparse

from sklearn.metrics import accuracy_score
from sklearn.model_selection import GridSearchCV
from sklearn.svm import SVC

from gntk_src.gntk import GNTK
from gcn_src import load_data


#from torch.utils.tensorboard import SummaryWriter

def parsers_parser():
    parser = argparse.ArgumentParser(description='GNTK computation')
    parser.add_argument('--num_mlp_layers', type=int, default=2, help='number of mlp layers')
    parser.add_argument('--num_layers', type=int, default=2, help='number of layers')
    parser.add_argument('--scale', type=str, default='uniform', help='scaling methods, [uniform, degree]')
    parser.add_argument('--jk', type=int, default=0, help='whether to add jk')
    parser.add_argument('--init_label_per_class', type=int, default=5)
    parser.add_argument('--budget_num_per_query', type=int, default=15)
    parser.add_argument('--total_query_times', type=int, default=1)
    parser.add_argument('--norm', type=int, default=1)
    parser.add_argument('--save_log', type=int, default=0)
    parser.add_argument('--model_name', type=str, default='GNTK')
    args = parser.parse_args()
    return args


def compute_gram(args, adj, features):
    gntk = GNTK(num_layers=args.num_layers, num_mlp_layers=args.num_mlp_layers, jk=args.jk, scale=args.scale)
    diag_list = gntk.diag(features, adj)
    gram = gntk.gntk(features, diag_list, adj)
    return gram


def kernel_svm(gram, labels, idx_train, idx_val, norm=True, save_log=False):
    labels = np.where(labels)[1]
    model_name_args = "{}_norm:{}_init_label_num:{}_budget_num_per_query:{}".format(args.model_name, args.norm,
                                                                args.init_label_per_class, args.budget_num_per_query)
    writer = SummaryWriter(comment=model_name_args)

    gram /= gram.min()
    if norm:
        gram_diag = np.sqrt(np.diag(gram))
        gram /= gram_diag[:, None]
        gram /= gram_diag[None, :]

    X_train, y_train, X_test, y_test = gram[idx_train,:][:,idx_train], labels[idx_train], gram[idx_val,:][:,idx_train], labels[idx_val]
    C_list = np.logspace(-2, 4, 120)
    svc = SVC(kernel='precomputed', cache_size=16000, max_iter=5e5, probability=True)
    clf = GridSearchCV(svc, {'C': C_list}, verbose=0, return_train_score=True)

    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_test)
    test_acc = accuracy_score(y_test, y_pred)
    print(test_acc)
    y_pred = clf.predict(X_train)
    train_acc = accuracy_score(y_train, y_pred)
    print(train_acc)
    if save_log:
        writer.add_scalar('train/acc', train_acc, 0)
        writer.add_scalar('test/acc', test_acc, 0)


def kernel_regression(gram, labels, idx_train, idx_val, norm=True, save_log=False):
    model_name_args = "{}_norm:{}_init_label_num:{}_budget_num_per_query:{}".format(args.model_name, args.norm,
                                                                args.init_label_per_class, args.budget_num_per_query)
    writer = SummaryWriter(comment=model_name_args)

    gram /= gram.min()
    if norm:
        gram_diag = np.sqrt(np.diag(gram))
        gram /= gram_diag[:, None]
        gram /= gram_diag[None, :]

    X_train, y_train, X_test, y_test = gram[idx_train, :][:, idx_train], labels[idx_train], gram[idx_val, :][:,
                                                                                            idx_train], labels[idx_val]

    u = X_test.dot(scipy.linalg.solve(X_train, y_train))
    y_pred = np.argmax(u, axis=1)
    test_acc = accuracy_score(np.where(y_test)[1], y_pred)
    print(test_acc)
    u = X_train.dot(scipy.linalg.solve(X_train, y_train))
    y_pred = np.argmax(u, axis=1)
    train_acc = accuracy_score(np.where(y_train)[1], y_pred)
    print(train_acc)

    if save_log:
        writer.add_scalar('train/acc', train_acc, 0)
        writer.add_scalar('test/acc', test_acc, 0)

    # # Solve kernel regression.
    # Y_train = np.ones((N_train, 10)) * -0.1
    # for i in range(N_train):
    #     Y_train[i][y_train[i]] = 0.9
    # u = H[N_train:, :N_train].dot(scipy.linalg.solve(H[:N_train, :N_train], Y_train))
    # print
    # "test accuracy:", 1.0 * np.sum(np.argmax(u, axis=1) == y_test) / N_test


def query_selection(args, gram, idx_train, idx_val):
    if args.budget_num_per_query > 0:
        pass



if __name__ == '__main__':
    args = parsers_parser()
    # load data
    adj, features, labels, idx_train, idx_val = load_data(init_num_per_class=args.init_label_per_class)
    # compute gram
    gram = compute_gram(args, adj, features)
    # kernel regression
    # kernel_svm(gram, labels, idx_train, idx_val, args.norm, args.save_log)
    kernel_regression(gram, labels, idx_train, idx_val, args.norm, args.save_log)
