import numpy as np
import torch
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

import time
from tqdm import tqdm
from . import classic_kernel
from . import utils

from . import generic_rfm
# import hickle


def train_hyperparams_and_test(X_train, y_train, X_test, y_test, c,
                               hyperparams, log_hyperparams,
                               device='cpu',return_train_traj=False,return_last=False):
    
    take_sqrt=hyperparams['take_sqrt']
    geom_update=hyperparams['geom_update']
    
    kernel_M_function = classic_kernel.clamped_laplacian_M
    def kernel_M_update_function(X, L, sol, M):
        wagop = classic_kernel.clamped_laplacian_M_wagop(X, L, sol, M)
        if not geom_update:
            if take_sqrt:
                M = utils.matrix_sqrt(wagop)
            else:
                M = wagop
        else:
            if take_sqrt:
                M = utils.matrix_sqrt(utils.matrix_sqrt(M @ wagop @ M.T))

            else:
                M = utils.matrix_sqrt(utils.matrix_sqrt(M @ wagop @ wagop @ M.T))
        return M
        
    
    return generic_rfm.train_hyperparams_and_test(X_train, y_train, X_test, y_test, c,
                                           hyperparams, log_hyperparams,
                                           device=device,
                                           return_train_traj=return_train_traj, return_last=return_last,
                                           kernel_M_function=kernel_M_function, kernel_M_update_function=kernel_M_update_function)

# def hyperparam_train(X_train, y_train, X_test, y_test, c,
#                      iters=5, reg=0, L=10, normalize=False,device='cpu',return_train_traj=False, take_sqrt=False,
#                     return_last=False):
    
#     # Outputs the best accuracy, best iteration and best M on the test set from training for <= iters # of iters on the test set

#     y_t_orig = y_train
#     y_v_orig = y_test
#     y_train = utils.convert_one_hot(y_train, c)
#     y_test = utils.convert_one_hot(y_test, c)

#     if normalize:
#         X_train /= np.linalg.norm(X_train, axis=-1).reshape(-1, 1)
#         X_test /= np.linalg.norm(X_test, axis=-1).reshape(-1, 1)

#     X_train = torch.from_numpy(X_train).float().to(device)
#     y_train = torch.from_numpy(y_train).float().to(device)
#     X_test = torch.from_numpy(X_test).float().to(device)
#     y_test = torch.from_numpy(y_test).float().to(device)

#     best_acc = 0.
#     best_iter = 0.
#     best_M = 0.

#     n, d = X_train.shape
#     M = torch.eye(d).to(device)
#     train_traj = []

#     for i in range(iters):

#         K_train = classic_kernel.clamped_laplacian_M(X_train, X_train, L, M)
#         sol = torch.linalg.solve(K_train + reg * torch.eye(K_train.shape[0]).to(device), y_train).T
#         if return_train_traj:
#             train_traj.append((sol,M))
            

#         K_test = classic_kernel.clamped_laplacian_M(X_train, X_test, L, M)
#         y_pred = (sol @ K_test).T

#         preds = torch.argmax(y_pred, dim=-1)
#         labels = torch.argmax(y_test, dim=-1)
#         count = torch.sum(labels == preds).detach().cpu().numpy()

#         old_test_acc = count / len(labels)

#         if (old_test_acc > best_acc) or return_last:
#             best_iter = i
#             best_acc = old_test_acc
#             best_M = M
#         M  = classic_kernel.clamped_laplacian_M_wagop(X_train, L, sol, M)
#         if take_sqrt:
#             M = utils.matrix_sqrt(M)
#         # M  = utils.matrix_sqrt(classic_kernel.clamped_laplacian_M_wagop(X_train, L, sol, M))

#     if return_train_traj:
#         return best_acc, best_iter, best_M, train_traj
#     else:
#         return best_acc, best_iter, best_M



# def hyperparam_train_geom(X_train, y_train, X_test, y_test, c,
#                      iters=5, reg=0, L=10, normalize=False,device='cpu',return_train_traj=False, take_sqrt=False,
#                          return_last=False):
    
#     # Outputs the best accuracy, best iteration and best M on the test set from training for <= iters # of iters on the test set

#     y_t_orig = y_train
#     y_v_orig = y_test
#     y_train = utils.convert_one_hot(y_train, c)
#     y_test = utils.convert_one_hot(y_test, c)

#     if normalize:
#         X_train /= np.linalg.norm(X_train, axis=-1).reshape(-1, 1)
#         X_test /= np.linalg.norm(X_test, axis=-1).reshape(-1, 1)

#     X_train = torch.from_numpy(X_train).float().to(device)
#     y_train = torch.from_numpy(y_train).float().to(device)
#     X_test = torch.from_numpy(X_test).float().to(device)
#     y_test = torch.from_numpy(y_test).float().to(device)

#     best_acc = 0.
#     best_iter = 0.
#     best_M = 0.

#     n, d = X_train.shape
#     M = torch.eye(d).to(device)
#     train_traj = []

#     for i in range(iters):

#         K_train = classic_kernel.clamped_laplacian_M(X_train, X_train, L, M)
#         sol = torch.linalg.solve(K_train + reg * torch.eye(K_train.shape[0]).to(device), y_train).T
#         if return_train_traj:
#             train_traj.append((sol,M))
            

#         K_test = classic_kernel.clamped_laplacian_M(X_train, X_test, L, M)
#         y_pred = (sol @ K_test).T

#         preds = torch.argmax(y_pred, dim=-1)
#         labels = torch.argmax(y_test, dim=-1)
#         count = torch.sum(labels == preds).detach().cpu().numpy()

#         old_test_acc = count / len(labels)

#         if (old_test_acc > best_acc) or return_last:
#             best_iter = i
#             best_acc = old_test_acc
#             best_M = M
#         wagop  = classic_kernel.clamped_laplacian_M_wagop(X_train, L, sol, M)
#         if take_sqrt:
#             M = utils.matrix_sqrt(utils.matrix_sqrt(M @ wagop @ M.T))
#         else:
#             M = utils.matrix_sqrt(utils.matrix_sqrt(M @ wagop @ wagop.T @ M.T))
#         # M  = utils.matrix_sqrt(classic_kernel.clamped_laplacian_M_wagop(X_train, L, sol, M))

#     if return_train_traj:
#         return best_acc, best_iter, best_M, train_traj
#     else:
#         return best_acc, best_iter, best_M


def train(X_train, y_train, X_test, y_test, c, M,
          iters=5, reg=0, L=10, normalize=False, device='cpu'):

    y_t_orig = y_train
    y_v_orig = y_test
    y_train = utils.convert_one_hot(y_train, c)
    y_test = utils.convert_one_hot(y_test, c)

    if normalize:
        X_train /= np.linalg.norm(X_train, axis=-1).reshape(-1, 1)
        X_test /= np.linalg.norm(X_test, axis=-1).reshape(-1, 1)

    X_train = torch.from_numpy(X_train).float().to(device)
    y_train = torch.from_numpy(y_train).float().to(device)
    X_test = torch.from_numpy(X_test).float().to(device)
    y_test = torch.from_numpy(y_test).float().to(device)

    K_train = classic_kernel.clamped_laplacian_M(X_train, X_train, L, M)
    sol = torch.linalg.solve(K_train + reg * torch.eye(K_train.shape[0]).to(device), y_train).T

    K_test = classic_kernel.clamped_laplacian_M(X_train, X_test, L, M)
    y_pred = (sol @ K_test).T

    preds = torch.argmax(y_pred, dim=-1)
    labels = torch.argmax(y_test, dim=-1)
    count = torch.sum(labels == preds).detach().cpu().numpy()

    acc = count / len(labels)
    return acc

