import torch
from torch.autograd import Variable
from torch.nn.functional import softmax

from goal_set_planning.util.misc import euclidean_distance
from goal_set_planning.util.inference_cardinality import inference_cardinality, TreeMarginals

from .distributions import Gaussian
from .classification import train_sample_classifier


NINF = -1e5


def chamfer(p_samples, q_samples, reduction='mean'):
    # N_p, N_q = p_samples.size(0), q_samples.size(0)
    dists = euclidean_distance(p_samples, q_samples, squared=True)

    if reduction == 'mean':
        cd = dists.min(0)[0].mean() + dists.min(1)[0].mean()
    elif reduction == 'sum':
        cd = dists.min(0)[0].sum() + dists.min(1)[0].sum()
    else:
        raise Exception("DISTANCES: Chamfer distance: Unrecognized reduction: {}".format(reduction))

    return cd


def cross_entropy_sample(p_samples, p):
    """Cross entropy estimated with samples x."""
    N = p_samples.shape[0]
    return -p.log_pdf(p_samples).sum() / N


def cross_entropy_distances(distances, sigma=0.05):
    """Two sample cross entropy. Given the distances between two sample sets,
    where the distances are of shape (N_q, N_p)."""
    N_p, N_q = distances.shape
    params = {"dtype": distances.dtype, "device": distances.device}
    error_px = Gaussian(0, sigma, **params)
    # Numerically stable version.
    log_qi_x = error_px.log_pdf(distances.view(N_p * N_q, 1)).view(N_p, N_q)
    log_qx = torch.logsumexp(log_qi_x, dim=-1) - torch.log(torch.as_tensor(N_q, **params))
    return -log_qx.sum() / N_p


def kernel_mmd(x, y, kernel):
    M, N = x.shape[0], y.shape[0]  # Only works if at least 2 dims
    kxx = kernel(x, x)
    kyy = kernel(y, y)
    kxx = (kxx.sum() - torch.trace(kxx)) / (M * (M - 1))
    kyy = (kyy.sum() - torch.trace(kyy)) / (N * (N - 1))
    kxy = kernel(x, y).sum() * 2 / (M * N)
    return kxx + kyy - kxy


def kl_gaussians(p, q):
    assert isinstance(p, Gaussian) and isinstance(q, Gaussian)
    assert p.dim == q.dim
    D = p.dim
    mu_diff = q.mu - p.mu
    det_ratio = q._sigma_det / p._sigma_det
    kl = torch.trace(q._sigma_inv.mm(p.sigma)) - D + torch.matmul(mu_diff, q._sigma_inv).dot(mu_diff)
    kl += torch.log(det_ratio)
    return 0.5 * kl


def kl_sample(p_samples, p, q):
    N = p_samples.size(0)
    return (p.log_pdf(p_samples) - q.log_pdf(p_samples)).sum() / N


def kl_knn(p_samples, q_samples, k=1):
    """KL estimate between two distributions, KL(p(x)||q(x)) given samples from
    p and q."""
    M, D = p_samples.shape
    N, _ = q_samples.shape
    d_XX_vals, d_XX_idx = euclidean_distance(p_samples, p_samples).topk(k + 1, largest=False)
    d_XY_vals, d_XY_idx = euclidean_distance(p_samples, q_samples).topk(k, largest=False)
    dk_XX = torch.sqrt(d_XX_vals[:, -1])
    dk_XY = torch.sqrt(d_XY_vals[:, -1])
    dkl = torch.log(dk_XY / dk_XX).sum() * D / M + torch.log(torch.tensor(N / (M - 1)))
    return dkl


def kl_classifier(samples, classifier):
    """KL estimate between two distributions, KL(p(x)||q(x)). Classifier should
    give the likelihood that the samples belong to p(x), as logits (without
    sigmoid function applied)."""
    N = samples.size(0)
    logits = classifier(samples).squeeze()
    return logits.sum() / N


"""Callable object versions of distances."""


class DistributionDistance(object):
    """Base class for distribution distances."""

    def __init__(self):
        pass

    def __call__(self, *args):
        return self.forward(*args)

    def forward(self, *args):
        raise NotImplementedError()

    def reset(self, **kwargs):
        pass


class KernelMMDSampleDistance(DistributionDistance):
    def __init__(self, kernel, estimate_params=True):
        super(KernelMMDSampleDistance, self).__init__()

        self.kernel = kernel
        self.estimate_params = estimate_params

    def forward(self, p_sample, q_sample):
        if self.estimate_params:
            self.kernel.set_params(p_sample.detach(), q_sample.detach())
        return kernel_mmd(p_sample, q_sample, self.kernel)


class KLDivergenceClassifier(DistributionDistance):
    def __init__(self, model=None, lr=0.01, weight_decay=1e-5, epochs=2,
                 warm_start=True, stop_early=True,
                 tensor_kwargs={"device": "cpu", "dtype": torch.float32}):
        self.model = model
        self.lr = lr
        self.weight_decay = weight_decay
        self.epochs = epochs
        self.warm_start = warm_start
        self.stop_early = stop_early
        self.tensor_kwargs = tensor_kwargs

    def init(self, p_sample, q_sample, epochs=100):
        self.model = None
        self.train_classifier(p_sample, q_sample, epochs=epochs)

    def reset(self, **kwargs):
        self.model = None

    def forward(self, p_sample, q_sample, retrain=True):
        if self.warm_start and self.model is None:
            print("WARNING: When in warm start mode, init() should be called before using the ratio estimator.")

        if retrain or self.model is None:
            self.train_classifier(p_sample, q_sample)

        self.model.eval()
        return kl_classifier(p_sample, self.model)

    def train_classifier(self, p_sample, q_sample, epochs=None, lr=None, weight_decay=None):
        epochs = epochs if epochs is not None else self.epochs
        lr = lr if lr is not None else self.lr
        weight_decay = weight_decay if weight_decay is not None else self.weight_decay

        # Reset the model if not in warm start mode.
        if not self.warm_start:
            self.model = None

        self.model = train_sample_classifier(p_sample.detach(), q_sample.detach(),
                                             lr=self.lr, weight_decay=self.weight_decay,
                                             epochs=self.epochs, model=self.model,
                                             stop_early=self.stop_early,
                                             tensor_kwargs=self.tensor_kwargs)


"""BELOW: Taken from: https://github.com/josipd/torch-two-sample"""


class SmoothFRStatistic(object):
    r"""The smoothed Friedman-Rafsky test :cite:`djolonga17graphtests`.
    Arguments
    ---------
    n_1: int
        The number of points in the first sample.
    n_2: int
        The number of points in the second sample.
    cuda: bool
        If true, the arguments to :py:meth:`~.SmoothFRStatistic.__call__` must
        be be on the current cuda device. Otherwise, they should be on the cpu.
    """

    def __init__(self, n_1, n_2, compute_t_stat=True, tensor_kwargs={"device": "cpu", "dtype": torch.float32}):
        n = n_1 + n_2
        self.n_1, self.n_2 = n_1, n_2
        # The idx_within tensor contains the indices that correspond to edges
        # that connect samples from within the same sample.
        # The matrix self.nbs is of size (n, n_edges) and has 1 in position
        # (i, j) if node i is incident to edge j. Specifically, note that
        # self.nbs @ mu will result in a vector that has at position i the sum
        # of the marginals of all edges incident to i, which we need in the
        # formula for the variance.
        self.tensor_kwargs = tensor_kwargs
        self.idx_within = torch.zeros((n * (n - 1)) // 2, device=tensor_kwargs["device"], dtype=bool)
        if compute_t_stat:
            self.nbs = torch.zeros(n, self.idx_within.size(0), **tensor_kwargs)

        if compute_t_stat:
            self.nbs.zero_()
        k = 0
        for i in range(n):
            for j in range(i + 1, n):
                if compute_t_stat:
                    self.nbs[i, k] = 1
                    self.nbs[j, k] = 1
                if (i < n_1 and j < n_1) or (i >= n_1 and j >= n_1):
                    self.idx_within[k] = 1
                k += 1

        self.marginals_fn = TreeMarginals(n_1 + n_2, self.idx_within.is_cuda)
        self.compute_t_stat = compute_t_stat

    def __call__(self, sample_1, sample_2, alphas, norm=2, ret_matrix=False):
        r"""Evaluate the smoothed Friedman-Rafsky test statistic.
        The test accepts several **inverse temperatures** in ``alphas``, does
        one test for each ``alpha``, and takes their mean as the statistic.
        Namely, using the notation in :cite:`djolonga17graphtests`, the
        value returned by this call if ``compute_t_stat=False`` is equal to:
        .. math::
            -\frac{1}{m}\sum_{j=m}^k T_{\pi^*}^{1/\alpha_j}(\textrm{sample}_1,
                                                            \textrm{sample}_2).
        If ``compute_t_stat=True``, the returned value is the t-statistic of
        the above quantity under the permutation null. Note that we compute the
        negated statistic of what is used in :cite:`djolonga17graphtests`, as
        it is exactly what we want to minimize when used as an objective for
        training implicit models.
        Arguments
        ---------
        sample_1: :class:`torch:torch.autograd.Variable`
            The first sample, should be of size ``(n_1, d)``.
        sample_2: :class:`torch:torch.autograd.Variable`
            The second sample, should be of size ``(n_2, d)``.
        alphas: list of :class:`float` numbers
            The inverse temperatures.
        norm : float
            Which norm to use when computing distances.
        ret_matrix: bool
            If set, the call with also return a second variable.
            This variable can be then used to compute a p-value using
            :py:meth:`~.SmoothFRStatistic.pval`.
        Returns
        -------
        :class:`float`
            The test statistic, a t-statistic if ``compute_t_stat=True``.
        :class:`torch:torch.autograd.Variable`
            Returned only if ``ret_matrix`` was set to true."""
        sample_12 = torch.cat((sample_1, sample_2), 0)
        diffs = euclidean_distance(sample_12, sample_12, squared=True)
        margs = None
        for alpha in alphas:
            margs_a = self.marginals_fn(
                self.marginals_fn.triu(- alpha * diffs))
            if margs is None:
                margs = margs_a
            else:
                margs = margs + margs_a

        margs = margs / len(alphas)
        idx_within = Variable(self.idx_within, requires_grad=False)
        n_1, n_2, n = self.n_1, self.n_2, self.n_1 + self.n_2
        m = margs.sum()
        t_stat = m - torch.masked_select(margs, idx_within).sum()
        if self.compute_t_stat:
            nbs = Variable(self.nbs, requires_grad=False)
            nbs_sum = (nbs.mm(margs.unsqueeze(1))**2).sum()
            chi_1 = n_1 * n_2 / (n * (n - 1))
            chi_2 = 4 * (n_1 - 1) * (n_2 - 1) / ((n - 2) * (n - 3))
            var = (chi_1 * (1 - chi_2) * nbs_sum +
                   chi_1 * chi_2 * (margs**2).sum() +
                   chi_1 * (chi_2 - 4 * chi_1) * m**2)
            mean = 2 * m * n_1 * n_2 / (n * (n - 1))
            std = torch.sqrt(1e-5 + var)
        else:
            mean = 0.
            std = 1.

        if ret_matrix:
            return - (t_stat - mean) / std, margs
        else:
            return - (t_stat - mean) / std


class SmoothKNNStatistic(object):
    r"""The smoothed k-nearest neighbours test :cite:`djolonga17graphtests`.
    Note that the ``k=1`` case is computed directly using a SoftMax and should
    execute much faster than the statistics with ``k > 1``.
    Arguments
    ---------
    n_1: int
        The number of points in the first sample.
    n_2: int
        The number of points in the second sample.
    cuda: bool
        If true, the arguments to ``__call__`` must be be on the current
        cuda device. Otherwise, they should be on the cpu.
    k: int
        The number of nearest neighbours (k in kNN)."""

    def __init__(self, n_1=None, n_2=None, k=1, compute_t_stat=True, alphas=None,
                 tensor_kwargs={"device": "cpu", "dtype": torch.float32}):
        self.count_potential = torch.FloatTensor(1, k + 1)
        self.count_potential.fill_(NINF)
        self.count_potential[0, -1] = 0
        # self.indices = (1 - torch.eye(n_1 + n_2, **tensor_kwargs)).bool()
        self.k = k
        self.n_1 = n_1
        self.n_2 = n_2
        self.compute_t_stat = compute_t_stat
        self.alphas = alphas
        self.tensor_kwargs = tensor_kwargs

    def __call__(self, sample_1, sample_2, alphas=None, norm=2, ret_matrix=False):
        r"""Evaluate the smoothed kNN statistic.
        The test accepts several **inverse temperatures** in ``alphas``, does
        one test for each ``alpha``, and takes their mean as the statistic.
        Namely, using the notation in :cite:`djolonga17graphtests`, the
        value returned by this call if `compute_t_stat=False` is equal to:
        .. math::
            -\frac{1}{m}\sum_{j=m}^k T_{\pi^*}^{1/\alpha_j}(\textrm{sample}_1,
                                                            \textrm{sample}_2).
        If ``compute_t_stat=True``, the returned value is the t-statistic of
        the above quantity under the permutation null. Note that we compute the
        negated statistic of what is used in :cite:`djolonga17graphtests`, as
        it is exactly what we want to minimize when used as an objective for
        training implicit models.
        Arguments
        ---------
        sample_1: :class:`torch:torch.autograd.Variable`
            The first sample, of size ``(n_1, d)``.
        sample_2: variable of shape (n_2, d)
            The second sample, of size ``(n_2, d)``.
        alpha: list of :class:`float`
            The smoothing strengths.
        norm : float
            Which norm to use when computing distances.
        ret_matrix: bool
            If set, the call with also return a second variable.
            This variable can be then used to compute a p-value using
            :py:meth:`~.SmoothKNNStatistic.pval`.
        Returns
        -------
        :class:`float`
            The test statistic, a t-statistic if ``compute_t_stat=True``.
        :class:`torch:torch.autograd.Variable`
            Returned only if ``ret_matrix`` was set to true."""
        if alphas is None:
            alphas = self.alphas
        n_1 = sample_1.size(0)
        n_2 = sample_2.size(0)
        self.indices = (1 - torch.eye(n_1 + n_2, **self.tensor_kwargs)).bool()
        # assert n_1 == self.n_1
        # assert n_2 == self.n_2
        n = n_1 + n_2
        sample_12 = torch.cat((sample_1, sample_2), 0)
        diffs = euclidean_distance(sample_12, sample_12, squared=True)
        indices = Variable(self.indices, requires_grad=False)
        k = self.count_potential.size()[1] - 1
        assert k == self.k
        count_potential = Variable(
            self.count_potential.expand(n, k + 1), requires_grad=False)

        diffs = torch.masked_select(diffs, indices).view(n, n - 1)

        margs_ = None
        for alpha in alphas:
            if self.k == 1:
                margs_a = softmax(-alpha * diffs, dim=1)
            else:
                margs_a = inference_cardinality(
                    - alpha * diffs.cpu(), count_potential)
            if margs_ is None:
                margs_ = margs_a
            else:
                margs_ = margs_ + margs_a

        margs_ = margs_ / len(alphas)
        # The variable margs_ is a matrix of size n x n-1, which we want to
        # reshape to n x n by adding a zero diagonal, as it makes the following
        # logic easier to follow. The variable margs_ is on the GPU when k=1.
        if margs_.is_cuda:
            margs = torch.cuda.FloatTensor(n, n)
        else:
            margs = torch.FloatTensor(n, n)
        margs.zero_()
        margs = Variable(margs, requires_grad=False)
        margs = margs.masked_scatter(indices.to(device=margs.device), margs_.view(-1))

        t_stat = margs[:n_1, n_1:].sum() + margs[n_1:, :n_1].sum()
        if self.compute_t_stat:
            m = margs.sum()
            mean = 2 * m * n_1 * n_2 / (n * (n - 1))
            nbs_sum = ((
                margs.sum(0).view(-1) + margs.sum(1).view(-1))**2).sum()
            flip_sum = (margs * margs.transpose(1, 0)).sum()
            chi_1 = n_1 * n_2 / (n * (n - 1))
            chi_2 = 4 * (n_1 - 1) * (n_2 - 1) / ((n - 2) * (n - 3))
            var = (chi_1 * (1 - chi_2) * nbs_sum +
                   chi_1 * chi_2 * (margs**2).sum() +
                   chi_1 * chi_2 * flip_sum +
                   chi_1 * (chi_2 - 4 * chi_1) * m ** 2)
            std = torch.sqrt(1e-5 + var)
        else:
            mean = 0.
            std = 1.

        if ret_matrix:
            return - (t_stat - mean) / std, margs
        else:
            return - (t_stat - mean) / std


class EnergyStatistic:
    r"""The energy test of :cite:`szekely2013energy`.
    Arguments
    ---------
    n_1: int
        The number of points in the first sample.
    n_2: int
        The number of points in the second sample."""

    def __init__(self, n_1=None, n_2=None):
        self.n_1 = n_1
        self.n_2 = n_2

        # self.a00 = - 1. / (n_1 * n_1)
        # self.a11 = - 1. / (n_2 * n_2)
        # self.a01 = 1. / (n_1 * n_2)

    def __call__(self, sample_1, sample_2, ret_matrix=False):
        r"""Evaluate the statistic.
        Arguments
        ---------
        sample_1: :class:`torch:torch.autograd.Variable`
            The first sample, of size ``(n_1, d)``.
        sample_2: variable of shape (n_2, d)
            The second sample, of size ``(n_2, d)``.
        norm : float
            Which norm to use when computing distances.
        ret_matrix: bool
            If set, the call with also return a second variable.
            This variable can be then used to compute a p-value using
            :py:meth:`~.EnergyStatistic.pval`.
        Returns
        -------
        :class:`float`
            The test statistic.
        :class:`torch:torch.autograd.Variable`
            Returned only if ``ret_matrix`` was set to true."""
        # sample_12 = torch.cat((sample_1, sample_2), 0)
        # distances = euclidean_distance(sample_12, sample_12, squared=True)
        # d_1 = distances[:self.n_1, :self.n_1].sum()
        # d_2 = distances[-self.n_2:, -self.n_2:].sum()
        # d_12 = distances[:self.n_1, -self.n_2:].sum()
        n_1 = sample_1.size(0)
        n_2 = sample_2.size(0)

        self.a00 = - 1. / (n_1 * n_1)
        self.a11 = - 1. / (n_2 * n_2)
        self.a01 = 1. / (n_1 * n_2)

        d_1 = euclidean_distance(sample_1, sample_1, squared=True).sum()
        d_2 = euclidean_distance(sample_2, sample_2, squared=True).sum()
        d_12 = euclidean_distance(sample_1, sample_2, squared=True).sum()

        loss = 2 * self.a01 * d_12 + self.a00 * d_1 + self.a11 * d_2

        if ret_matrix:
            sample_12 = torch.cat((sample_1, sample_2), 0)
            return loss, euclidean_distance(sample_12, sample_12, squared=True)
        else:
            return loss
