import torch
from models.classic_kernel import euclidean_distances, euclidean_distances_M
import numpy as np
import scipy
import utils

torch.set_default_dtype(torch.float64)

def gaussian(samples, centers, bandwidth, return_dist=False):
    '''Gaussian kernel.

    Args:
        samples: of shape (n_sample, n_feature).
        centers: of shape (n_center, n_feature).
        bandwidth: kernel bandwidth.

    Returns:
        kernel matrix of shape (n_sample, n_center).
    '''
    assert bandwidth > 0
    kernel_mat = euclidean_distances(samples, centers)

    if return_dist:
        dist = kernel_mat.clone()

    kernel_mat.clamp_(min=0)
    gamma = 1. / (2 * bandwidth ** 2)
    kernel_mat.mul_(-gamma)
    kernel_mat.exp_()

    if return_dist:
        return kernel_mat, dist

    return kernel_mat

def gaussian_M(samples, centers, bandwidth, M, return_dist=False):
    # assert bandwidth > 0
    kernel_mat = euclidean_distances_M(samples, centers, M, squared=True)

    if return_dist:
        dist = kernel_mat.clone()

    kernel_mat.clamp_(min=0)
    gamma = 1. / (2 * bandwidth ** 2)
    kernel_mat.mul_(-gamma)
    kernel_mat.exp_()

    if return_dist:
        return kernel_mat, dist

    return kernel_mat

def get_agop(X, sol, L, P, batch_size=2, K=None, centering=False, x=None,
             return_per_class_agop=False):
    M = 0.

    if x is None:
        x = X

    if K is None:
        K = gaussian_M(X, x, L, P)

    a1 = sol.T
    n, d = X.shape
    n, c = a1.shape
    m, d = x.shape

    a1 = a1.reshape(n, c, 1)
    X1 = (X @ P).reshape(n, 1, d)
    step1 = a1 @ X1 # (n, c, d)
    del a1, X1
    step1 = step1.reshape(-1, c*d) # (n, c*d)

    step2 = K.T @ step1 # (n, c*d)
    del step1

    step2 = step2.reshape(-1, c, d) # (n, c, d)

    a2 = sol
    step3 = (a2 @ K).T # (n, c)

    del K, a2

    step3 = step3.reshape(m, c, 1)
    x1 = (x @ P).reshape(m, 1, d)
    step3 = step3 @ x1 # (n, c, d)

    G = (step2 - step3) * -1/(L**2) # (n, c, d)

    M = 0.

    if centering:
        G = G - G.mean(0)

    bs = batch_size
    batches = torch.split(G, bs)

    for i in range(len(batches)):
        # grad = batches[i].cuda()
        grad = batches[i]
        gradT = torch.transpose(grad, 1, 2)
        M += torch.sum(gradT @ grad, dim=0).cpu()
        del grad, gradT
    torch.cuda.empty_cache()
    M /= len(G)

    per_class_agops = []
    if return_per_class_agop:
        for i in range(len(batches)):
            for class_i in range(G.shape[1]):
                if len(per_class_agops) < G.shape[1]:
                    per_class_agops.append(batches[i][:,class_i].T @ batches[i][:,class_i])
                else:
                    per_class_agops[class_i] += batches[i][:,class_i].T @ batches[i][:,class_i]
        for class_i in range(G.shape[1]):
            per_class_agops[class_i] /= len(G)

    return M, per_class_agops

def get_wagop(X, sol, L, M, y, batch_size=2, K=None, centering=False, x=None):
    '''
    X is shape (n, d)
    sol is shape (c, n)
    L is a float scalar
    P is shape (d, d)
    y is shape (n, c)
    K is shape (n, n)
    '''
    if K is None:
        K = classic_kernel.gaussian_M(X,X,L,M)

    R = (sol.T @ sol) * K / L**2
    XM = X @ M
    mu = torch.sum(R,dim=0).view(-1,1) * XM
    nu = R @ XM
    wagop = XM.T @ (nu - mu)

    return  wagop, None
