'''Construct kernel model with EigenPro optimizer.'''
import collections
import time
import torch
import random
from sklearn.utils import gen_batches

import torch.nn as nn
import numpy as np

from .utils import svd
from .utils import float_x
import ipdb
import pickle
from .utils import calculate_time
from timeit import default_timer as timer
from sklearn.utils import gen_batches

def asm_eigenpro_fn(samples, map_fn, top_q, bs_gpu, alpha, min_q=5, seed=1):
    """Prepare gradient map for EigenPro and calculate
    scale factor for learning ratesuch that the update rule,
        p <- p - eta * g
    becomes,
        p <- p - scale * eta * (g - eigenpro_fn(g))

    Arguments:
        samples:	matrix of shape (n_sample, n_feature).
        map_fn:    	kernel k(samples, centers) where centers are specified.
        top_q:  	top-q eigensystem for constructing eigenpro iteration/kernel.
        bs_gpu:     maxinum batch size corresponding to GPU memory.
        alpha:  	exponential factor (<= 1) for eigenvalue rescaling due to approximation.
        min_q:  	minimum value of q when q (if None) is calculated automatically.
        seed:   	seed for random number generation.

    Returns:
        eigenpro_fn:	tensor function.
        scale:  		factor that rescales learning rate.
        top_eigval:  	largest eigenvalue.
        beta:   		largest k(x, x) for the EigenPro kernel.
    """

    np.random.seed(seed)  # set random seed for subsamples
    start = time.time()
    n_sample, _ = samples.shape

    if top_q is None:
        svd_q = min(n_sample - 1, 1000)
    else:
        svd_q = top_q

    eigvals, eigvecs = svd.nystrom_kernel_svd(samples, map_fn, svd_q)

    # Choose k such that the batch size is bounded by
    #   the subsample size and the memory size.
    #   Keep the original k if it is pre-specified.
    if top_q is None:
        max_bs = min(max(n_sample / 5, bs_gpu), n_sample)
        top_q = np.sum(np.power(1 / eigvals, alpha) < max_bs) - 1
        top_q = max(top_q, min_q)

    eigvals, tail_eigval = eigvals[:top_q - 1], eigvals[top_q - 1]
    eigvecs = eigvecs[:, :top_q - 1]

    device = samples.device
    eigvals_t = torch.tensor(eigvals.copy()).to(device)
    eigvecs_t = torch.tensor(eigvecs).to(device)
    tail_eigval_t = torch.tensor(tail_eigval, dtype=torch.float).to(device)

    scale = float_x(np.power(eigvals[0] / tail_eigval, alpha))
    diag_t = (1 - torch.pow(tail_eigval_t / eigvals_t, alpha)) / eigvals_t

    def eigenpro_fn(grad, kmat):
        '''Function to apply EigenPro preconditioner.'''
        return torch.mm(eigvecs_t * diag_t,
                        torch.t(torch.mm(torch.mm(torch.t(grad),
                                                  kmat),
                                         eigvecs_t)))

    print("SVD time: %.2f, top_q: %d, top_eigval: %.2f, new top_eigval: %.2e" %
          (time.time() - start, top_q, eigvals[0], eigvals[0] / scale))

    knorms = 1 - np.sum(eigvecs ** 2, axis=1) * n_sample
    beta = np.max(knorms)
    beta=1

    return eigenpro_fn, scale, eigvals[0], float_x(beta), eigvals, eigvecs


class HilbertProjection(nn.Module):
    '''Fast Kernel Regression using EigenPro iteration.'''

    def __init__(self, kernel_fn, centers, y_dim, device="cuda", weight_init=None,wandb=None):
        super().__init__()
        self.kernel_fn = kernel_fn
        self.n_centers, self.x_dim = centers.shape
        self.device = device
        self.pinned_list = []
        self.eigenpro_f = None
        self.precond_verbose = True

        self.wandb_run = wandb
        self.centers = self.tensor(centers, release=True)
        self.kzz = self.kernel_fn(self.centers,self.centers).to(self.device)

        self.mse_error = torch.tensor(100_000, device=self.device)


        if weight_init is not None:
            self.weight = self.tensor(weight_init, release=True)
        else:
            self.weight = self.tensor(torch.zeros(
                self.n_centers, y_dim), release=True)
        self.weight_decay = None

        self.time_track_dict = {'fit_hilbert_projection': [0, 0]}

    def __del__(self):
        for pinned in self.pinned_list:
            _ = pinned.to("cpu")
        torch.cuda.empty_cache()

    def tensor(self, data, dtype=None, release=True):
        if torch.is_tensor(data):
            tensor = data.detach().clone().to(dtype=dtype, device=self.device)
        else:
            tensor = torch.tensor(data, dtype=dtype,
                              requires_grad=False).to(self.device)
        if release:
            self.pinned_list.append(tensor)
        return tensor

    def kernel_matrix(self, samples,ids):
        return self.kzz[ids,:]

    def forward(self, samples,ids, weight=None):
        if weight is None:
            weight = self.weight
        kmat = self.kzz[ids,:]
        pred = kmat.mm(weight)
        return pred

    def get_predictions(self):
        predictions = torch.zeros_like(self.weight)

        permutation = torch.randperm(self.n_centers, device=self.device) 

        return self.kzz @ self.weight


    def primal_gradient(self, samples, labels,batch_ids, weight):
        pred = self.forward(samples,batch_ids, weight)
        grad = pred - labels
        return grad

    @staticmethod
    def _compute_opt_params(bs, bs_gpu, beta, top_eigval):
        if bs is None:
            bs = min(np.int32(beta / top_eigval + 1), bs_gpu)

        if bs < beta / top_eigval + 1:
            eta = bs / beta /2
        else:
            eta = 0.99 * 1 * bs / (beta + (bs - 1) * top_eigval)
        return bs, float_x(eta)

    def eigenpro_iterate(self, z_batch, gz_batch, eta, batch_ids):
        # update random coordiate block (for mini-batch)
        grad = self.primal_gradient(z_batch, gz_batch,batch_ids, self.weight)
        self.weight.index_add_(0, batch_ids, -eta * grad)

        # update fixed coordinate block (for EigenPro)
        # ipdb.set_trace()
        kmat = self.kernel_fn(z_batch, self.nystrom_samples)
        correction = self.eigenpro_f(grad, kmat)
        self.weight.index_add_(0, self.nystrom_ids, eta * correction)
        self.weight.mul_(1 - eta * self.weight_decay)
        return

    def evaluate(self, x_eval, y_eval, bs,
                 metrics=('mse', 'multiclass-acc'),
                 clf_threshold=None, bayes_opt=None):
        p_list = []
        n_sample, _ = x_eval.shape

        for batch_ids in gen_batches(n_sample,bs):

            z_batch = x_eval[batch_ids]
            p_batch = self.forward(z_batch,batch_ids) 
            p_list.append(p_batch)

        p_eval = torch.cat(p_list,dim=0)

        eval_metrics = collections.OrderedDict()
        if 'mse' in metrics:
            eval_metrics['mse'] = torch.mean(torch.square(p_eval - y_eval))
        return eval_metrics


    def setup_preconditioner(self, *args):
        (self.eigenpro_f, self.gap, self.top_eigval,
         self.beta, self.eigvals, self.eigvecs) = asm_eigenpro_fn(*args)
        self.new_top_eigval = self.top_eigval / self.gap

    def fit_batch(self, z_batch, gz_batch, eta, batch_ids):



        self.eigenpro_iterate(z_batch, gz_batch, eta, batch_ids)

    @calculate_time
    def fit_hilbert_projection(
        self, z_train, gz_train, max_epochs=200, mem_gb=12,
        x_val=None, y_val=None, cutoff=1e-5, weight_decay=None,
        n_nystrom_subsamples=None, top_q=None, bs=None, eta=None,
        n_train_eval=5000, run_epoch_eval=True, scale=1, seed=1,
        clf_threshold=0.5, bayes_opt=None, metrics=['mse'], return_log=True
    ):
        tstart = timer()#time.time()
        self.weight_decay = 0.0 if weight_decay is None else weight_decay
        # self.weight = self.weight * 0
        n_samples, n_labels = gz_train.shape

        # Calculate batch size / learning rate for improved EigenPro iteration.
        if self.eigenpro_f is None:

            if n_nystrom_subsamples is None:
                if n_samples < 100000:
                    n_nystrom_subsamples = min(n_samples, 2000)
                else:
                    n_nystrom_subsamples = 12000

            # n_nystrom_subsamples = 10_000

            mem_bytes = (mem_gb - 1) * 1024 ** 3  # preserve 1GB
            bsizes = np.arange(n_samples)
            mem_usages = ((self.x_dim + 3 * n_labels + bsizes + 1)
                          * self.n_centers + n_nystrom_subsamples * 1000) * 4
            bs_gpu = np.sum(mem_usages < mem_bytes)  # device-dependent batch size

            np.random.seed(seed)
            sample_ids = np.random.choice(n_samples, n_nystrom_subsamples, replace=False)
            self.nystrom_ids = self.tensor(sample_ids).long()
            self.nystrom_samples = self.centers[self.nystrom_ids]
            self.setup_preconditioner(self.nystrom_samples, self.kernel_fn, top_q, bs_gpu, .95)
            if eta is None:
                self.bs, self.eta = self._compute_opt_params(
                    bs, bs_gpu, self.beta, self.new_top_eigval)
            else:
                self.bs, _ = self._compute_opt_params(bs, bs_gpu, self.beta, self.new_top_eigval)

            if self.precond_verbose:
                print("Projection: Nystrom size=%d, bs_gpu=%d, eta=%.2f, bs=%d, top_eigval=%.2e, beta=%.2f" %
                      (n_nystrom_subsamples, bs_gpu, self.eta, self.bs, self.top_eigval, self.beta))

            self.eta = self.tensor(scale * self.eta / self.bs, dtype=torch.double)

        z_train_eval, gz_train_eval = z_train[0:1000], gz_train[0:1000]
        start = time.time()
        epoch = 0
        self.mse_error = 10000

        print(f'cut_off is: {cutoff}')
        while self.mse_error> 10**-6:

            final_step = n_samples // self.bs

            permutation = torch.randperm(z_train.size()[0],device = self.device)#.to(self.device)
            step = 0

            for i in range(0,z_train.size()[0], int(self.bs)):

                batch_ids = permutation[i:i + int(self.bs)]

                self.fit_batch(
                    z_train[batch_ids], gz_train[batch_ids], self.eta, batch_ids
                )



                if step % 4==0 or step == final_step: #or self.mse_error < max(cutoff,10**-4):
                    train_sec = time.time() - start

                    tr_score = self.evaluate(
                        z_train_eval, gz_train_eval, self.bs, clf_threshold=clf_threshold,
                        bayes_opt=bayes_opt, metrics=metrics
                    )


                    self.mse_error = tr_score["mse"]

                    print(f'Proj: {epoch} epochs,{step} step, {train_sec:.1f}s\t', end='')
                    for metric in metrics:
                        print(f'Proj: train {metric}: {tr_score[metric]:.10f} ', end='')
                    print()


                step += 1
                if self.mse_error < 10**-6:#max(cutoff,10**-7):#cutoff:#max(cutoff,10**-6):
                    print(f'Proj: {epoch} epochs,{step} step, {train_sec:.1f}s\t', end='')
                    for metric in metrics:
                        print(f'Proj: train {metric}: {tr_score[metric]:.10f} ', end='')
                    print()
                    break

            epoch = epoch + 1

        predictions = self.get_predictions()


        return (self.weight, predictions) 
