import numpy as np
import torch
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

import time
from tqdm import tqdm
from . import classic_kernel
from . import utils
# import hickle

from . import generic_rfm

def train_hyperparams_and_test(X_train, y_train, X_test, y_test, c,
                               hyperparams, log_hyperparams,
                               device='cpu',return_train_traj=False,return_last=False):
    
    take_sqrt=hyperparams['take_sqrt']
    geom_update=hyperparams['geom_update']
    
    kernel_M_function = classic_kernel.quadratic_kernel_L_M
    def kernel_M_update_function(X, L, sol, M):
        wagop = classic_kernel.quadratic_kernel_L_M_wagop(X, L, sol, M)
        if not geom_update:
            if take_sqrt:
                M = utils.matrix_sqrt(wagop)
            else:
                M = wagop
        else:
            if take_sqrt:
                M = utils.matrix_sqrt(utils.matrix_sqrt(M @ wagop @ M.T))

            else:
                M = utils.matrix_sqrt(utils.matrix_sqrt(M @ wagop @ wagop @ M.T))
        return M
        
    
    return generic_rfm.train_hyperparams_and_test(X_train, y_train, X_test, y_test, c,
                                           hyperparams, log_hyperparams,
                                           device=device,
                                           return_train_traj=return_train_traj, return_last=return_last,
                                           kernel_M_function=kernel_M_function, kernel_M_update_function=kernel_M_update_function)
