from typing import Tuple, Type, TypeVar, Union

import torch
import torch.nn.functional as F
from torch import Tensor

tanh_eps = 1e-20
euler_gamma = 0.57721566490153286060


def _box_shape_ok(t: Tensor, learnt_temp=False) -> bool:
    if len(t.shape) < 2:
        return False
    if not learnt_temp:
        if t.size(-2) != 2:
            return False
        return True
    else:
        if t.size(-2) != 4:
            return False

        return True


def _shape_error_str(tensor_name, expected_shape, actual_shape):
    return "Shape of {} has to be {} but is {}".format(
        tensor_name, expected_shape, tuple(actual_shape)
    )


# see: https://realpython.com/python-type-checking/#type-hints-for-methods
# to know why we need to use TypeVar
TBoxTensor = TypeVar("TBoxTensor", bound="BoxTensor")


class BoxTensor(object):
    """A wrapper to which contains single tensor which
    represents single or multiple boxes.

    Have to use composition instead of inheritance because
    it is not safe to interit from :class:`torch.Tensor` because
    creating an instance of such a class will always make it a leaf node.
    This works for :class:`torch.nn.Parameter` but won't work for a general
    box_tensor.
    """

    def __init__(self, data: Tensor, learnt_temp: bool = False) -> None:
        """
        .. todo:: Validate the values of z, Z ? z < Z

        Arguments:
            data: Tensor of shape (**, zZ, num_dims). Here, zZ=2, where
                the 0th dim is for bottom left corner and 1st dim is for
                top right corner of the box
        """

        if _box_shape_ok(data, learnt_temp):
            self.data = data
        else:
            raise ValueError(_shape_error_str("data", "(**,2,num_dims)", data.shape))
        super().__init__()

    def __repr__(self):
        return "box_tensor_wrapper(" + self.data.__repr__() + ")"

    def __getitem__(self, index):
        return self.data[index]

    @property
    def z(self) -> Tensor:
        """Lower left coordinate as Tensor"""

        return self.data[..., 0, :]

    @property
    def volume(self) -> Tensor:
        return self._volume(self.z, self.Z)

    def _volume(self, z, Z) -> Tensor:
        return torch.sum(torch.log(Z - z), dim=-1)

    @property
    def Z(self) -> Tensor:
        """Top right coordinate as Tensor"""

        return self.data[..., 1, :]

    @property
    def box_type(self):
        return "BoxTensor"

    @property
    def centre(self) -> Tensor:
        """Centre coordinate as Tensor"""

        return (self.z + self.Z) / 2

    @classmethod
    def from_zZ(cls: Type[TBoxTensor], z: Tensor, Z: Tensor) -> TBoxTensor:
        """
        Creates a box by stacking z and Z along -2 dim.
        That is if z.shape == Z.shape == (**, num_dim),
        then the result would be box of shape (**, 2, num_dim)
        """

        if z.shape != Z.shape:
            raise ValueError(
                "Shape of z and Z should be same but is {} and {}".format(
                    z.shape, Z.shape
                )
            )
        box_val: Tensor = torch.stack((z, Z), -2)

        return cls(box_val)

    @classmethod
    def from_split(cls: Type[TBoxTensor], t: Tensor, dim: int = -1) -> TBoxTensor:
        """Creates a BoxTensor by splitting on the dimension dim at midpoint

        Args:
            t: input
            dim: dimension to split on

        Returns:
            BoxTensor: output BoxTensor

        Raises:
            ValueError: `dim` has to be even
        """
        len_dim = t.size(dim)

        if len_dim % 2 != 0:
            raise ValueError(
                "dim has to be even to split on it but is {}".format(t.size(dim))
            )
        split_point = int(len_dim / 2)
        z = t.index_select(
            dim,
            torch.tensor(list(range(split_point)), dtype=torch.int64, device=t.device),
        )

        Z = t.index_select(
            dim,
            torch.tensor(
                list(range(split_point, len_dim)), dtype=torch.int64, device=t.device
            ),
        )

        return cls.from_zZ(z, Z)

    def _intersection(
        self: TBoxTensor,
        other: TBoxTensor,
        intersection_temp: float = 1.0,
        bayesian: bool = False,
    ) -> Tuple[Tensor, Tensor]:
        t1 = self
        t2 = other

        if bayesian:
            try:
                z = intersection_temp * torch.logaddexp(
                    t1.z / intersection_temp, t2.z / intersection_temp
                )
                z = torch.max(z, torch.max(t1.z, t2.z))
                Z = -intersection_temp * torch.logaddexp(
                    -t1.Z / intersection_temp, -t2.Z / intersection_temp
                )
                Z = torch.min(Z, torch.min(t1.Z, t2.Z))
            except Exception as e:
                print("Gumbel intersection is not possible")
                ValueError(e)
        else:
            z = torch.max(t1.z, t2.z)
            Z = torch.min(t1.Z, t2.Z)

        return z, Z

    def hard_intersection_volume(self, other):
        z, Z = self._intersection(other)
        return self._volume(z, Z)

    def intersection(
        self: TBoxTensor,
        other: TBoxTensor,
        intersection_temp: float = 1.0,
        bayesian: bool = False,
    ) -> TBoxTensor:
        """Gives intersection of self and other.

        .. note:: This function can give fipped boxes, i.e. where z[i] > Z[i]
        """
        z, Z = self._intersection(
            other, intersection_temp=intersection_temp, bayesian=bayesian
        )

        return self.from_zZ(z, Z)

    def join_volume(self: TBoxTensor, other: TBoxTensor, bayesian: bool = False):
        """Gives join"""
        t1 = self
        t2 = other

        if bayesian:
            try:
                z = -1.0 * torch.logaddexp(-t1.z / 1.0, -t2.z / 1.0)
                z = torch.max(z, torch.min(t1.z, t2.z))
                # z = torch.min(z, torch.min(t1.z, t2.z))
                Z = 1.0 * torch.logaddexp(t1.Z / 1.0, t2.Z / 1.0)
                Z = torch.min(Z, torch.max(t1.Z, t2.Z))
                # Z = torch.max(Z, torch.max(t1.Z, t2.Z))
            except Exception as e:
                print("Gumbel intersection is not possible")
                ValueError(e)
        else:
            z = torch.min(self.z, other.z)
            Z = torch.max(self.Z, other.Z)

        return self._log_soft_volume(z, Z)

    def join(self: TBoxTensor, other: TBoxTensor, bayesian: bool = False):
        """Gives join"""
        t1 = self
        t2 = other

        if bayesian:
            try:
                z = -1.0 * torch.logaddexp(-t1.z / 1.0, -t2.z / 1.0)
                z = torch.max(z, torch.min(t1.z, t2.z))
                Z = 1.0 * torch.logaddexp(t1.Z / 1.0, t2.Z / 1.0)
                Z = torch.min(Z, torch.max(t1.Z, t2.Z))
            except Exception as e:
                print("Gumbel intersection is not possible")
                ValueError(e)
        else:
            z = torch.min(self.z, other.z)
            Z = torch.max(self.Z, other.Z)

        return self.from_zZ(z, Z)

    @classmethod
    def _log_soft_volume(cls, z: Tensor, Z: Tensor, volume_temp: float = 1.0) -> Tensor:
        eps = torch.finfo(z.dtype).tiny  # type: ignore

        return torch.sum(
            torch.log(F.softplus(Z - z, beta=volume_temp) + 1e-23), dim=-1
        )  # need this eps to that the derivative of log does not blow

    def log_soft_volume(self, volume_temp: float = 1.0) -> Tensor:
        res = self._log_soft_volume(self.z, self.Z, volume_temp=volume_temp)

        return res

    # def volume(self) -> Tensor:
    #     return torch.sum(
    #         torch.log(self.Z - self.z), dim=-1
    #     )

    @classmethod
    def _log_soft_volume_adjusted(
        cls,
        z: Tensor,
        Z: Tensor,
        volume_temp: float = 1.0,
        intersection_temp: float = 1.0,
    ) -> Tensor:
        #  eps = torch.finfo(z.dtype).tiny  # type: ignore

        return torch.sum(
            torch.log(
                F.softplus(
                    Z - z - 2 * euler_gamma * intersection_temp, beta=volume_temp
                )
                + 1e-23
            ),
            dim=-1,
        )

    def log_soft_volume_adjusted(
        self, volume_temp: float = 1.0, intersection_temp: float = 1.0
    ) -> Tensor:
        res = self._log_soft_volume_adjusted(
            self.z, self.Z, volume_temp=volume_temp, intersection_temp=intersection_temp
        )
        return res

    def intersection_log_soft_volume(
        self,
        other: TBoxTensor,
        volume_temp: float = 1.0,
        intersection_temp: float = 1.0,
        bayesian: bool = False,
        scale: Union[float, Tensor] = 1.0,
    ) -> Tensor:
        z, Z = self._intersection(other, intersection_temp, bayesian)
        vol = self._log_soft_volume(z, Z, volume_temp=volume_temp)

        return vol

    def gumbel_intersection_log_volume(
        self: TBoxTensor,
        other: TBoxTensor,
        volume_temp=1.0,
        intersection_temp: float = 1.0,
    ) -> TBoxTensor:
        z, Z = self._intersection(
            other, intersection_temp=intersection_temp, bayesian=True
        )
        vol = self._log_soft_volume_adjusted(
            z, Z, volume_temp=volume_temp, intersection_temp=intersection_temp
        )
        return vol

    def log_gumbel_membership(
        self: TBoxTensor,
        v: Tensor,
        membership_temp=1.0,
    ) -> TBoxTensor:
        eps = torch.finfo(v.dtype).tiny
        # log_membership = -torch.exp(
        #     torch.logsumexp((torch.logaddexp((v - self.Z)/membership_temp, -(v - self.z)/membership_temp)), axis=-1))
        log_membership = -torch.exp((v - self.Z) / membership_temp).sum(
            axis=-1
        ) - torch.exp(-(v - self.z) / membership_temp).sum(axis=-1)
        return log_membership

    @classmethod
    def get_wW(cls, z, Z):
        return z, Z

    @classmethod
    def _weights_init(cls, weights: torch.Tensor):
        """An in-place weight initializer method
        which can be used to do sensible init
        of weights depending on box type.
        For this base class, this method does nothing"""
        pass


class CenterDeltaBoxTensor(BoxTensor):
    """Same as BoxTensor but with a different parameterization: (**, wW, num_dims)

    z = c - delta
    Z = c + delta(which is always positive)
    """

    @property
    def z(self) -> Tensor:
        return self.center - self.delta

    @property
    def Z(self) -> Tensor:
        return self.center + self.delta

    @property
    def center(self) -> Tensor:
        return self.data[..., 0, :]

    @property
    def delta(self) -> Tensor:
        return torch.nn.functional.softplus(self.data[..., 1, :], beta=10)

    @classmethod
    def from_zZ(cls: Type[TBoxTensor], z: Tensor, Z: Tensor) -> TBoxTensor:
        if z.shape != Z.shape:
            raise ValueError(
                "Shape of z and Z should be same but is {} and {}".format(
                    z.shape, Z.shape
                )
            )
        w, W = cls.get_wW(z, Z)  # type:ignore

        box_val: Tensor = torch.stack((w, W), -2)

        return cls(box_val)

    @classmethod
    def get_wW(cls, z: Tensor, Z: Tensor):
        """Convert from (z, Z) coordinates to (center, delta_raw) coordinates.

        For CenterDeltaBoxTensor:
        - center = (z + Z) / 2
        - delta = (Z - z) / 2
        - delta_raw = softplus_inverse(delta, beta=10)
        """
        center = (z + Z) / 2
        delta = (Z - z) / 2
        delta_raw = _softplus_inverse(delta, beta=10)
        return center, delta_raw

    @classmethod
    def from_split(cls: Type[TBoxTensor], t: Tensor, dim: int = -1) -> TBoxTensor:
        """Creates a BoxTensor by splitting on the dimension dim at midpoint

        Args:
            t: input
            dim: dimension to split on

        Returns:
            BoxTensor: output BoxTensor

        Raises:
            ValueError: `dim` has to be even
        """
        len_dim = t.size(dim)

        if len_dim % 2 != 0:
            raise ValueError(
                "dim has to be even to split on it but is {}".format(t.size(dim))
            )
        split_point = int(len_dim / 2)
        w = t.index_select(
            dim,
            torch.tensor(list(range(split_point)), dtype=torch.int64, device=t.device),
        )

        W = t.index_select(
            dim,
            torch.tensor(
                list(range(split_point, len_dim)), dtype=torch.int64, device=t.device
            ),
        )
        box_val: Tensor = torch.stack((w, W), -2)

        return cls(box_val)


class CenterScalarDeltaBoxTensor(CenterDeltaBoxTensor):
    """Same as CenterDeltaBoxTensor but with a scalar delta parameterization: (**, wW, num_dims)

    z = c - delta
    Z = c + delta(which is always positive)
    delta is same for all dimensions

    """

    @classmethod
    def from_split(cls: type[TBoxTensor], t: Tensor, dim: int = -1) -> TBoxTensor:
        """Creates a BoxTensor by splitting on the dimension such that first d-1 dimensions are center and last dimension is delta
        Then repeat the delta for all dimensions to get W.
        """
        w = t[..., :-1]
        delta = t[..., -1].unsqueeze(-1)
        W = delta.repeat(1, w.shape[-1]).reshape_as(w)
        # TODO: improve the line below
        # w = w.reshape_as(W)

        box_val: Tensor = torch.stack((w, W), -2)
        return cls(box_val)

    # @classmethod
    # def log_soft_volume(self, volume_temp: float = 1.0) -> Tensor:
    #     res = self._log_soft_volume(self.z.squeeze(), self.Z.squeeze(), volume_temp=volume_temp)
    #
    #     return res


def _softplus_inverse(t: torch.Tensor, beta=1.0, threshold=20):
    below_thresh = beta * t < threshold
    res = t
    res[below_thresh] = torch.log(torch.exp(beta * t[below_thresh]) - 1.0) / beta

    return res
