import torch
from .utils import unitnorm_svd,calculate_time
import ipdb
from sys import getsizeof


class GradientCorrection:

    @calculate_time
    def __init__(
            self, kernel_fn, nystrom_samples, knots, device, level,
            batch_size=100, learning_rate_prefactor=0.1,wandb_run = None,verbose=False):
        print("Initilizing data precondtioner...")
        self.device=device
        self.kernel_fn = kernel_fn
        self.nystrom_samples = nystrom_samples
        self.nystrom_samples_all = []
        nystrom_size = self.nystrom_samples.shape[0]
        self.Lam_x, self.E_x = unitnorm_svd.nystrom_kernel_svd(
            self.nystrom_samples.to('cpu'),
            self.kernel_fn, level + 1
        )
        self.tail_eig_x = self.Lam_x[level]
        self.Lam_x = self.Lam_x[:level]
        self.E_x = self.E_x[:, :level]
        self.D_x = (1 - self.tail_eig_x / self.Lam_x) / self.Lam_x / nystrom_size

        self.batch_size = batch_size
        self.wandb_run = wandb_run


        knorms = 1 - torch.sum(self.E_x ** 2, axis=1)
        beta = torch.max(knorms)
        print(f'beta is:{beta}')
        beta = 1
        print(f'beta is:{beta}')
        self.batch_size = int(1/self.tail_eig_x/len(self.device))
        if self.batch_size*len(self.device) < beta / self.tail_eig_x + 1:
            print("learning 1st branch")
            self.lr = self.batch_size*len(self.device) / beta/(2)
        else:
            print("learning 2nd branch")
            self.lr = learning_rate_prefactor  * self.batch_size*len(self.device) / (beta + (self.batch_size*len(self.device) - 1) * self.tail_eig_x)

        self.E_x_all = []
        self.D_x_all = []
        self.preconditioner_matrix = []
        for ind,g in enumerate(self.device):
            self.nystrom_samples_all.append(self.nystrom_samples.to(g))
            self.E_x_all.append(self.E_x.to(g))

            Kmat_xs_z = self.kernel_fn(self.nystrom_samples.to(g),knots[ind])
            self.preconditioner_matrix.append(Kmat_xs_z.to(g).T @ (self.D_x.to(g) * self.E_x.to(g) ))
        self.time_track_dict = {'precondition_kmat_xs_xbatch':[0,0],
                                'Kmat_xs_xbatch_eval': [0, 0], 'Kmat_xs_xbatch_mult': [0, 0]}

        print(f'outer: learning rate: {self.lr},batch_size={self.batch_size}, top eigenvalue:{self.Lam_x[0]},'
              f' new top eigenvalue:{self.tail_eig_x}')
        print("Data preconditioner is ready.")
    @calculate_time
    def Kmat_xs_xbatch_eval(self,X_batch,ind):
        return self.kernel_fn(self.nystrom_samples_all[ind], X_batch)# C1

    @calculate_time
    def __call__(self, grad,Kmat_xs_xbatch,ind):#preconditioned_kmat_xs_xbatch):
        return self.preconditioner_matrix[ind] @ ( (self.E_x_all[ind].T @ Kmat_xs_xbatch) @ grad )

if __name__ == "__main__":

    from .utils.setup_utils import RectangleKernelRegressionProblem
    from .utils.kernel import gaussian
    import numpy as np

    kernel_fn = lambda x, y: gaussian(x, y, bandwidth=1)
    syn_data = RectangleKernelRegressionProblem(d=2, model_size=10)
    nystrom_ids = np.random.choice(range(len(syn_data.x_train)), 100, replace=False)
    knot_ids = np.random.choice(range(len(syn_data.x_train)), syn_data.model_size, replace=False)
    nystrom_samples = syn_data.x_train[nystrom_ids]
    syn_data.knots = syn_data.x_train[knot_ids]
    syn_data.compute_kernel_matrices()
    gradient_correction_fn = GradientCorrection(
        kernel_fn=kernel_fn,
        nystrom_samples=nystrom_samples, level=2,
        model=syn_data
    )
    batch_ids = np.random.choice(range(len(syn_data.x_train)), 10, replace=False)
    X_batch = syn_data.x_train[batch_ids]
    grad = syn_data.Kmat_xz[batch_ids] @ torch.randn(len(batch_ids), syn_data.model_size)
    h = gradient_correction_fn(X_batch, grad)
