import torch
import torch.nn as nn
import numpy as np
from .utils import weighted_max, batch_subset_kernel

# Pytorch Implementation of https://arxiv.org/abs/1605.07583
# Original Matlab implementation https://github.com/cnmusco/recursive-nystrom
def recursiveNystrom(X, kernel_obj, samples, inverse=False, epsilon=1e-7):
    n = X.size(0)
    if samples >= n:
        raise ValueError("Samples greater than or equal to # of rows. Use direct kernel instead.")
    kernel_batch_size = 5000
    # import pdb; pdb.set_trace()

    oversamp = np.log(samples)
    k = int(np.ceil(samples / (4 * oversamp)))
    nLevels = int(np.ceil(np.log(n / samples) / np.log(2)))

    perm = torch.randperm(n).to(X.device)

    lSize = [n]
    for i in range(1, nLevels + 1):
        lSize.append(int(np.ceil(lSize[-1] / 2)))
    
    samp = torch.arange(lSize[-1], dtype=torch.long).to(X.device)
    rInd = perm[:lSize[-1]]
    with torch.no_grad():
        weights = torch.ones(len(rInd),).to(X.device)
        kDiag = kernel_obj.diag(X)

        for l in reversed(range(0, nLevels)):
            rIndCurr = perm[:lSize[l]]
            KS = batch_subset_kernel(X, kernel_obj, rIndCurr, rInd, kernel_batch_size)
            SKS = KS[samp, :]
            SKSn = SKS.size(0)

            if (k >= SKSn):
                _lambda = 1e-6
            else:
                eigs = (SKS * weights.unsqueeze(1) * weights.unsqueeze(0)).eig()[0]
                top_eig_sum = eigs[:, 0].abs().sort()[0][:k].sum()
                _lambda = ((SKS.diag() * weights).pow(2).sum() - top_eig_sum) / k
                del eigs
        
            R = (SKS + (_lambda * weights.pow(-2)).diag_embed()).inverse()
            if l > 0:
                levs = (
                    oversamp * (1 / _lambda) * 
                    torch.clamp_min((kDiag[rIndCurr] - (KS.mm(R) * KS).sum(-1)), 0.)
                ).clamp_max(1.)
                samp = (torch.rand(lSize[l],).to(levs.device) < levs).nonzero().squeeze()

                if samp.numel() == 0:
                    levs.fill_(samples / lSize[l])
                    samp = torch.randperm(lSize[l])[:samples].to(levs.device)
                weights = (1 / levs[samp]).sqrt()
            else:
                levs = (
                    (1 / _lambda) * 
                    torch.clamp_min((kDiag[rIndCurr] - (KS.mm(R) * KS).sum(-1)), 0.)
                ).clamp_max(1.)
                samp = np.random.choice(np.arange(n), samples, replace=False, p=weighted_max(levs).cpu().numpy())
            rInd = perm[samp]

            del KS, SKS, R, levs
            torch.cuda.empty_cache()

        C = batch_subset_kernel(X, kernel_obj, np.arange(n), rInd, kernel_batch_size)
        W = C[rInd, :]
        if not inverse:
            W = (W + len(rInd) * epsilon * torch.eye(samples).to(W.device)).inverse()
        else:
            C = C.pinverse().t()
    return C, W, rInd

class Nystrom_PerronFreboniusOperator(nn.Module):
    def __init__(self, x, y, kernel_obj, samples, epsilon=1e-7):
        super(Nystrom_PerronFreboniusOperator, self).__init__()
        self.KSxx, self.SKSxx, self.x_nystrom_point_inds = [nn.Parameter(t, requires_grad=False) for t in recursiveNystrom(x, kernel_obj, samples, inverse=True, epsilon=epsilon)]
        self.KSyy, self.SKSyy, self.y_nystrom_point_inds = [nn.Parameter(t, requires_grad=False) for t in recursiveNystrom(y, kernel_obj, samples, epsilon=epsilon)]

        # reconstructKxx = self.KSxx.mm(self.SKSxx).mm(self.KSxx.t())
        # realKxx = kernel_obj(x, x).inverse()
        # print(((realKxx - reconstructKxx) ** 2).mean())

        # reconstructKyy = self.KSyy.mm(self.SKSyy).mm(self.KSyy.t())
        # realKyy = kernel_obj(y, y)
        # print(((realKyy - reconstructKyy) ** 2).max())
    
    def forward(self, k_prime):
        out = self.KSxx.t().mm(k_prime)
        out = self.SKSxx.mm(out)
        out = self.KSxx.mm(out)

        out = self.KSyy.t().mm(out)
        out = self.SKSyy.mm(out)
        out = self.KSyy.mm(out)
        return out
