import abc
from typing import List

import numpy as np
import torch
import torch.nn as nn


class AbstractBound(abc.ABC, nn.Module):
    def __init__(self):
        """
        Specify data bound and perform related computations.
        For equations, refer to https://mc-stan.org/docs/2_29/reference-manual/variable-transforms.html.
        """
        super().__init__()

    @abc.abstractmethod
    def forward_with_logj(self, x: torch.Tensor):
        raise NotImplementedError

    @abc.abstractmethod
    def inverse_with_logj(self, x: torch.Tensor):
        raise NotImplementedError

    def forward(self, x: torch.Tensor):
        return self.forward_with_logj(x)[0]

    def logj_forward(self, x: torch.Tensor):
        return self.forward_with_logj(x)[1]

    def inverse(self, x: torch.Tensor):
        return self.inverse_with_logj(x)[0]

    def logj_inverse(self, x: torch.Tensor):
        return self.inverse_with_logj(x)[1]

    @abc.abstractmethod
    def clip_(self, x: torch.Tensor, min_magnitude: float = 1e-8):
        # Clips the value to the bound in-place (with some magnitude tolerance for stability).
        raise ValueError


class LowerUpperBound(AbstractBound):
    def __init__(self, lower: float, upper: float):
        """
        Data bound on a [lower, upper] closed interval.
        Data is transformed according to a log-odds transform.
        """
        super().__init__()
        self.lower = lower
        self.upper = upper

    def log_abs_det_jacobian(self, inv_logit):
        # Compute the log of the absolute value of the jacobian determinant of the transformation.

        # We need to sum logj across dimensions (dim=1). This way we get a logj value for each sample.
        # The transformation is independent for each dimension, so the Jacobian is diagonal and its log determinant is
        # the sum of log diagonal terms.
        return (np.log(self.upper - self.lower) + torch.log(inv_logit) + torch.log(1 - inv_logit)).sum(dim=1)

    def forward_with_logj(self, x: torch.Tensor):
        x_transformed = torch.logit((x - self.lower) / (self.upper - self.lower))
        inv_logit = torch.sigmoid(x)
        logj = -self.log_abs_det_jacobian(inv_logit)
        return x_transformed, logj

    def inverse_with_logj(self, x: torch.Tensor):
        inv_logit = torch.sigmoid(x)
        x_transformed = self.lower + (self.upper - self.lower) * inv_logit
        logj = self.log_abs_det_jacobian(inv_logit)
        return x_transformed, logj

    def clip_(self, x: torch.Tensor, min_magnitude: float = 1e-8):
        torch.clip_(x, self.lower + min_magnitude, self.upper - min_magnitude)


class LowerBound(AbstractBound):
    def __init__(self, lower: float):
        """
        Data bound on a [lower, infinity) right-open interval.
        Data is transformed according to a log transform.
        """
        super().__init__()
        self.lower = lower

    @staticmethod
    def log_abs_det_jacobian(x):
        return x

    def forward_with_logj(self, x: torch.Tensor):
        x_transformed = torch.log(x - self.lower)
        # logj = -self.log_abs_det_jacobian(x)
        # logj = -torch.log(x - self.lower)  # It is implied that x > self.lower everywhere.
        logj = -torch.clone(x_transformed)
        return x_transformed, logj

    def inverse_with_logj(self, x: torch.Tensor):
        x_transformed = torch.exp(x) + self.lower
        # logj = self.log_abs_det_jacobian(x)
        logj = torch.clone(x)
        return x_transformed, logj

    def clip_(self, x: torch.Tensor, min_magnitude: float = 1e-8):
        torch.clip_(x, self.lower + min_magnitude, torch.inf)


class UpperBound(AbstractBound):
    def __init__(self, upper: float):
        """
        Data bound on a (infinity, upper] left-open interval.
        Data is transformed according to a log transform.
        """
        super().__init__()
        self.upper = upper

    @staticmethod
    def log_abs_det_jacobian(x):
        return x

    def forward_with_logj(self, x: torch.Tensor):
        x_transformed = torch.log(self.upper - x)
        logj = -self.log_abs_det_jacobian(x)
        return x_transformed, logj

    def inverse_with_logj(self, x: torch.Tensor):
        x_transformed = self.upper - torch.exp(x)
        logj = self.log_abs_det_jacobian(x)
        return x_transformed, logj

    def clip_(self, x: torch.Tensor, min_magnitude: float = 1e-8):
        torch.clip_(x, -torch.inf, self.upper - min_magnitude)


class NoBound(AbstractBound):
    def __init__(self):
        """
        A no-op data bound. Data is kept as is, so the log abs det jacobian term is equal to 0.
        """
        super().__init__()

    def forward_with_logj(self, x: torch.Tensor):
        return x, torch.zeros(len(x), device=x.device)

    def inverse_with_logj(self, x: torch.Tensor):
        return x, torch.zeros(len(x), device=x.device)

    def clip_(self, x: torch.Tensor, min_magnitude: float = 1e-8):
        return x


class CompositeBound(AbstractBound):
    def __init__(self, bounds: List[AbstractBound]):
        super().__init__()
        self.bounds = bounds

    def forward_with_logj(self, x: torch.Tensor):
        n_dim = x.shape[1]
        output_x = torch.zeros_like(x)
        output_logj = torch.zeros(len(x))
        for i in range(n_dim):
            output_x_bound, output_logj_bound = self.bounds[i].forward_with_logj(x[:, i].view(-1, 1))
            output_x[:, i] = output_x_bound.view(-1)
            output_logj += output_logj_bound.view(-1)
        return output_x, output_logj

    def inverse_with_logj(self, x: torch.Tensor):
        n_dim = x.shape[1]
        output_x = torch.zeros_like(x)
        output_logj = torch.zeros(len(x))
        for i in range(n_dim):
            output_x_bound, output_logj_bound = self.bounds[i].inverse_with_logj(x[:, i].view(-1, 1))
            output_x[:, i] = output_x_bound.view(-1)
            output_logj += output_logj_bound.view(-1)
        return output_x, output_logj

    def clip_(self, x: torch.Tensor, min_magnitude: float = 1e-8):
        for i in range(x.shape[1]):
            self.bounds[i].clip_(x[:, i], min_magnitude)
