#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""
Pairwise likelihood for pairwise preference model (e.g., PairwiseGP).
"""

from __future__ import annotations

import math
from abc import ABC, abstractmethod
from typing import Any, Tuple

import torch
from gpytorch.likelihoods import Likelihood
from torch import Tensor
from torch.distributions import Bernoulli


class PairwiseLikelihood(Likelihood, ABC):
    """
    Pairwise likelihood base class for pairwise preference GP (e.g., PairwiseGP).

    :meta private:
    """

    def __init__(self, max_plate_nesting: int = 1):
        """
        Initialized like a `gpytorch.likelihoods.Likelihood`.

        Args:
            max_plate_nesting: Defaults to 1.
        """
        super().__init__(max_plate_nesting)

    def forward(self, utility: Tensor, D: Tensor, **kwargs: Any) -> Bernoulli:
        """Given the difference in (estimated) utility util_diff = f(v) - f(u),
        return a Bernoulli distribution object representing the likelihood of
        the user prefer v over u.

        Note that this is not used by the `PairwiseGP` model,
        """
        return Bernoulli(probs=self.p(utility=utility, D=D))

    @abstractmethod
    def p(self, utility: Tensor, D: Tensor) -> Tensor:
        """Given the difference in (estimated) utility util_diff = f(v) - f(u),
        return the probability of the user prefer v over u.

        Args:
            utility: A Tensor of shape `(batch_size) x n`, the utility at MAP point
            D: D is `(batch_size x) m x n` matrix with all elements being zero in last
                dimension except at two positions D[..., i] = 1 and D[..., j] = -1
                respectively, representing item i is preferred over item j.
            log: if true, return log probability
        """

    def log_p(self, utility: Tensor, D: Tensor) -> Tensor:
        """return the log of p"""
        return torch.log(self.p(utility=utility, D=D))

    def negative_log_gradient_sum(self, utility: Tensor, D: Tensor) -> Tensor:
        """Calculate the sum of negative log gradient with respect to each item's latent
            utility values. Useful for models using laplace approximation.

        Args:
            utility: A Tensor of shape `(batch_size x) n`, the utility at MAP point
            D: D is `(batch_size x) m x n` matrix with all elements being zero in last
                dimension except at two positions D[..., i] = 1 and D[..., j] = -1
                respectively, representing item i is preferred over item j.

        Returns:
            A `(batch_size x) n` Tensor representing the sum of negative log gradient
            values of the likelihood over all comparisons (i.e., the m dimension)
            with respect to each item.
        """
        raise NotImplementedError

    def negative_log_hessian_sum(self, utility: Tensor, D: Tensor) -> Tensor:
        """Calculate the sum of negative log hessian with respect to each item's latent
            utility values. Useful for models using laplace approximation.

        Args:
            utility: A Tensor of shape `(batch_size) x n`, the utility at MAP point
            D: D is `(batch_size x) m x n` matrix with all elements being zero in last
                dimension except at two positions D[..., i] = 1 and D[..., j] = -1
                respectively, representing item i is preferred over item j.

        Returns:
            A `(batch_size x) n x n` Tensor representing the sum of negative log hessian
            values of the likelihood over all comparisons (i.e., the m dimension) with
            respect to each item.
        """
        raise NotImplementedError


class PairwiseProbitLikelihood(PairwiseLikelihood):
    """Pairwise likelihood using probit function

    Given two items v and u with utilities f(v) and f(u), the probability that we
    prefer v over u with probability std_normal_cdf((f(v) - f(u))/sqrt(2)). Note
    that this formulation implicitly assume the noise term is fixed at 1.
    """

    # Clamping z values for better numerical stability. See self._calc_z for detail
    # norm_cdf(z=3) ~= 0.999, top 0.1% percent
    _zlim = 3

    def _calc_z(self, utility: Tensor, D: Tensor) -> Tensor:
        """Calculate the z score given estimated utility values and
        the comparison matrix D.
        """
        scaled_util = (utility / math.sqrt(2)).unsqueeze(-1)
        z = D.to(scaled_util) @ scaled_util
        z = z.clamp(-self._zlim, self._zlim).squeeze(-1)
        return z

    def _calc_z_derived(self, z: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
        """Calculate auxiliary statistics derived from z, including log pdf,
        log cdf, and the hazard function (pdf divided by cdf)"""
        std_norm = torch.distributions.normal.Normal(
            torch.zeros(1, dtype=z.dtype, device=z.device),
            torch.ones(1, dtype=z.dtype, device=z.device),
        )
        z_logpdf = std_norm.log_prob(z)
        z_cdf = std_norm.cdf(z)
        z_logcdf = torch.log(z_cdf)
        hazard = torch.exp(z_logpdf - z_logcdf)
        return z_logpdf, z_logcdf, hazard

    def p(self, utility: Tensor, D: Tensor, log: bool = False) -> Tensor:
        z = self._calc_z(utility=utility, D=D)
        std_norm = torch.distributions.normal.Normal(
            torch.zeros(1, dtype=z.dtype, device=z.device),
            torch.ones(1, dtype=z.dtype, device=z.device),
        )
        return std_norm.cdf(z)

    def negative_log_gradient_sum(self, utility: Tensor, D: Tensor) -> Tensor:
        # Compute the sum over of grad. of negative Log-LH wrt utility f.
        # Original grad should be of dimension m x n, as in (6) from
        # [Chu2005preference]_. The sum over the m dimension of grad. of
        # negative log likelihood with respect to the utility
        z = self._calc_z(utility, D)
        _, _, h = self._calc_z_derived(z)
        h_factor = h / math.sqrt(2)
        grad = (h_factor.unsqueeze(-2) @ (-D)).squeeze(-2)

        return grad

    def negative_log_hessian_sum(self, utility: Tensor, D: Tensor) -> Tensor:
        # Original hess should be of dimension m x n x n, as in (7) from
        # [Chu2005preference]_ Sum over the first dimension and return a tensor of
        # shape n x n.
        # The sum over the m dimension of hessian of negative log likelihood
        # with respect to the utility
        DT = D.transpose(-1, -2)
        z = self._calc_z(utility, D)
        _, _, h = self._calc_z_derived(z)
        mul_factor = h * (h + z) / 2
        mul_factor = mul_factor.unsqueeze(-2).expand(*DT.size())
        # multiply the hessian value by preference signs
        # (+1 if preferred or -1 otherwise) and sum over the m dimension
        hess = DT * mul_factor @ D
        return hess


class PairwiseLogitLikelihood(PairwiseLikelihood):
    """Pairwise likelihood using logistic (i.e., sigmoid) function

    Given two items v and u with utilities f(v) and f(u), the probability that we
    prefer v over u with probability sigmoid(f(v) - f(u)). Note
    that this formulation implicitly assume the beta term in logistic function is
    fixed at 1.
    """

    # Clamping logit values for better numerical stability.
    # See self._calc_logit for detail logistic(8) ~= 0.9997, top 0.03% percent
    _logit_lim = 8

    def _calc_logit(self, utility: Tensor, D: Tensor) -> Tensor:
        logit = D.to(utility) @ utility.unsqueeze(-1)
        logit = logit.clamp(-self._logit_lim, self._logit_lim).squeeze(-1)
        return logit

    def log_p(self, utility: Tensor, D: Tensor) -> Tensor:
        logit = self._calc_logit(utility=utility, D=D)
        return torch.nn.functional.logsigmoid(logit)

    def p(self, utility: Tensor, D: Tensor) -> Tensor:
        logit = self._calc_logit(utility=utility, D=D)
        return torch.sigmoid(logit)

    def negative_log_gradient_sum(self, utility: Tensor, D: Tensor) -> Tensor:
        indices_shape = utility.shape[:-1] + (-1,)
        winner_indices = (D == 1).nonzero(as_tuple=True)[-1].reshape(indices_shape)
        loser_indices = (D == -1).nonzero(as_tuple=True)[-1].reshape(indices_shape)
        ex = torch.exp(torch.gather(utility, -1, winner_indices))
        ey = torch.exp(torch.gather(utility, -1, loser_indices))
        unsigned_grad = ey / (ex + ey)
        grad = (unsigned_grad.unsqueeze(-2) @ (-D)).squeeze(-2)
        return grad

    def negative_log_hessian_sum(self, utility: Tensor, D: Tensor) -> Tensor:
        DT = D.transpose(-1, -2)
        # calculating f(v) - f(u) given u > v information in D
        neg_logit = -(D @ utility.unsqueeze(-1)).squeeze(-1)
        term = torch.sigmoid(neg_logit)
        mul_factor = term - (term) ** 2
        mul_factor = mul_factor.unsqueeze(-2).expand(*DT.size())
        # multiply the hessian value by preference signs
        # (+1 if preferred or -1 otherwise) and sum over the m dimension
        hess = DT * mul_factor @ D
        return hess
