'''
Code adapted from: https://github.com/aradha/recursive_feature_machines/blob/pip_install/tabular_benchmark_experiments/main.py
which is based on: https://github.com/LeoYu/neural-tangent-kernel-UCI/tree/master
'''
import torch
import csv
import argparse
import os
import math
import numpy as np
from utils import convert_one_hot
from train_rfm import get_kernel, update
import utils

torch.set_default_dtype(torch.float64)

def train(X_train, y_train, X_test, y_test, c, M, args,
          iters=5, reg=0, L=10, normalize=False):

    y_t_orig = y_train
    y_v_orig = y_test
    y_train = convert_one_hot(y_train, c)
    y_test = convert_one_hot(y_test, c)

    if normalize:
        X_train /= np.linalg.norm(X_train, axis=-1).reshape(-1, 1)
        X_test /= np.linalg.norm(X_test, axis=-1).reshape(-1, 1)

    X_train = torch.from_numpy(X_train)
    y_train = torch.from_numpy(y_train)
    X_test = torch.from_numpy(X_test)
    y_test = torch.from_numpy(y_test)

    K_train = get_kernel(X_train, X_train, M, L, args.model)
    sol = np.linalg.solve(K_train + reg*np.eye(len(K_train)), y_train)

    K_test = get_kernel(X_train, X_test, M, L, args.model).numpy()
    preds = (sol.T @ K_test).T

    y_pred = torch.from_numpy(preds)
    preds = torch.argmax(y_pred, dim=-1)
    labels = torch.argmax(y_test, dim=-1)
    count = torch.sum(labels == preds).numpy()

    acc = count / len(labels)
    return acc

def hyperparam_train(X_train, y_train, X_test, y_test, c, args,
                     iters=5, reg=0, L=10, normalize=False, rfm_update='agop',
                     agop_power=0.5, centering=False):
    y_t_orig = y_train
    y_v_orig = y_test
    y_train = convert_one_hot(y_train, c)
    y_test = convert_one_hot(y_test, c)

    if normalize:
        X_train /= np.linalg.norm(X_train, axis=-1).reshape(-1, 1)
        X_test /= np.linalg.norm(X_test, axis=-1).reshape(-1, 1)

    X_train = torch.from_numpy(X_train)
    y_train = torch.from_numpy(y_train)
    X_test = torch.from_numpy(X_test)
    y_test = torch.from_numpy(y_test)

    best_acc = 0.
    best_iter = 0.
    best_M = 0.

    n, d = X_train.shape
    M = torch.eye(d)
    args.bandwidth = L
    args.rfm_update = rfm_update

    for i in range(iters):
        K_train = get_kernel(X_train, X_train, M, L, args.model)
        sol = np.linalg.solve(K_train + reg*np.eye(len(K_train)), y_train)
        sol = torch.from_numpy(sol)

        K_test = get_kernel(X_train, X_test, M, L, args.model).numpy()
        y_pred = (sol.T @ K_test).T

        preds = torch.argmax(y_pred, dim=-1)
        labels = torch.argmax(y_test, dim=-1)
        count = torch.sum(labels == preds).numpy()

        old_test_acc = count / len(labels)

        if old_test_acc > best_acc:
            best_iter = i
            best_acc = old_test_acc
            best_M = M

        M = update(X_train, X_train, M, sol, args, y_train,
                   centering=centering, K_train=K_train, return_per_class_agop=False)

        if agop_power != 1:
            M = utils.matrix_power(M, agop_power, is_torch=True)

    return best_acc, best_iter, best_M

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-dir', default = "data", type = str, help = "data directory")
    parser.add_argument('-file', default = "result.log", type = str, help = "Output File")
    parser.add_argument('-model', default='gaussian')

    args = parser.parse_args()

    datadir = args.dir

    avg_acc_list = []

    outf = open(args.file, "w")
    out_writer = csv.writer(outf, delimiter='\t')
    out_writer.writerow(['Dataset', 'Size', 'NumFeatures', 'NumClasses', 'Validation Acc',
                         'Reg', 'Iter', 'Test Acc', "RFM Update", '(W)AGOP Power', 'Bandwidth',
                         'AGOP Grad Centering', 'Normalize Data'])

    max_iter = 10
    regs = [10, 1, .1, 1e-2, 1e-3, 0]
    normalize = [True, False]
    rfm_update = ['wagop']
    agop_power = [1.0, 0.5]
    Ls = [1e-3, 1e-1, 1, 2, 5, 10]
    centering = [True, False]

    for idx, dataset in enumerate(sorted(os.listdir(datadir))):

        if not os.path.isdir(os.path.join(datadir, dataset)):
            continue
        if not os.path.isfile(os.path.join(datadir, dataset, f'{dataset}.txt')):
            continue
        dic = dict()
        for k, v in map(lambda x : x.split(), open(os.path.join(datadir, dataset, f'{dataset}.txt'), "r").readlines()):
            dic[k] = v
        c = int(dic["n_clases="])
        d = int(dic["n_entradas="])
        n_train = int(dic["n_patrons_entrena="])
        n_val = int(dic["n_patrons_valida="])
        n_train_val = int(dic["n_patrons1="])
        n_test = 0
        if "n_patrons2=" in dic:
            n_test = int(dic["n_patrons2="])
        n_tot = n_train_val + n_test

        if n_tot > 100000:
            continue
        print (idx, dataset, "\tN:", n_tot, "\td:", d, "\tc:", c)

        # load data
        f = open(os.path.join(datadir, dataset, dic["fich1="]), "r").readlines()[1:]
        X = np.asarray(list(map(lambda x: list(map(float, x.split()[1:-1])), f)))
        y = np.asarray(list(map(lambda x: int(x.split()[-1]), f)))

        # Hyperparameter Selection
        fold = list(map(lambda x: list(map(int, x.split())),
                        open(datadir + "/" + dataset + "/" + "conxuntos.dat", "r").readlines()))
        train_fold, val_fold = fold[0], fold[1]

        best_acc, best_reg, best_iter, best_M, best_L = 0, 0, 0, 0, 0
        best_normalize, best_centering = False, False
        best_update = ''
        best_agop_power = 0
        print("Cross Validating")
        for reg in regs:
            for n in normalize:
                if dataset == 'balance-scale':
                    n = False

                for u in rfm_update:
                    for pow in agop_power:
                        for L in Ls:
                            for cent in centering:
                                if u == 'wagop' and cent == True:
                                    # theres no centering in wagop so skip
                                    continue

                                try:
                                    acc, iter_v, M = hyperparam_train(X[train_fold], y[train_fold],
                                                                      X[val_fold], y[val_fold], c, args,
                                                                      iters=max_iter, reg=reg, L=L, normalize=n,
                                                                      rfm_update=u, agop_power=pow, centering=cent)
                                except np.linalg.LinAlgError as e:
                                    # print("Singular matrix encountered:", e)
                                    # this would only be because of a singular matrix error which
                                    # can occur in some times of using AGOP / WAGOP power of 1 or 1/2, case-dependent.
                                    continue

                                if acc > best_acc:
                                    best_acc = acc
                                    best_reg = reg
                                    best_iter = iter_v
                                    best_M = M
                                    best_normalize = n
                                    best_update = u
                                    best_agop_power = pow
                                    best_L = L
                                    best_centering = cent

        # 4-fold cross-validating
        avg_acc = 0.0
        fold = list(map(lambda x: list(map(int, x.split())),
                        open(datadir + "/" + dataset + "/" + "conxuntos_kfold.dat", "r").readlines()))
        print("Training")
        for repeat in range(4):
            train_fold, test_fold = fold[repeat * 2], fold[repeat * 2 + 1]

            acc = train(X[train_fold], y[train_fold], X[test_fold], y[test_fold],
                        c, best_M, args, iters=best_iter, reg=best_reg, normalize=best_normalize,
                        L=best_L)
            avg_acc += 0.25 * acc

        print ("acc:", avg_acc, best_reg, best_iter, best_normalize,"\n")

        out_writer.writerow([dataset, n_tot, d, c, best_acc, best_reg, best_iter, avg_acc*100,
                             best_update, best_agop_power, best_L, best_centering, best_normalize])
        outf.flush()

        avg_acc_list.append(avg_acc)

    print ("avg_acc:", np.mean(avg_acc_list) * 100)
    print("avg_acc:", np.mean(avg_acc_list)*100, file=outf, flush=True)
    outf.close()

if __name__=='__main__':
    main()
