
import logging
import pdb

import numpy as np
import torch
import tqdm
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

logger = logging.getLogger(__name__)

def intdiv_ceil(dividend, divisor):
    """Integer division of dividend / divisor, but round the result upward.
    This is different than simply wrapping float division in ceil() if the
    numbers are too big to be accurately represented as floats."""
    return (dividend + divisor - 1) // divisor

def _calc_dataset_grad(model_loss_fn, model_parameters, dataset):
    # TODO: Make batch size configurable (but be careful about last batch
    # getting truncated if bs>1)
    batch_size = 1
    num_batches = len(dataset) // batch_size

    sampler = SequentialSampler(dataset)
    dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size)
    grad = [torch.zeros(p.data.shape, device=model_parameters[0].device) for p in model_parameters]
    for batch in tqdm.tqdm(dataloader, desc='Calc dataset grad'):
        loss = model_loss_fn(batch) / num_batches
        #loss.backward()
        new_grad = torch.autograd.grad(loss, model_parameters, only_inputs=True)

        for i in range(len(grad)):
            grad[i].add_(new_grad[i] / num_batches)

    return grad
    #return [p.grad.cpu() for p in model_parameters]

class KohInfluenceFunction:
    """Computes influence functions based on the paper: "Understanding
    Black-box Predictions via Influence Functions" by Koh and Liang 2017
    (arxiv:1703.04730v2).

    Basic usage of this class is as follows:
        inf_func = ModelInfluenceFunction(model_loss_fn, model_parameters, train_dataset, test_dataset)
        influences = inf_func.calc_influence(query_example_dataset)

    After running the above, influences is a list such that influences[i] is
    the influence of the example query_example_dataset[i] on the test dataset
    with respect to the given model and train dataset.

    NOTE: This class must be reconstructed if the model or train/test datasets
    change. If this class is constructed and some of the data passed to the
    constructor is then modified (e.g., the model parameter values are
    updated), the results may not be correct."""

    def __init__(self, model_loss_fn, model_parameters, train_dataset, test_dataset, config=None, rng=None):
        """Constructor.

        model_loss_fn is a function that takes a batch of examples as input and
        returns the model loss. The model loss should be a pytorch tensor that
        can be differentiated with respect to the model_parameters. The input
        batch will be passed to the model in the same format that would be
        returned from a torch.DataLoader initialized from the train/test/query
        dataset passed to the calc_s_test and calc_influences methods of this
        class. That is, assuming dataset[i] returns a tuple of tensors,
        model_loss_fn will be passed a tuple where each element j is
        concat(dataset[i][j], dataset[i2][j], ...).

        model_parameters is a list of model parameter tensors used to compute
        model_loss_fn (e.g., the `list(model.parameters())`)

        train_dataset and test_dataset are torch `TensorDataset`s that contain the
        train examples for which to compute the Hessian and the test examples that
        the influence will be computed for, respectively.

        config['hivp_r'] is the r parameter from Section 3 "Stochastic
        Estimation" from the Koh and Liang paper (i.e., the number of estimates
        of the hessian-inverse vector product (HIVP) to take, which are
        averaged to approximate the "true" HIVP). Defaults to 1.

        config['hivp_t'] is the t parameter from the Koh and Liang paper (the
        number of iterations to use when approximating the HIVP). Defaults to
        the number of examples in the train_dataset.

        config['hivp_scale'] is an intermediate scaling value used to ensure
        the norm of the hessian is 1 during calculation of the inverse
        (otherwise would not converge). The hessian is *divided* by this value,
        so larger values increase the chance of convergence (but may slow down
        the rate of convergence or lead to floating-point instability).
        Defaults to 1.

        config['hivp_damping'] is the damping amount for the HIVP calculation.
        Defaults to 0.01.

        rng is a numpy.random.Generator object. If None, constructs a new
        generator."""
        self._config = config
        if self._config is None:
            self._config = dict()

        self._model_loss_fn = model_loss_fn
        self._model_parameters = model_parameters
        self._hivp_r = self._config.get('hivp_r', 1)
        self._hivp_t = self._config.get('hivp_t', len(train_dataset))
        self._hivp_scale = self._config.get('hivp_scale', 1)
        self._hivp_damping = self._config.get('hivp_damping', 0.01)
        self._rng = rng if rng is not None else np.random.default_rng(np.random.randint(0, 2**63))

        self._s_test = self._calc_s_test(train_dataset, test_dataset)

    def _calc_s_test(self, train_dataset, test_dataset):
        """Compute the variable s_test from the Koh and Liang paper. This is
        a Hessian-inverse vector product, where the Hessian comes from the
        model loss on the train set and the vector is the gradient of the model
        loss on the test set."""

        # Note: Using the algorithm described in Koh and Liang (sec 3
        # "stochastic estimation") here (a variant of LISSA). The "hivp" in
        # variable names refers to the hessian-inverse vector product H^-1*v.

        v = self._calc_dataset_grad(test_dataset)

        sampler = RandomSampler(train_dataset, replacement=True, num_samples=(self._hivp_r*self._hivp_t))
        dataloader = DataLoader(train_dataset, sampler=sampler, batch_size=8)
        dataiter = iter(dataloader)
        #dataloader = torch.utils.data.BatchSampler(sampler, batch_size=8, drop_last=False)

        for p in self._model_parameters:
            p.grad.zero_()
        for i in range(len(v)):
            v[i] = v[i].cpu()
        torch.cuda.empty_cache()

        #pdb.set_trace()
        estimates = []
        for _ in tqdm.trange(self._hivp_r, desc='HIVP r-loop'):
            hivp = [tensor.clone().cuda() for tensor in v]
            #for j, z in enumerate(tqdm.tqdm(dataloader, desc='HIVP t-loop')):
            for j in tqdm.trange(self._hivp_t, desc='HIVP t-loop'):
                for p in self._model_parameters:
                    p.grad.zero_()
                try:
                    z = next(dataiter)
                except StopIteration:
                    dataiter = iter(dataloader)
                    z = next(dataiter)
                #z = next(dataloader)
                #z = train_dataset[int(self._rng.integers(len(train_dataset)))]
                #z = tuple([tensor.unsqueeze(0).to(self._model_parameters[0].device) for tensor in z])
                h_times_old_hivp = self._hvp(fn=self._model_loss_fn, examples=z, vector_list=hivp, max_batch_size=1)

                with torch.no_grad():
                    # Dividing by hivp_scale here is to ensure we meet convergence
                    # conditions
                    h_times_old_hivp = [tensor / self._hivp_scale for tensor in h_times_old_hivp]
                    #print(hivp[20])

                    dp = 0
                    dhto = 0
                    dv = 0
                    dn = 0
                    for i in range(len(self._model_parameters)):
                        vcud = v[i].cuda()
                        htocud = h_times_old_hivp[i].cuda()
                        diff = (vcud - htocud)
                        hivp[i] = (1 - self._hivp_damping) * hivp[i].cuda() + diff
                        dp += (diff*diff).sum().item()
                        dhto += (htocud*htocud).sum().item()
                        dv += (vcud*vcud).sum().item()
                        dn += (vcud*htocud).sum().item()
                    dp = np.sqrt(dp)
                    dhto = np.sqrt(dhto)
                    dv = np.sqrt(dv)
                    dn = dn/dhto/dv
                    logger.info(f'dp: {dp:0.5f}, dv: {dv:0.5f}, dhto: {dhto:0.5f}, dn: {dn:0.5f}')


            # Append the last (best) hivp as the estimate
            estimates.append([tensor.cpu() for tensor in hivp])

        final_hivp = []
        for i in range(len(self._model_parameters)):
            final_estimate = estimates[0][i]
            for r_idx in range(1, self._hivp_r):
                final_estimate += estimates[r_idx][i]

            # We once again divide by hivp_scale here to "undo" the scaling
            # from earlier. I.e., our hivp is currently (H/scale)^-1*v, so it
            # is `scale` times too big, and dividing fixes that.
            final_hivp.append(((final_estimate / self._hivp_r) / self._hivp_scale).to(self._model_parameters[0].device))

        return final_hivp

    def _hvp(self, fn, examples, vector_list, max_batch_size=-1):
        """Calculate the hessian-vector product between the hessian of the
        given (pytorch-differentiable) function fn evaluated at the given
        example and the given vector_list.

        The reason for having vector_list instead of vector is that the model
        parameter "vector" is actually a list of parameter tensors, and rather
        than flattening these to a vector we do all computations in this form.
        Thus, vector_list is a list of tensors matching the shape of the model
        parameter tensors."""


        num_examples = len(examples[0])
        if max_batch_size <= 0:
            max_batch_size = num_examples
        num_batches = intdiv_ceil(num_examples, max_batch_size)

        # Grad accumulation over multiple batches, in case single batch doesn't
        # fit in memory
        #hvp = [torch.zeros(p.data.shape, device=self._model_parameters[0].device) for p in self._model_parameters]
        hvp = [torch.zeros(p.data.shape, device=torch.device("cpu")) for p in self._model_parameters]
        for batch_num in range(num_batches):
            batch = tuple([t[batch_num*max_batch_size:(batch_num+1)*max_batch_size].to(self._model_parameters[0].device) for t in examples])
            loss = fn(batch)
            grad = torch.autograd.grad(loss, self._model_parameters, create_graph=True, only_inputs=True)

            grad_vec_dotprod = 0
            for i in range(len(self._model_parameters)):
                grad_vec_dotprod += (grad[i] * vector_list[i].cuda()).sum() / num_batches
                #self._model_parameters[i].grad.zero_()

            #tmp_hvp = torch.autograd.grad(grad_vec_dotprod, self._model_parameters, only_inputs=True)
            grad_vec_dotprod.backward()
        for i in range(len(self._model_parameters)):
            #hvp[i] += tmp_hvp[i]
            hvp[i] += self._model_parameters[i].grad.cpu()

        return hvp

    def _calc_dataset_grad(self, dataset):
        return _calc_dataset_grad(self._model_loss_fn, self._model_parameters, dataset)

    def calc_influence_from_grad(self, grad):
        """Calculate influence given the gradient of loss wrt parameters on
        some particular example(s)"""

        influence = 0
        for i in range(len(grad)):
            influence += -(self._s_test[i] * grad[i]).sum()
        return influence

    def calc_influences(self, query_examples_dataset):
        """Calculate the influences of the examples in the query set on the
        test set with respect to the model and train dataset of this influence
        function."""

        # Important that batch size = 1 here, since we want to compute
        # influence individually for each example in the query dataset
        batch_size = 1

        sampler = SequentialSampler(query_examples_dataset)
        dataloader = DataLoader(query_examples_dataset, sampler=sampler, batch_size=batch_size)
        influences = []
        for batch in dataloader:
            loss = self._model_loss_fn(batch)
            grad = torch.autograd.grad(loss, self._model_parameters, only_inputs=True)

            influences.append(self.calc_influence_from_grad(grad))

        return influences

class DevgradInfluenceFunction:
    """Computes influence functions based on simply the dot-product of the
    query grad by the average dev grad."""

    def __init__(self, model_loss_fn, model_parameters, test_dataset, config=None, rng=None):
        """Constructor."""

        self._config = config
        if self._config is None:
            self._config = dict()

        self._model_loss_fn = model_loss_fn
        self._model_parameters = model_parameters
        self._rng = rng if rng is not None else np.random.default_rng(np.random.randint(0, 2**63))

        self._dev_grad = self._calc_dataset_grad(test_dataset)

    def _calc_dataset_grad(self, dataset):
        return _calc_dataset_grad(self._model_loss_fn, self._model_parameters, dataset)

    def calc_influence_from_grad(self, grad):
        """Calculate influence given the gradient of loss wrt parameters on
        some particular example(s)"""

        influence = 0
        for i in range(len(grad)):
            influence += -(self._dev_grad[i] * grad[i]).sum()
        return influence

    def calc_influences(self, query_examples_dataset):
        """Calculate the influences of the examples in the query set on the
        test set with respect to the model and train dataset of this influence
        function."""

        # Important that batch size = 1 here, since we want to compute
        # influence individually for each example in the query dataset
        batch_size = 1

        sampler = SequentialSampler(query_examples_dataset)
        dataloader = DataLoader(query_examples_dataset, sampler=sampler, batch_size=batch_size)
        influences = []
        for batch in dataloader:
            loss = self._model_loss_fn(batch)
            grad = torch.autograd.grad(loss, self._model_parameters, only_inputs=True)

            influences.append(self.calc_influence_from_grad(grad))

        return influences

