from typing import Union, Optional, Callable
from functools import partial

import numpy as np
import torch
from torch import Tensor
from robustopt_torch.StochasticSolver import StochasticSolver


def sinkhorn_ifunc(x, v, nu_vals, nu_weights, nu_weights_log, cost, eps):
    """The actual objective/influence function for Sinkhorn deconvolution. Since the
        cost is expected to be broadcastable, this function can take multiple
        x values. v is a tensor giving the current dual function values at
        the points in the discrete distribution.
    """
    if eps > 0.0:
        discrp = ( v - cost(x, nu_vals) ) / eps
        discrp = discrp + nu_weights_log
        lsexp = torch.logsumexp(discrp, -1)
        return torch.dot(v, nu_weights) - eps * lsexp
    else:
        return torch.dot(v, nu_weights) + \
            torch.amin(cost(x, nu_vals) - v, -1)

class DeconvolutionIFunc:
    def __init__(self,
                 cost : Callable[[Tensor, Tensor], Tensor],
                 eps : float,
                 nu_vals : Union[np.ndarray, Tensor],
                 nu_weights : Optional[Union[np.ndarray, Tensor]] = None):
        # Cost function for the sinkhorn cost, should be broadcastable so that
        # all-pairs cost can be computed with the same call
        self.cost = cost

        # Value of Sinkhorn regularization
        if eps < 0:
            raise ValueError("Epsilon must be non-negative!")
        self.eps = eps

        # The point masses of the discrete distribution
        self.nu_vals = torch.as_tensor(nu_vals).detach().clone()

        # The weights of the points in the discrete distribution
        if nu_weights is None:
            self.nu_weights = torch.ones(self.nu_vals.size()[0],
                                         dtype=torch.get_default_dtype())
        else:
            self.nu_weights = torch.as_tensor(nu_weights).detach().clone()

        if (self.nu_weights < 0.0).any():
            raise ValueError("Weights must be positive!")

        self.nu_weights.div_(self.nu_weights.sum())
        self.nu_weights_log = torch.log(self.nu_weights)

        # Parameters used in the solving for the influence function
        # self.step = solve_step * eps
        self.solve_params = {"num_iter" : 10000,
                             "grad_samp" : 100,
                             "valid_samp" : 5000,
                             "solver_config" : {}}

        # Stores the most recent computed solution
        self.cached_v = torch.zeros(self.nu_weights.numel(), dtype=torch.get_default_dtype())

    def set_solve_params(self, **params):
        """Convenience method to update any solver parameters."""
        self.solve_params.update(params)

    def set_solver_config(self, **params):
        """Convenience method to set any parameters on the solver."""
        self.solve_params["solver_config"].update(params)

    def sinkhorn_ifunc(self, x : Tensor, v : Tensor) -> Tensor:
        """The actual objective/influence function for Sinkhorn deconvolution. Since the
           cost is expected to be broadcastable, this function can take multiple
           x values. v is a tensor giving the current dual function values at
           the points in the discrete distribution.
        """
        return sinkhorn_ifunc(x, v, self.nu_vals, self.nu_weights,
                              self.nu_weights_log, self.cost, self.eps)
        # if self.eps > 0.0:
        #     discrp = ( v - self.cost(x, self.nu_vals) ) / self.eps
        #     discrp = discrp + self.nu_weights_log
        #     lsexp = torch.logsumexp(discrp, -1)
        #     return torch.dot(v, self.nu_weights) - self.eps * lsexp
        # else:
        #     return torch.dot(v, self.nu_weights) + \
        #         torch.amin(self.cost(x, self.nu_vals) - v, -1)

    def vsolve(self, sampler : Callable[[int], Tensor], init_val :
               Optional[Tensor] = None) -> Tensor:
        """Solve for optimal dual function values for a given distribution using SGD/MD.
Sampler is a function which draws samples from the target distribution. The
method returns a detached version of the optimal v.

        """

        if init_val is None:
            init_val = torch.zeros(self.nu_weights.numel(), dtype=torch.get_default_dtype())

        problem_dict = {"num_iter" : self.solve_params["num_iter"],
                        "init_val" : init_val,
                        "objective" : lambda x, v : -self.sinkhorn_ifunc(x,v),
                        "projection" : lambda x : x.sub_(x.sum() / x.numel()),
                        "grad_sampler" : lambda: sampler(self.solve_params["grad_samp"]),
                        "valid_sampler" : lambda: sampler(self.solve_params["valid_samp"])}

        solver = StochasticSolver(**self.solve_params["solver_config"])
        return solver.solve(**problem_dict)

    def get_ifunc(self, sampler : Callable[[int], Tensor]) -> Callable[[Tensor], Tensor]:
        """Convenience method to solve for the influence function and then get a
           callable copy of the function. Sampler is the distribution to
           calculate the influence function at.
        """
        v = self.vsolve(sampler, self.cached_v)
        self.cached_v.copy_(v)
        wrapped_ifunc = partial(sinkhorn_ifunc, v = v.detach().clone(), nu_vals
                                = self.nu_vals.detach().clone(), nu_weights =
                                self.nu_weights.detach().clone(), nu_weights_log
                                = self.nu_weights_log.detach().clone(), cost =
                                self.cost, eps = self.eps)
        return wrapped_ifunc
