import numpy as np
import torch
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

import time
from tqdm import tqdm
import classic_kernel
import utils
# import hickle

def hyperparam_train(X_train, y_train, X_test, y_test, c,
                     iters=5, reg=0, L=10, normalize=False,device='cpu',return_train_traj=False):
    
    # Outputs the best accuracy, best iteration and best M on the test set from training for <= iters # of iters on the test set

    y_t_orig = y_train
    y_v_orig = y_test
    y_train = utils.convert_one_hot(y_train, c)
    y_test = utils.convert_one_hot(y_test, c)

    if normalize:
        X_train /= np.linalg.norm(X_train, axis=-1).reshape(-1, 1)
        X_test /= np.linalg.norm(X_test, axis=-1).reshape(-1, 1)

    X_train = torch.from_numpy(X_train).float().to(device)
    y_train = torch.from_numpy(y_train).float().to(device)
    X_test = torch.from_numpy(X_test).float().to(device)
    y_test = torch.from_numpy(y_test).float().to(device)

    best_acc = 0.
    best_iter = 0.
    best_M = 0.

    n, d = X_train.shape
    M = torch.eye(d).to(device)
    train_traj = []

    for i in range(iters):

        K_train = classic_kernel.clamped_laplacian_M(X_train, X_train, L, M)
        sol = torch.linalg.solve(K_train + reg * torch.eye(K_train.shape[0]).to(device), y_train).T
        if return_train_traj:
            train_traj.append((sol,M))
            

        K_test = classic_kernel.clamped_laplacian_M(X_train, X_test, L, M)
        y_pred = (sol @ K_test).T

        preds = torch.argmax(y_pred, dim=-1)
        labels = torch.argmax(y_test, dim=-1)
        count = torch.sum(labels == preds).detach().cpu().numpy()

        old_test_acc = count / len(labels)

        if old_test_acc > best_acc:
            best_iter = i
            best_acc = old_test_acc
            best_M = M
        fact = classic_kernel.clamped_laplacian_M_fact(X_train, L, sol, M)
        # print(fact)
        in_between = fact @ M.T
        M = utils.matrix_pow(in_between @ in_between.T,1/4)
        # M = 
        # M  = utils.matrix_sqrt(classic_kernel.clamped_laplacian_M_wagop(X_train, L, sol, M))

    if return_train_traj:
        return best_acc, best_iter, best_M, train_traj
    else:
        return best_acc, best_iter, best_M


def train(X_train, y_train, X_test, y_test, c, M,
          iters=5, reg=0, L=10, normalize=False, device='cpu'):

    y_t_orig = y_train
    y_v_orig = y_test
    y_train = utils.convert_one_hot(y_train, c)
    y_test = utils.convert_one_hot(y_test, c)

    if normalize:
        X_train /= np.linalg.norm(X_train, axis=-1).reshape(-1, 1)
        X_test /= np.linalg.norm(X_test, axis=-1).reshape(-1, 1)

    X_train = torch.from_numpy(X_train).float().to(device)
    y_train = torch.from_numpy(y_train).float().to(device)
    X_test = torch.from_numpy(X_test).float().to(device)
    y_test = torch.from_numpy(y_test).float().to(device)

    K_train = classic_kernel.clamped_laplacian_M(X_train, X_train, L, M)
    sol = torch.linalg.solve(K_train + reg * torch.eye(K_train.shape[0]).to(device), y_train).T

    K_test = classic_kernel.clamped_laplacian_M(X_train, X_test, L, M)
    y_pred = (sol @ K_test).T

    preds = torch.argmax(y_pred, dim=-1)
    labels = torch.argmax(y_test, dim=-1)
    count = torch.sum(labels == preds).detach().cpu().numpy()

    acc = count / len(labels)
    return acc

