from abc import ABC, abstractmethod
from typing import Optional, Union

import numpy as np
import scipy.special as sc
import scipy.stats as stats

from .utils import quad


class BayesianGaussianModel(ABC):
    def __init__(
        self,
        s2: Union[float, np.ndarray],
    ) -> None:
        self.s2 = s2

    @abstractmethod
    def prior_predictive(self, y: float, s2: Optional[float] = None) -> float:
        raise NotImplementedError

    def prior_predictive_der(self, y: float) -> float:
        raise NotImplementedError

    def likelihood(self, y: float, theta: float) -> float:
        return stats.norm.pdf(y, loc=theta, scale=np.sqrt(self.s2))

    def posterior_mean(self, y: Union[float, np.ndarray]) -> float:
        return np.where(
            y == 0.0,
            0.0,
            y + self.s2 * self.prior_predictive_der(y) / self.prior_predictive(y),
        )

    def root(
        self, w: float, theta: float, alpha: float = 0.05, s2: Optional[float] = None
    ) -> float:
        s2 = self.s2 if s2 is None else s2
        s = np.sqrt(s2)

        q_left = alpha * (1 - w)
        if q_left == 0.0:
            # z_left -> -inf
            left = np.finfo(float).max
        else:
            z_left = stats.norm.ppf(q_left)
            left = self.prior_predictive(theta - s * z_left, s2=s2) * np.exp(
                0.5 * np.square(z_left)
            )

        q_right = alpha * w
        if q_right == 0.0:
            # z_right -> inf
            right = np.finfo(float).max
        else:
            z_right = -stats.norm.ppf(q_right)
            right = self.prior_predictive(theta - s * z_right, s2=s2) * np.exp(
                0.5 * np.square(z_right)
            )

        return left - right


class GaussianGaussianModel(BayesianGaussianModel):
    def __init__(
        self, s2: Union[float, np.ndarray], t2: Union[float, np.ndarray]
    ) -> None:
        super().__init__(s2)
        self.t2 = t2

    def prior_theta(self, theta: float) -> float:
        return stats.norm.pdf(theta, scale=np.sqrt(self.t2))

    def prior_predictive(self, y: float, s2: Optional[float] = None) -> float:
        s2 = self.s2 if s2 is None else s2
        return stats.norm.pdf(y, scale=np.sqrt(s2 + self.t2))

    def posterior_mean(self, y: float) -> float:
        s2 = self.s2
        t2 = self.t2

        prec_y = 1 / s2
        prec_theta = 1 / t2

        return prec_y * y / (prec_y + prec_theta)

    def root(
        self, w: float, theta: float, alpha: float = 0.05, s2: Optional[float] = None
    ) -> float:
        s2 = self.s2 if s2 is None else s2
        s = np.sqrt(s2)
        t2 = self.t2
        v2 = self.s2 + t2

        q_left = alpha * (1 - w)
        if q_left == 0.0:
            # z_left -> -inf
            left = np.finfo(float).max
        else:
            z_left = stats.norm.ppf(q_left)
            left = (
                -0.5
                * (np.square(theta) - 2 * theta * s * z_left - t2 * np.square(z_left))
                / v2
            )

        q_right = alpha * w
        if q_right == 0.0:
            # z_right -> inf
            right = np.finfo(float).max
        else:
            z_right = -stats.norm.ppf(q_right)
            right = (
                -0.5
                * (np.square(theta) - 2 * theta * s * z_right - t2 * np.square(z_right))
                / v2
            )

        return left - right


class HorseshoeGaussianModel(BayesianGaussianModel):
    def __init__(self, s2: Union[float, np.ndarray]) -> None:
        self.s2 = s2

    def prior_predictive(self, y: float, s2: Optional[float] = None) -> float:
        s2 = self.s2 if s2 is None else s2
        return np.where(
            y == 0.0,
            2.0 / np.pi / np.sqrt(2 * np.pi * s2),
            2
            * sc.dawsn(np.abs(y) / np.sqrt(2 * s2))
            / np.power(np.pi, 3 / 2)
            / np.abs(y),
        )

    def prior_predictive_der(self, y: float) -> float:
        s2 = self.s2
        a = b = 0.5
        y2_scaled = np.square(y) / (2 * s2)
        return (
            -y
            / s2
            * (
                sc.gamma(a + 1.5)
                * sc.gamma(a + b)
                / sc.gamma(a)
                / sc.gamma(a + b + 1.5)
                / np.sqrt(2 * np.pi * s2)
                * sc.hyp1f1(a + 1.5, a + b + 1.5, -y2_scaled)
            )
        )


class GeneralisedParetoGaussianModel(BayesianGaussianModel):
    def __init__(self, s2: Union[float, np.ndarray]) -> None:
        self.s2 = s2

    def prior_predictive(self, y: float, s2: Optional[float] = None) -> float:
        s2 = self.s2 if s2 is None else s2
        s = np.sqrt(s2)
        if y == 0.0:
            return 1.0 / (2 * s * np.sqrt(2 * np.pi))
        y2 = np.square(y)
        y2_scaled = y2 / (2 * s2)
        return s * (1 - np.exp(-y2_scaled)) / (np.sqrt(2 * np.pi) * y2)

    def prior_predictive_der(self, y: float) -> float:
        s2 = self.s2
        a = 0.5
        b = 1.0
        y2_scaled = np.square(y) / (2 * s2)
        return (
            -y
            / s2
            * (
                sc.gamma(a + 1.5)
                * sc.gamma(a + b)
                / sc.gamma(a)
                / sc.gamma(a + b + 1.5)
                / np.sqrt(2 * np.pi * s2)
                * sc.hyp1f1(a + 1.5, a + b + 1.5, -y2_scaled)
            )
        )


class OneHalfBetaPrimeGaussianModel(BayesianGaussianModel):
    def __init__(self, s2: Union[float, np.ndarray]) -> None:
        self.s2 = s2

    def prior_predictive(self, y: float, s2: Optional[float] = None) -> float:
        s2 = self.s2 if s2 is None else s2
        y2_scaled = np.square(y) / (4 * s2)
        return (
            np.sqrt(np.pi / (2 * s2))
            / 4
            * (sc.ive(0, y2_scaled) - sc.ive(1, y2_scaled))
        )

    def prior_predictive_der(self, y: float) -> float:
        s2 = self.s2
        a = 1.0
        b = 0.5
        y2_scaled = np.square(y) / (2 * s2)
        return (
            -y
            / s2
            * (
                sc.gamma(a + 1.5)
                * sc.gamma(a + b)
                / sc.gamma(a)
                / sc.gamma(a + b + 1.5)
                / np.sqrt(2 * np.pi * s2)
                * sc.hyp1f1(a + 1.5, a + b + 1.5, -y2_scaled)
            )
        )


class LaplaceGaussianModel(BayesianGaussianModel):
    def __init__(
        self,
        s2: float,
        b: float,  # TODO: generalise to multivariate
    ) -> None:
        super().__init__(s2)
        self.b = b

    def prior_predictive(self, y: float, s2: Optional[float] = None) -> float:
        return quad(
            lambda theta: (
                stats.laplace.pdf(theta, scale=self.b) * self.likelihood(y, theta)
            ),
            -np.inf,
            np.inf,
            midpoint=y,
            epsabs=1e-8,
            epsrel=1e-8,
        )[0]

    def root(
        self, w: float, theta: float, alpha: float = 0.05, s2: Optional[float] = None
    ) -> float:
        s2 = self.s2 if s2 is None else s2
        s = np.sqrt(s2)

        q_left = alpha * (1 - w)
        if q_left == 0.0:
            # z_left -> -inf
            left = np.finfo(float).max
        else:
            z_left = stats.norm.ppf(q_left)
            left = np.log(self.prior_predictive(theta - s * z_left, s2=s2)) + (
                0.5 * np.square(z_left)
            )

        q_right = alpha * w
        if q_right == 0.0:
            # z_right -> inf
            right = np.finfo(float).max
        else:
            z_right = -stats.norm.ppf(q_right)
            right = np.log(self.prior_predictive(theta - s * z_right, s2=s2)) + (
                0.5 * np.square(z_right)
            )

        return left - right


class ScaledHorseshoeGaussianModel(BayesianGaussianModel):
    def __init__(
        self, s2: float, l2: float
    ) -> None:  # TODO: generalise to multivariate
        self.s2 = s2
        self.l2 = l2  # l2 = s2 recovers HorseshoeGaussianModel

    def prior_predictive(self, y: float, s2: Optional[float] = None) -> float:
        s2 = self.s2 if s2 is None else s2
        l2 = self.l2
        return quad(
            lambda t: (
                stats.halfcauchy.pdf(t)
                * stats.norm.pdf(y, loc=0.0, scale=np.sqrt(s2 + l2 * np.square(t)))
            ),
            0.0,
            np.inf,
            epsabs=1e-8,
            epsrel=1e-8,
        )[0]

    def prior_predictive_der(self, y: float) -> float:
        s2 = self.s2
        l2 = self.l2
        return quad(
            lambda t: (
                stats.halfcauchy.pdf(t)
                * stats.norm.pdf(y, loc=0.0, scale=np.sqrt(s2 + l2 * np.square(t)))
                * (-y / (s2 + l2 * np.square(t)))
            ),
            0.0,
            np.inf,
            epsabs=1e-8,
            epsrel=1e-8,
        )[0]

    def root(
        self, w: float, theta: float, alpha: float = 0.05, s2: Optional[float] = None
    ) -> float:
        s2 = self.s2 if s2 is None else s2
        s = np.sqrt(s2)

        q_left = alpha * (1 - w)
        if q_left == 0.0:
            # z_left -> -inf
            left = np.finfo(float).max
        else:
            z_left = stats.norm.ppf(q_left)
            left = np.log(self.prior_predictive(theta - s * z_left, s2=s2)) + (
                0.5 * np.square(z_left)
            )

        q_right = alpha * w
        if q_right == 0.0:
            # z_right -> inf
            right = np.finfo(float).max
        else:
            z_right = -stats.norm.ppf(q_right)
            right = np.log(self.prior_predictive(theta - s * z_right, s2=s2)) + (
                0.5 * np.square(z_right)
            )

        return left - right
