import numpy as np
import torch
from numpy.linalg import solve, svd, norm
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from . import utils
# from . import agop_rfm_gauss, wagop_rfm_gauss, agop_rfm_laplace, wagop_rfm_laplace, wagop_rfm_quad
from . import classic_kernel
from tqdm.contrib import tenumerate
import hickle
from torchmetrics.functional.classification import accuracy
from .eigenpro import KernelModel
import scipy

import time
from tqdm import tqdm


class GenericRFM:
    """
    Copied from RFM github, with support for update_power, and small changes to inheritance
    Subclasses should implement update_M and kernel_M
    """
    def __init__(self, device=torch.device('cpu'), mem_gb=8, diag=False, centering=False, reg=1e-3, update_power=1.0, standardize_max=True):
        super().__init__()
        self.M = None
        self.model = None
        self.diag = diag # if True, Mahalanobis matrix M will be diagonal
        self.centering = centering # if True, update_M will center the gradients before taking an outer product
        self.device = device
        self.mem_gb = mem_gb
        self.reg = reg # only used when fit using direct solve
        self.update_power = update_power
        self.standardize_max = standardize_max # whether max entry of M should be forced to have value 1

    def update_M(self, samples, p_batch_size):
        raise NotImplementedError("Must implement this method in a subclass")

    def kernel_M(self, samples, centers):
        raise NotImplementedError("Must implement this method in a subclass")

    def get_data(self, data_loader):
        X, y = [], []
        for idx, batch in enumerate(data_loader):
            inputs, labels = batch
            X.append(inputs)
            y.append(labels)
        return torch.cat(X, dim=0), torch.cat(y, dim=0)

    def update_M(self):
        raise NotImplementedError("Must implement this method in a subclass")


    def fit_predictor(self, centers, targets, **kwargs):
        self.centers = centers
        if self.M is None:
            if self.diag:
                self.M = torch.ones(centers.shape[-1], device=self.device, dtype=centers.dtype)
            else:
                self.M = torch.eye(centers.shape[-1], device=self.device, dtype=centers.dtype)
        if self.fit_using_eigenpro:
            self.weights = self.fit_predictor_eigenpro(centers, targets, **kwargs)
        elif self.fit_using_numpy:
            self.weights = self.fit_predictor_lstsq_numpy(centers, targets)
        else:
            self.weights = self.fit_predictor_lstsq(centers, targets)


    def fit_predictor_lstsq(self, centers, targets):
        centers = centers.to(self.device)
        targets = targets.to(self.device)
        if self.reg>0:
            return torch.linalg.solve(
                self.kernel_M(centers, centers)
                + self.reg*torch.eye(len(centers), device=centers.device),
                targets
            )
        else:
            return torch.linalg.solve(
                self.kernel_M(centers, centers),
                targets
            )


    def fit_predictor_lstsq_numpy(self, centers, targets):
        centers = centers.to(self.device)
        targets = targets.to(self.device)
        if self.reg>0:
            return torch.from_numpy(np.linalg.solve(
                self.kernel_M(centers, centers).cpu().numpy()
                + self.reg*np.eye(len(centers), device=centers.device),
                targets.cpu().numpy()
            )).to(self.device)
        else:
            return torch.from_numpy(np.linalg.solve(
                self.kernel_M(centers, centers).cpu().numpy(),
                targets.cpu().numpy()
            )).to(self.device)

    def fit_predictor_eigenpro(self, centers, targets, **kwargs):
        n_classes = 1 if targets.dim()==1 else targets.shape[-1]
        self.model = KernelModel(self.kernel_M, centers, n_classes, device=self.device)
        _ = self.model.fit(centers, targets, mem_gb=self.mem_gb, **kwargs)
        return self.model.weight


    def predict(self, samples):
        out = self.kernel_M(samples.to(self.device), self.centers.to(self.device)) @ self.weights.to(self.device)
        return out.to(samples.device)


    def fit(self, train_loader, test_loader,
            iters=3, name=None, method='lstsq',
            train_acc=False, loader=True, classif=True,
            return_mse=False, verbose=True, M_batch_size=None, **kwargs):

        self.fit_using_eigenpro = (method.lower()=='eigenpro')
        self.fit_using_numpy = (method.lower()=='numpy')

        if loader:
            print("Loaders provided")
            X_train, y_train = self.get_data(train_loader)
            X_test, y_test = self.get_data(test_loader)
        else:
            X_train, y_train = train_loader
            X_test, y_test = test_loader

        test_accs = []
        mses = []
        Ms = []
        for i in range(iters):
            self.fit_predictor(X_train, y_train, X_val=X_test, y_val=y_test, **kwargs)

            if classif:
                if verbose or return_mse:
                    test_acc = self.score(X_test, y_test, metric='accuracy')
                    train_acc = self.score(X_train, y_train, metric='accuracy')
                    test_accs.append(test_acc)
                if verbose:
                    print(f"Round {i}, Train Acc: {100*train_acc:.2f}%")
                    print(f"Round {i}, Test Acc: {100*test_acc:.2f}%")


            test_mse = self.score(X_test, y_test, metric='mse')

            if verbose:
                print(f"Round {i}, Test MSE: {test_mse:.4f}")

            self.fit_M(X_train, y_train, verbose=verbose, M_batch_size=M_batch_size, **kwargs)

            if return_mse:
                Ms.append(self.M+0)
                mses.append(test_mse)

            if name is not None:
                hickle.dump(self.M, f"saved_Ms/M_{name}_{i}.h")

        self.fit_predictor(X_train, y_train, X_val=X_test, y_val=y_test, **kwargs)
        final_mse = self.score(X_test, y_test, metric='mse')

        if verbose:
            print(f"Final MSE: {final_mse:.4f}")
        if classif:
            final_test_acc = self.score(X_test, y_test, metric='accuracy')
            test_accs.append(final_test_acc)
            if verbose:
                print(f"Final Test Acc: {100*final_test_acc:.2f}%")

        if return_mse:
            if classif:
                return Ms, mses, test_accs
            else:
                return Ms, mses

        return final_mse

    def _compute_optimal_M_batch(self, p, c, d, scalar_size=4):
        """Computes the optimal batch size for EGOP."""
        THREADS_PER_BLOCK = 512 # pytorch default
        def tensor_mem_usage(numels):
            """Calculates memory footprint of tensor based on number of elements."""
            return np.ceil(scalar_size * numels / THREADS_PER_BLOCK) * THREADS_PER_BLOCK

        def max_tensor_size(mem):
            """Calculates maximum possible tensor given memory budget (bytes)."""
            return int(np.floor(mem / THREADS_PER_BLOCK) * (THREADS_PER_BLOCK / scalar_size))

        curr_mem_use = torch.cuda.memory_allocated() # in bytes
        M_mem = tensor_mem_usage(d if self.diag else d**2)
        centers_mem = tensor_mem_usage(p * d)
        mem_available = (self.mem_gb *1024**3) - curr_mem_use - (M_mem + centers_mem) * scalar_size
        M_batch_size = max_tensor_size((mem_available - 3*tensor_mem_usage(p) - tensor_mem_usage(p*c*d)) / (2*scalar_size*(1+p)))

        return M_batch_size

    def fit_M(self, samples, labels, p_batch_size=None, M_batch_size=None,
              verbose=True, total_points_to_sample=50000, **kwargs):
        """Applies EGOP to update the Mahalanobis matrix M."""

        n, d = samples.shape
        M = torch.zeros_like(self.M) if self.M is not None else (
            torch.zeros(d, dtype=samples.dtype) if self.diag else torch.zeros(d, d, dtype=samples.dtype))

        if M_batch_size is None:
            BYTES_PER_SCALAR = self.M.element_size()
            p, d = samples.shape
            c = labels.shape[-1]
            M_batch_size = self._compute_optimal_M_batch(p, c, d, scalar_size=BYTES_PER_SCALAR)

            if verbose:
                print(f"Using batch size of {M_batch_size}")

        # batches = torch.randperm(n).split(M_batch_size)
        batches = torch.arange(n).split(M_batch_size)

        num_batches = 1 + total_points_to_sample//M_batch_size
        batches = batches[:num_batches]
        if verbose:
            print(f'Sampling AGOP on {num_batches*M_batch_size} total points')

        if verbose:
            for i, bids in tenumerate(batches):
                torch.cuda.empty_cache()
                M.add_(self.update_M(samples[bids], p_batch_size))
        else:
            for bids in batches:
                torch.cuda.empty_cache()
                M.add_(self.update_M(samples[bids], p_batch_size))

        if self.standardize_max:
            M = M / M.max()
        if self.update_power != 1.0: # Take matrix power of updated M
            M = utils.matrix_pow(M, self.update_power)
        self.M = M
        if self.standardize_max:
            self.M = self.M / self.M.max()
        del M


    def score(self, samples, targets, metric='mse'):
        preds = self.predict(samples.to(self.device)).to(targets.device)
        if metric=='accuracy':
            if preds.shape[-1]==1:
                num_classes = len(torch.unique(targets))
                if num_classes==2:
                    return accuracy(preds, targets, task="binary").item()
                else:
                    return accuracy(preds, targets, task="multiclass", num_classes=num_classes).item()
            else:
                preds_ = torch.argmax(preds,dim=-1)
                targets_ = torch.argmax(targets,dim=-1)
                return accuracy(preds_, targets_, task="multiclass", num_classes=preds.shape[-1]).item()

        elif metric=='mse':
            return (targets - preds).pow(2).mean()
        else:
            raise NotImplementedError("Only mse and accuracy are supported metrics.")

class Gauss_AGOP_RFM(GenericRFM):

    def __init__(self, bandwidth=1., **kwargs):
        super().__init__(**kwargs)
        self.bandwidth = bandwidth

    def kernel_M(self, samples, centers):
        return classic_kernel.gaussian_M(samples, centers, self.bandwidth, self.M)

    def update_M(self, samples, p_batch_size=None):
        if p_batch_size is not None:
            raise NotImplementedError("p_batch_size not yet implemented")

        K = self.kernel_M(samples, self.centers)

        p, d = self.centers.shape
        p, c = self.weights.shape
        n, d = samples.shape

        samples_term = (
                K # (n, p)
                @ self.weights # (p, c)
            ).reshape(n, c, 1)

        if self.diag:
            centers_term = (
                K # (n, p)
                @ (
                    self.weights.view(p, c, 1) * (self.centers * self.M).view(p, 1, d)
                ).reshape(p, c*d) # (p, cd)
            ).view(n, c, d) # (n, c, d)

            samples_term = samples_term * (samples * self.M).reshape(n, 1, d)

        else:
            centers_term = (
                K # (n, p)
                @ (
                    self.weights.view(p, c, 1) * (self.centers @ self.M).view(p, 1, d)
                ).reshape(p, c*d) # (p, cd)
            ).view(n, c, d) # (n, c, d)

            samples_term = samples_term * (samples @ self.M).reshape(n, 1, d)

        G = (centers_term - samples_term) / self.bandwidth**2 # (n, c, d)

        if self.centering:
            G = G - G.mean(0) # (n, c, d)

        if self.diag:
            return torch.einsum('ncd, ncd -> d', G, G)
        else:
            return torch.einsum("ncd, ncD -> dD", G, G)


class Gauss_AGOP_TestAlternate_RFM(GenericRFM):

    def __init__(self, bandwidth=1., **kwargs):
        super().__init__(**kwargs)
        self.bandwidth = bandwidth

    def kernel_M(self, samples, centers):
        return classic_kernel.gaussian_M(samples, centers, self.bandwidth, self.M)

    def update_M(self, samples, p_batch_size=None):
        if p_batch_size is not None:
            raise NotImplementedError("p_batch_size not yet implemented")

        K = self.kernel_M(samples, self.centers)

        X = samples
        x = self.centers
        P = self.M
        L = self.bandwidth
        sol = self.weights.T
        batch_size = 2
        diag_only=self.diag
        return_per_class_agop=False
        centering=self.centering
        M = 0.

        if x is None:
            x = X

        if K is None:
            K = classic_kernel.gaussian_M(X, x, L, P)

        a1 = sol.T
        n, d = X.shape
        n, c = a1.shape
        m, d = x.shape

        a1 = a1.reshape(n, c, 1)
        X1 = (X @ P).reshape(n, 1, d)
        step1 = a1 @ X1
        del a1, X1
        step1 = step1.reshape(-1, c*d)

        step2 = K.T @ step1
        del step1

        step2 = step2.reshape(-1, c, d)

        a2 = sol
        step3 = (a2 @ K).T

        del K, a2

        step3 = step3.reshape(m, c, 1)
        x1 = (x @ P).reshape(m, 1, d)
        step3 = step3 @ x1

        G = (step2 - step3) * -1/(L**2)

        M = 0.

        if centering:
            G = G - G.mean(0)

        bs = batch_size
        batches = torch.split(G, bs)

        for i in range(len(batches)):
            # grad = batches[i].cuda()
            grad = batches[i]
            gradT = torch.transpose(grad, 1, 2)
            M += torch.sum(gradT @ grad, dim=0)
            del grad, gradT
        torch.cuda.empty_cache()
        M /= len(G)
        return M
    # if agop_power == 0.5:
    #     M = np.real(scipy.linalg.sqrtm(M.numpy()))
    # elif agop_power.is_integer():
    #     if agop_power == 1:
    #         M = M.numpy()
    #     else:
    #         M = np.real(np.linalg.matrix_power(M.numpy(), int(agop_power)))
    # else:
    #     eigs, vecs = np.linalg.eigh(M.numpy())
    #     eigs = np.power(eigs, agop_power)
    #     eigs[np.isnan(eigs)] = 0.0
    #     M = vecs @ np.diag(eigs) @ vecs.T

# def gaussian_M_update(samples, centers, bandwidth, M, weights, K=None, \
#                       centers_bsize=-1, centering=False, agop_power=0.5,
#                       return_per_class_agop=False):
#     return get_grads(samples, weights.T, bandwidth, M, K=K, centering=centering, x=centers,
#                      agop_power=agop_power, return_per_class_agop=return_per_class_agop)

import itertools
def wagop_gauss_naive_M(X,L,sol,M,diag_only=False):
    assert(not diag_only)
    n = X.shape[0]
    d = X.shape[1]
    A = sol.t() @ sol
    K = classic_kernel.gaussian_M(X,X,L,M)
    AK = A*K
    wagop = torch.zeros(d,d,device=X.device)
    for i,j in tqdm(itertools.product(range(n),repeat=2),total=n**2):
        # \sum_{i,j} A_{ij} K_{ij}
        term = M.T @ (torch.outer(X[i]-X[j],X[i]-X[j])) @ M / 2
        wagop -= AK[i,j] * term / L**2
    return wagop

def wagop_gauss_compact_M(X, L, sol, M, diag_only=False):
    assert(not diag_only)
    K = classic_kernel.gaussian_M(X,X,L,M)
    R = (sol.T @ sol) * K / L**2
    XM = X @ M
    mu = torch.sum(R,dim=0).view(-1,1) * XM
    nu = R @ XM
    return XM.T @ (nu - mu)

def hmwagop_gauss_compact_M(X, L, sol, M, diag_only=False):
    # Hail Mary wAGOP
    assert(not diag_only)
    n = X.shape[0]
    K = classic_kernel.gaussian_M(X, X, L, M)
    A = 2 * sol.T @ sol  # alpha_i alpha_j
    A -= torch.diag(sol.T @ sol).reshape(-1,1)  # alpha_i alpha_i
    A -= torch.diag(sol.T @ sol).reshape(1,-1)  # alpha_j alpha_j
    R = A * K / L**2
    XM = X @ M
    mu = torch.sum(R,dim=0).view(-1,1) * XM
    nu = R @ XM
    return XM.T @ (nu - mu)


def lgop_gauss_compact_M(X, L, sol, M, diag_only=False):
    assert(not diag_only)
    n = X.shape[0]
    K = classic_kernel.gaussian_M(X,X,L,M)
    R = (sol.T @ sol) * K / L**2
    XM = X @ M
    mu = torch.sum(R,dim=0).view(-1,1) * XM
    nu = R @ XM
    return X.T @ (nu - mu)

def wagop_gauss_inner_compact_M(X, L, sol, M, diag_only=False):
    assert(not diag_only)
    K = classic_kernel.gaussian_inner_M(X,X,L,M)
    R = (sol.T @ sol) * K / L**2
    XM = X @ M
    return XM.T @ R @ XM

def wagop_gauss_hack_M(X, L, sol, M, diag_only=False):
    assert(not diag_only)
    K = classic_kernel.gaussian_M(X,X,L,M)
    R = (sol.T @ sol) * K / L**2
    XM = X @ M
    return XM.T @ R @ XM

class GaussInner_wAGOP_RFM(GenericRFM):

    def __init__(self, bandwidth=1., wagop_implementation='classic', **kwargs):
        super().__init__(**kwargs)
        self.bandwidth = bandwidth
        self.wagop_implementation = wagop_implementation

    def kernel_M(self, samples, centers):
        return classic_kernel.gaussian_inner_M(samples, centers, self.bandwidth, self.M)

    def update_M(self, samples, p_batch_size=None):
        samples = samples.to(self.device)
        self.centers = self.centers.to(self.device)
        if p_batch_size is not None:
            raise NotImplementedError("p_batch_size not yet implemented")

        if self.centering:
            raise NotImplementedError("Centering not yet implemented for GaussInner wAGOP method")

        if samples.shape[0] != self.centers.shape[0]:
            raise ValueError("Number of samples must equal centers in current implementation of GaussInner wAGOP")
        # if self.wagop_implementation == 'classic':
        #     return classic_kernel.gaussian_M_wagop(samples, self.bandwidth, self.weights.T, self.M, diag_only=self.diag)
        # elif self.wagop_implementation == 'naive':
        #     return wagop_gauss_naive_M(samples, self.bandwidth, self.weights.T, self.M, diag_only=self.diag)
        if self.wagop_implementation == 'compact':
            return wagop_gauss_inner_compact_M(samples, self.bandwidth, self.weights.T, self.M, diag_only=self.diag)
        else:
            raise ValueError("Invalid wagop_implementation")

class Gauss_LGOP_RFM(GenericRFM):

    def __init__(self, bandwidth=1., wagop_implementation='classic', **kwargs):
        super().__init__(**kwargs)
        self.bandwidth = bandwidth
        self.wagop_implementation = wagop_implementation

    def kernel_M(self, samples, centers):
        return classic_kernel.gaussian_M(samples, centers, self.bandwidth, self.M)

    def update_M(self, samples, p_batch_size=None):
        samples = samples.to(self.device)
        self.centers = self.centers.to(self.device)
        if p_batch_size is not None:
            raise NotImplementedError("p_batch_size not yet implemented")

        if self.centering:
            raise NotImplementedError("Centering not yet implemented for Gauss LGOP method")

        if samples.shape[0] != self.centers.shape[0]:
            raise ValueError("Number of samples must equal centers in current implementation of Gauss LGOP")
        # if self.wagop_implementation == 'classic':
        #     return classic_kernel.gaussian_M_wagop(samples, self.bandwidth, self.weights.T, self.M, diag_only=self.diag)
        # elif self.wagop_implementation == 'naive':
        #     return wagop_gauss_naive_M(samples, self.bandwidth, self.weights.T, self.M, diag_only=self.diag)
        if self.wagop_implementation == 'compact':
            return lgop_gauss_compact_M(samples, self.bandwidth, self.weights.T, self.M, diag_only=self.diag)
        else:
            raise ValueError("Invalid wagop_implementation")

class Gauss_HMwAGOP_RFM(GenericRFM):

    def __init__(self, bandwidth=1., wagop_implementation='classic', **kwargs):
        super().__init__(**kwargs)
        self.bandwidth = bandwidth
        self.wagop_implementation = wagop_implementation

    def kernel_M(self, samples, centers):
        return classic_kernel.gaussian_M(samples, centers, self.bandwidth, self.M)

    def update_M(self, samples, p_batch_size=None):
        samples = samples.to(self.device)
        self.centers = self.centers.to(self.device)
        if p_batch_size is not None:
            raise NotImplementedError("p_batch_size not yet implemented")

        if self.centering:
            raise NotImplementedError("Centering not yet implemented for Gauss LGOP method")

        if samples.shape[0] != self.centers.shape[0]:
            raise ValueError("Number of samples must equal centers in current implementation of Gauss LGOP")
        # if self.wagop_implementation == 'classic':
        #     return classic_kernel.gaussian_M_wagop(samples, self.bandwidth, self.weights.T, self.M, diag_only=self.diag)
        # elif self.wagop_implementation == 'naive':
        #     return wagop_gauss_naive_M(samples, self.bandwidth, self.weights.T, self.M, diag_only=self.diag)
        if self.wagop_implementation == 'compact':
            return hmwagop_gauss_compact_M(samples, self.bandwidth, self.weights.T, self.M, diag_only=self.diag)
        else:
            raise ValueError("Invalid wagop_implementation")

class Gauss_wAGOP_RFM(GenericRFM):

    def __init__(self, bandwidth=1., wagop_implementation='classic', **kwargs):
        super().__init__(**kwargs)
        self.bandwidth = bandwidth
        self.wagop_implementation = wagop_implementation

    def kernel_M(self, samples, centers):
        return classic_kernel.gaussian_M(samples, centers, self.bandwidth, self.M)

    def update_M(self, samples, p_batch_size=None):
        samples = samples.to(self.device)
        self.centers = self.centers.to(self.device)
        if p_batch_size is not None:
            raise NotImplementedError("p_batch_size not yet implemented")

        if self.centering:
            raise NotImplementedError("Centering not yet implemented for Gauss wAGOP method")

        if samples.shape[0] != self.centers.shape[0]:
            raise ValueError("Number of samples must equal centers in current implementation of Gauss wAGOP")
        if self.wagop_implementation == 'classic':
            return classic_kernel.gaussian_M_wagop(samples, self.bandwidth, self.weights.T, self.M, diag_only=self.diag)
        elif self.wagop_implementation == 'naive':
            return wagop_gauss_naive_M(samples, self.bandwidth, self.weights.T, self.M, diag_only=self.diag)
        elif self.wagop_implementation == 'compact':
            return wagop_gauss_compact_M(samples, self.bandwidth, self.weights.T, self.M, diag_only=self.diag)
        elif self.wagop_implementation == 'hack':
            return wagop_gauss_hack_M(samples, self.bandwidth, self.weights.T, self.M, diag_only=self.diag)
        else:
            raise ValueError("Invalid wagop_implementation")


class Laplace_AGOP_RFM(GenericRFM):

    def __init__(self, bandwidth=1., **kwargs):
        super().__init__(**kwargs)
        self.bandwidth = bandwidth

    def kernel_M(self, samples, centers):
        return classic_kernel.laplacian_M(samples, centers, self.bandwidth, self.M)

    def update_M(self, samples, p_batch_size):
        samples = samples.to(self.device)
        self.centers = self.centers.to(self.device)
        """Performs a batched update of M."""
        K = self.kernel_M(samples, self.centers)
        if p_batch_size is None:
            p_batch_size = self.centers.shape[0]

        dist = classic_kernel.euclidean_distances_M(samples, self.centers, self.M, squared=False)
        dist = torch.where(dist < 1e-10, torch.zeros(1, device=dist.device).float(), dist)

        K.div_(dist)
        del dist
        K[K == float("Inf")] = 0.0

        p, d = self.centers.shape
        p, c = self.weights.shape
        n, d = samples.shape

        samples_term = (K @ self.weights).reshape(n, c, 1)  # (n, p)  # (p, c)

        if self.diag:
            temp = 0
            for p_batch in torch.arange(p).split(p_batch_size):
                temp += K[:, p_batch] @ ( # (n, len(p_batch))
                    self.weights[p_batch,:].view(len(p_batch), c, 1) * (self.centers[p_batch,:] * self.M).view(len(p_batch), 1, d)
                ).reshape(
                    len(p_batch), c * d
                )  # (len(p_batch), cd)

            centers_term = temp.view(n, c, d)

            samples_term = samples_term * (samples * self.M).reshape(n, 1, d)

        else:
            temp = 0
            for p_batch in torch.arange(p).split(p_batch_size):
                temp += K[:, p_batch] @ ( # (n, len(p_batch))
                    self.weights[p_batch,:].view(len(p_batch), c, 1) * (self.centers[p_batch,:] @ self.M).view(len(p_batch), 1, d)
                ).reshape(
                    len(p_batch), c * d
                )  # (len(p_batch), cd)

            centers_term = temp.view(n, c, d)

            samples_term = samples_term * (samples @ self.M).reshape(n, 1, d)

        G = (centers_term - samples_term) / self.bandwidth  # (n, c, d)

        del centers_term, samples_term, K

        if self.centering:
            G = G - G.mean(0) # (n, c, d)

        # return quantity to be added to M. Division by len(samples) will be done in parent function.
        if self.diag:
            return torch.einsum('ncd, ncd -> d', G, G)
        else:
            return torch.einsum("ncd, ncD -> dD", G, G)


class Quadratic_AGOP_RFM(GenericRFM):
    """
    AGOP-RFM with the quadratic kernel k(x,y,M) = (x^T M y)^2.
    (Note: This is not the centered version of the quadratic kernel, which would be k(x,y,M) = ((x-y)^T M (x-y))^2.)
    """

    def __init__(self, bandwidth=1., **kwargs):
        super().__init__(**kwargs)
        self.bandwidth = bandwidth

    def kernel_M(self, samples, centers):
        return classic_kernel.quadratic_kernel_L_M(samples, centers, self.bandwidth, self.M)

    def update_M(self, samples, p_batch_size=None):
        samples = samples.to(self.device)
        self.centers = self.centers.to(self.device)
        if p_batch_size is not None:
            raise NotImplementedError("p_batch_size not yet implemented")

        if self.diag:
            raise NotImplementedError("Diagonal M not yet implemented for Quadratic kernel")
        else:
            # p, d = self.centers.shape
            # p, c = self.weights.shape
            # n, d = samples.shape

            xjTMxi = samples @ self.M @ self.centers.T  # (n, p)
            Mxi = self.centers @ self.M # (p, d)
            G = torch.einsum('pc,np,pd->ncd',self.weights, xjTMxi, Mxi) * (2/self.bandwidth**2)

        if self.centering:
            G = G - G.mean(0) # (n, c, d)

        # return quantity to be added to M. Division by len(samples) will be done in parent function.
        if self.diag:
            return torch.einsum('ncd, ncd -> d', G, G)
        else:
            return torch.einsum("ncd, ncD -> dD", G, G)


class Quadratic_AGOP_RFM_Naive(GenericRFM):
    """
    AGOP-RFM with the quadratic kernel k(x,y,M) = (x^T M y)^2.
    (Note: This is not the centered version of the quadratic kernel, which would be k(x,y,M) = ((x-y)^T M (x-y))^2.)

    NAIVE implementation to closely match with wAGOP implementation -- for debugging
    """

    def __init__(self, bandwidth=1., **kwargs):
        super().__init__(**kwargs)
        self.bandwidth = bandwidth

    def kernel_M(self, samples, centers):
        return classic_kernel.quadratic_kernel_L_M(samples, centers, self.bandwidth, self.M)

    def update_M(self, samples, p_batch_size=None):
        samples = samples.to(self.device)
        self.centers = self.centers.to(self.device)

        if p_batch_size is not None:
            raise NotImplementedError("p_batch_size not yet implemented")

        if self.centering:
            raise NotImplementedError("Centering not yet implemented for Quadratic wAGOP method")

        if samples.shape[0] != self.centers.shape[0]:
            raise ValueError("Number of samples must equal centers in current implementation of Quadratic wAGOP")

        # print(self.centers - samples)
        kp_ij = samples @ self.M @ self.centers.T * 2/self.bandwidth**2 # (n,n)
        alphamul = torch.einsum('ic,jc->ij',self.weights,self.weights)
        multiplier = torch.einsum('ij,ik,jk->ij', alphamul, kp_ij, kp_ij)
        # multiplier = alphamul * kp_ij
        Mx = self.centers @ self.M
        agop = torch.einsum('ij,id,je->de',multiplier,Mx,Mx)

        if self.diag:
            agop = torch.diag(agop)
        return agop

class Quadratic_wAGOP_RFM(GenericRFM):
    """
    AGOP-RFM with the quadratic kernel k(x,y,M) = (x^T M y)^2.
    (Note: This is not the centered version of the quadratic kernel, which would be k(x,y,M) = ((x-y)^T M (x-y))^2.)
    """

    def __init__(self, bandwidth=1., **kwargs):
        super().__init__(**kwargs)
        self.bandwidth = bandwidth

    def kernel_M(self, samples, centers):
        return classic_kernel.quadratic_kernel_L_M(samples, centers, self.bandwidth, self.M)

    def update_M(self, samples, p_batch_size=None):
        samples = samples.to(self.device)
        self.centers = self.centers.to(self.device)

        if p_batch_size is not None:
            raise NotImplementedError("p_batch_size not yet implemented")

        if self.centering:
            raise NotImplementedError("Centering not yet implemented for Quadratic wAGOP method")

        if samples.shape[0] != self.centers.shape[0]:
            raise ValueError("Number of samples must equal centers in current implementation of Quadratic wAGOP")

        # if wagop:
        kp_ij = samples @ self.M @ samples.T * 2/self.bandwidth**2 # (n,n)
        # if agop:
        #     Sigma = samples.T @ samples / samples.shape[0]
        #     kp_ij = samples @ self.M @ Sigma @ self.M.T @ samples.T * 2/self.bandwidth**2 # (n,n)

        alphamul = self.weights @ self.weights.T
        kp_ij = kp_ij * alphamul
        Mx = samples @ self.M
        wagop = Mx.T @ kp_ij @ Mx
        if self.diag:
            wagop = torch.diag(wagop)
        return wagop



class Quadratic_AGOP_RFM_TestAlternate(GenericRFM):
    """
    AGOP-RFM with the quadratic kernel k(x,y,M) = (x^T M y)^2.
    (Note: This is not the centered version of the quadratic kernel, which would be k(x,y,M) = ((x-y)^T M (x-y))^2.)

    NAIVE implementation to closely match with wAGOP implementation -- for debugging
    """

    def __init__(self, bandwidth=1., **kwargs):
        super().__init__(**kwargs)
        self.bandwidth = bandwidth

    def kernel_M(self, samples, centers):
        return 3 * ((samples @ self.M) @ centers.T)**2

    def update_M(self, samples, p_batch_size=None):
        samples = samples.to(self.device)
        self.centers = self.centers.to(self.device)

        if p_batch_size is not None:
            raise NotImplementedError("p_batch_size not yet implemented")

        if self.centering:
            raise NotImplementedError("Centering not yet implemented for Quadratic wAGOP method")

        if samples.shape[0] != self.centers.shape[0]:
            raise ValueError("Number of samples must equal centers in current implementation of Quadratic wAGOP")

        X = samples
        x = self.centers
        P = self.M
        sol = self.weights.T
        batch_size = 2
        diag_only=self.diag
        return_per_class_agop=False
        centering=self.centering
        M = 0.

        start = time.time()

        K = 3 * 2 * (X @ P @ x.T)**1
        a1 = sol.T
        n, d = X.shape
        n, c = a1.shape
        m, d = x.shape

        a1 = a1.reshape(n, c, 1)
        X1 = (X @ P).reshape(n, 1, d)
        step1 = a1 @ X1
        del a1, X1
        step1 = step1.reshape(-1, c*d)

        step2 = K.T @ step1
        del step1

        G = step2.reshape(-1, c, d)

        if centering:
            G_mean = torch.mean(G, axis=0).unsqueeze(0)
            G = G - G_mean
        M = 0.

        bs = batch_size
        batches = torch.split(G, bs)
        for i in range(len(batches)):
            if torch.cuda.is_available():
                grad = batches[i].cuda()
            else:
                grad = batches[i]

            gradT = torch.transpose(grad, 1, 2)
            if diag_only:
                T = torch.sum(gradT * gradT, axis=-1)
                M += torch.sum(T, axis=0).cpu()
            else:
                #gradT = torch.swapaxes(grad, 1, 2)#.cuda()
                M += torch.sum(gradT @ grad, dim=0).cpu()
            del grad, gradT
        torch.cuda.empty_cache()
        M /= len(G)
        if diag_only:
            M = torch.diag(M)

        M = M.numpy()

        end = time.time()

        return torch.from_numpy(M).to(self.device)



class Quadratic_wAGOP_sqsim_RFM(GenericRFM):
    """
    AGOP-RFM with the quadratic kernel k(x,y,M) = (x^T M y)^2.
    (Note: This is not the centered version of the quadratic kernel, which would be k(x,y,M) = ((x-y)^T M (x-y))^2.)
    """

    def __init__(self, bandwidth=1., **kwargs):
        super().__init__(**kwargs)
        self.bandwidth = bandwidth

    def kernel_M(self, samples, centers):
        return classic_kernel.quadratic_kernel_L_M(samples, centers, self.bandwidth, self.M)

    def update_M(self, samples, p_batch_size=None):
        samples = samples.to(self.device)
        self.centers = self.centers.to(self.device)

        if p_batch_size is not None:
            raise NotImplementedError("p_batch_size not yet implemented")

        if self.centering:
            raise NotImplementedError("Centering not yet implemented for Quadratic wAGOP method")

        if samples.shape[0] != self.centers.shape[0]:
            raise ValueError("Number of samples must equal centers in current implementation of Quadratic wAGOP")

        kp_ij = samples @ self.M @ samples.T
        kp_ij = kp_ij * 2/self.bandwidth**2 # (n,n)
        kp_ij = kp_ij * kp_ij # square similarity to ensure nonnegative
        alphamul = self.weights @ self.weights.T
        kp_ij = kp_ij * alphamul
        Mx = samples @ self.M
        wagop = Mx.T @ kp_ij @ Mx
        if self.diag:
            wagop = torch.diag(wagop)
        return wagop

class Quadratic_wAGOP_powered_RFM(GenericRFM):
    """
    AGOP-RFM with the quadratic kernel k(x,y,M) = (x^T M y)^2.
    (Note: This is not the centered version of the quadratic kernel, which would be k(x,y,M) = ((x-y)^T M (x-y))^2.)
    """

    def __init__(self, Mpow=1, bandwidth=1., **kwargs):
        super().__init__(**kwargs)
        self.bandwidth = bandwidth
        self.Mpow = Mpow

    def kernel_M(self, samples, centers):
        return classic_kernel.quadratic_kernel_L_M(samples, centers, self.bandwidth, self.M)

    def update_M(self, samples, p_batch_size=None):
        samples = samples.to(self.device)
        self.centers = self.centers.to(self.device)

        if p_batch_size is not None:
            raise NotImplementedError("p_batch_size not yet implemented")

        if self.centering:
            raise NotImplementedError("Centering not yet implemented for Quadratic wAGOP method")

        if samples.shape[0] != self.centers.shape[0]:
            raise ValueError("Number of samples must equal centers in current implementation of Quadratic wAGOP")

        # print(self.centers - samples)
        Mpowed = self.M.clone().detach()
        Sigma = samples.T @ samples / samples.shape[0]
        if self.Mpow == 2:
            Mpowed = Mpowed.T @ Mpowed
            # Mpowed = Mpowed @ Sigma @ Mpowed.T
        elif self.Mpow == 1:
            pass
        elif self.Mpow == 4:
            Mpowed = Mpowed.T @ Mpowed
            Mpowed = Mpowed.T @ Mpowed
        else:
            assert(False), "Mpow not implemented"
        # for _ in range(self.Mpow-1):
        #     Mpowed = 2*Mpowed @ self.M.T
        kp_ij = samples @ Mpowed @ self.centers.T * 2/self.bandwidth**2 # (n,n)
        alphamul = torch.einsum('ic,jc->ij',self.weights,self.weights)
        kp_ij = kp_ij * alphamul
        Mx = self.centers @ self.M
        wagop = torch.einsum('ij,id,je->de',kp_ij,Mx,Mx)
        if self.diag:
            wagop = torch.diag(wagop)
        return wagop

class Quadratic_AGOP_surrogate_RFM(GenericRFM):
    """
    AGOP-RFM with the quadratic kernel k(x,y,M) = (x^T M y)^2.
    (Note: This is not the centered version of the quadratic kernel, which would be k(x,y,M) = ((x-y)^T M (x-y))^2.)
    """

    def __init__(self, bandwidth=1., Sigma_surrogate=None, **kwargs):
        super().__init__(**kwargs)
        self.bandwidth = bandwidth
        self.Sigma_surrogate = Sigma_surrogate

    def kernel_M(self, samples, centers):
        return classic_kernel.quadratic_kernel_L_M(samples, centers, self.bandwidth, self.M)

    def update_M(self, samples, p_batch_size=None):
        samples = samples.to(self.device)
        self.centers = self.centers.to(self.device)

        if p_batch_size is not None:
            raise NotImplementedError("p_batch_size not yet implemented")

        if self.centering:
            raise NotImplementedError("Centering not yet implemented for Quadratic wAGOP method")

        if samples.shape[0] != self.centers.shape[0]:
            raise ValueError("Number of samples must equal centers in current implementation of Quadratic wAGOP")

        kp_ij = samples @ self.M.T @ self.Sigma_surrogate @ self.M @ self.centers.T * 2/self.bandwidth**2 # (n,n)
        alphamul = torch.einsum('ic,jc->ij',self.weights,self.weights)
        kp_ij = kp_ij * alphamul
        Mx = self.centers @ self.M
        wagop = torch.einsum('ij,id,je->de',kp_ij,Mx,Mx)
        if self.diag:
            wagop = torch.diag(wagop)
        return wagop
