#!/usr/bin/env python3
import copy
import math
from functools import reduce
from typing import Any, Callable, Iterable, List, Optional

import torch
import torch.nn as nn
from torch import Tensor
from torch.autograd import grad
from torch.utils.data import DataLoader, Dataset
from torch.utils.data._utils.collate import default_collate


def _concatenate_tensors(tensors: Iterable[Tensor]):
    """Takes a series of tensors and compresses them all into a single 1-dimensional tensor"""
    return torch.cat(tuple(t.contiguous().view(-1) for t in tensors))


def _get_parameter_gradients(parameters: List[nn.Parameter]) -> Tensor:
    result = [x.grad.clone() for x in parameters]
    # zero_grad on a list of Parameters
    for p in parameters:
        if p.grad is not None:
            # pyre-fixme[16]: `Parameter` has no attribute `detach_`.
            p.grad.detach_()
            p.grad.zero_()
    return _concatenate_tensors(result)


def _get_loss_gradient(
    closure: Callable[[Any], Tensor], parameters: List[nn.Parameter], datapoint: Any
) -> Tensor:
    """Returns a tensor list representing the gradient of the loss across the model parameters.
    Args:
        closure: A function that takes a datapoint and returns a 1 dimensional tensor corresponding to the loss.
            Note that this function should typically include the loss function of your model,
            typically something like `lambda datapoint: loss(model(datapoint[0]), datapoint[1])`
        parameters: A set of parameters over which gradients will be computed,
            typically `model.parameters()`.
        datapoint: The datapoint over which to compute the loss. Passed into closure.

    Returns:
        A 1 dimensional tensor containing the gradient, flattened and concatenated.
        Will have shape equal to the total number of elements in parameters.

    Example:
        >>> model = nn.Linear(1, 1, bias=False)
        >>> _get_loss_gradient(lambda x: model(x), model.parameters(), Tensor([2.0])
        tensor([[2.0]])
    """
    with torch.enable_grad():
        loss_tensor = closure(datapoint)
        loss_tensor.backward()
        r = _get_parameter_gradients(parameters)
    return r


def _hvp(
    closure: Callable[[Any], Tensor],
    parameters: List[nn.Parameter],
    v: Tensor,
    datapoint: Any,
) -> Tensor:
    """Returns a tensor list representing the Hessian of the loss over datapoint multiplied by v.

    Args:
        closure: A function that takes a datapoint and returns a 1 dimensional tensor corresponding to the loss.
            Note that this function should typically include the loss function of your model,
            typically something like `lambda datapoint: loss(model(datapoint[0]), datapoint[1])`
        parameters: A set of parameters over which gradients will be computed,
            typically `model.parameters()`.
        v: A 1 dimensional tensor with the same number of elements as the model parameters.
        datapoint: The datapoint over which to compute the loss. Passed into closure.

    Returns:
        A 1 dimensional tensor containing the hessian vector prodict, flattened and concatenated.
        Will have shape equal to the total number of elements in parameters.
    """
    with torch.enable_grad():
        loss_tensor = closure(datapoint)
        first_derivatives = _concatenate_tensors(
            grad(loss_tensor, parameters, create_graph=True, retain_graph=True)
        )
        dot_product = (first_derivatives * v).sum()
        return _concatenate_tensors(
            grad(dot_product, parameters, create_graph=False, retain_graph=False)
        ).detach()


def get_influence_rhs(
    closure: Callable[[Any], Tensor],
    parameters: Iterable[nn.Parameter],
    hessian_dataset: Dataset,
    rhs_datapoint: Any,
    num_iterations: int = 500,
    num_chains: int = 1,
    damp: float = 0.00,
    scale: float = 25.0,
    hessian_dataset_collate_fn: Callable[[Any], Any] = default_collate,
) -> Tensor:
    """Returns the right hand side component of the influence score.
    """
    parameters = list(parameters)
    rhs_gradient = _get_loss_gradient(closure, parameters, rhs_datapoint)

    # We repeat the (stochastic) process and aggregate together results
    samples = []
    for _ in range(num_chains):
        h_estimate = rhs_gradient
        hessian_dataloader_shuffled = DataLoader(
            hessian_dataset,
            shuffle=True,
            batch_size=1,
            collate_fn=hessian_dataset_collate_fn,
        )

        for _ in range(num_iterations):
            hessian_datapoint = next(iter(hessian_dataloader_shuffled))
            hvp_result = _hvp(closure, parameters, h_estimate, hessian_datapoint)
            with torch.no_grad():
                h_estimate = reduce(
                    torch.add,
                    [
                        rhs_gradient,
                        (1 - damp) * h_estimate,
                        (-1.0 / scale) * hvp_result,
                    ],
                )
        samples.append(h_estimate / scale)
    with torch.no_grad():
        return reduce(torch.add, samples) / len(samples)


def get_influence_with_rhs(
    closure: Callable[[Any], Tensor],
    parameters: Iterable[nn.Parameter],
    lhs_datapoint: Any,
    rhs_components: Iterable[List[Tensor]],
) -> Iterable[float]:

    """Returns an iterable yielding influence scores between a single left hand side datapoint and a set of right hand side components
    """
    parameters = list(parameters)
    lhs_component = _get_loss_gradient(closure, parameters, lhs_datapoint)
    return (
        torch.dot(lhs_component, rhs_component).item()
        for rhs_component in rhs_components
    )


def get_influence(
    closure: Callable[[Any], Tensor],
    parameters: Iterable[nn.Parameter],
    lhs_dataset: Dataset,
    rhs_dataset: Dataset,
    rhs_batch_size: int = 1,
    num_iterations: int = 500,
    num_chains: int = 1,
    damp: float = 0.00,
    scale: float = 25.0,
    lhs_collate_fn: Callable[[Any], Any] = default_collate,
    rhs_collate_fn: Callable[[Any], Any] = default_collate,
) -> Tensor:
    """Returns a tensor containing influence scores between a set of test datapoints and a set of training datapoints.

    """
    if coverage < 0 or coverage > 1:
        raise ValueError("`coverage` should be in [0, 1]")

    coverage = max(coverage, 1.0 - coverage)
    init_forward = _get_init_forward(init_forward, model)

    all_residuals = []
    all_upweighted = []
    parameters = list(parameters)

    rhs_dataloader = DataLoader(rhs_dataset, batch_size=rhs_batch_size, collate_fn=rhs_collate_fn)
    for rhs_data in rhs_dataloader:
        influence_rhs = [
            get_influence_rhs(
                lambda d: loss(forward(d), d[1]),
                parameters,
                rhs_dataset,
                [x[i].unsqueeze(0) for x in rhs_data],
                num_iterations,
                num_chains,
                damp,
                scale,
                rhs_collate_fn,
            )
            for i in range(rhs_data[0].size(0))
        ]

        for i in range(len(influence_rhs)):
            # we need to deepcopy since `torch.nn.utils.vector_to_parameters` is an in-place op
            upweighted = copy.deepcopy(parameters)
            torch.nn.utils.vector_to_parameters(influence_rhs[i], upweighted)
            upweighted = [u.data + p.data for (u, p) in zip(upweighted, parameters)]

            datapoint = [x[i].unsqueeze(0) for x in rhs_data]
            output = forward(datapoint)
            upw_output = init_forward(upweighted, datapoint)

            residual = torch.abs(output - upw_output)
            all_residuals.append(residual)
            all_upweighted.append(upweighted)

    return lambda datapoint: _get_confidence_bounds(
        coverage,
        all_residuals,
        all_upweighted,
        datapoint,
        init_forward,
    )


def _get_init_forward(
    init_forward: Optional[Callable[[Iterable[nn.Parameter], Any], Tensor]] = None,
    model: Optional[nn.Module] = None,
) -> Callable[[Iterable[nn.Parameter], Any], Tensor]:
    """
    Returns: A function that takes a list of :class:`nn.Parameter` and a datapoint and returns an tensor representing the
            prediction of the model parameterizes by the parameter list.
    """

    argument = init_forward or model
    if isinstance(argument, nn.Module):

        def init_forward_fn(
            parameters: Iterable[nn.Parameter], datapoint: Any
        ) -> Tensor:
            cloned = copy.deepcopy(argument)
            cloned.eval()
            parameters = list(parameters)
            for i, param in enumerate(cloned.parameters()):
                param.data = parameters[i].data
            x = datapoint if isinstance(datapoint, Tensor) else datapoint[0]
            return cloned(x)

        return init_forward_fn
    elif isinstance(argument, Callable):
        return argument
    else:
        raise ValueError("`init_forward` and `model` cannot be both None.")