from abc import abstractmethod
from typing import Any, cast

import distrax
import equinox as eqx
import jax
import jax.numpy as jnp
from jax.scipy.special import logsumexp


class Flow(eqx.Module):
    prior: distrax.Distribution = eqx.field(init=False, static=True, default=None)
    prior_shape: tuple[int, ...] = eqx.field(init=False, default=None)

    def add_prior(self, prior, shape):
        object.__setattr__(self, 'prior', prior)
        object.__setattr__(self, 'prior_shape', shape)

        return self

    @abstractmethod
    def _forward(
        self,
        x: jax.Array,
        ldj: float,
        key: jax.Array | None = None,
        sideinfo: Any | None = None,
        params: Any | None = None,
    ):
        raise NotImplementedError

    @abstractmethod
    def _inverse(
        self,
        x: jax.Array,
        ldj: float,
        key: jax.Array | None = None,
        sideinfo: Any | None = None,
        params: Any | None = None,
    ):
        raise NotImplementedError

    def encode(
        self,
        x: jax.Array,
        key: jax.Array | None = None,
        sideinfo: Any | None = None,
        params: Any | None = None,
    ):
        return self._forward(x, 0, key=key, sideinfo=sideinfo, params=params)[0]

    def sample(
        self,
        key: jax.Array,
        sideinfo: Any | None = None,
        params: Any | None = None,
    ):
        assert self.prior is not None, 'Must call `add_prior` before calling `sample`'

        sample_key, inverse_key = jax.random.split(key)
        z = self.prior.sample(seed=sample_key, sample_shape=self.prior_shape)

        return self._inverse(z, 0, key=inverse_key, sideinfo=sideinfo, params=params)[0]

    def log_prob(
        self,
        x: jax.Array,
        reps: int,
        key: jax.Array | None = None,
        sideinfo: Any | None = None,
        params: Any | None = None,
    ):
        assert self.prior is not None, 'Must call `add_prior` before calling `log_prob`'

        if reps > 1:
            assert key is not None

            keys = jax.random.split(key, reps)
            z, ldj = jax.vmap(self._forward, in_axes=(None, None, 0, None, None))(
                x, 0, keys, sideinfo, params
            )

            # TODO: Outdated return signature, needs fixing if we decide to use it
            return logsumexp(
                self.prior.log_prob(z).reshape(z.shape[0], -1).sum(axis=-1) + ldj
            ) - jnp.log(reps)
        else:
            z, ldj = self._forward(x, 0, key=key, sideinfo=sideinfo, params=params)
            log_prob_z = self.prior.log_prob(z).sum()

            return log_prob_z + ldj, log_prob_z, ldj

    def loss(
        self,
        x: jax.Array,
        reps: int,
        key: jax.Array | None = None,
        sideinfo: Any | None = None,
        params: Any | None = None,
    ):
        log_prob_x, log_prob_z, ldj = self.log_prob(
            x, reps, key=key, sideinfo=sideinfo, params=params
        )

        return -log_prob_x, log_prob_z, ldj


class Identity(Flow):
    def _forward(self, x, ldj, key=None, sideinfo=None, params=None):
        return x, ldj

    def _inverse(self, x, ldj, key=None, sideinfo=None, params=None):
        return x, ldj


class Inverse(Flow):
    flow: Flow

    def __new__(cls, flow, *args, **kwargs):
        if flow.__class__ == Inverse:
            return flow.flow
        else:
            return super().__new__(cls, *args, **kwargs)

    def _forward(self, x, ldj, key=None, sideinfo=None, params=None):
        return self.flow._inverse(x, ldj, key=key)

    def _inverse(self, x, ldj, key=None, sideinfo=None, params=None):
        return self.flow._forward(x, ldj, key=key)


class Reverse(Flow):
    def _forward(self, x, ldj, key=None, sideinfo=None, params=None):
        x = jnp.transpose(x, jnp.arange(len(x.shape))[::-1])

        return x, ldj

    def _inverse(self, x, ldj, key=None, sideinfo=None, params=None):
        x = jnp.transpose(x, jnp.arange(len(x.shape))[::-1])

        return x, ldj


class Rescale(Flow):
    alpha: float

    def _forward(self, x, ldj, key=None, sideinfo=None, params=None):
        ldj += jnp.log(self.alpha) * x.size
        x = x * self.alpha

        return x, ldj

    def _inverse(self, x, ldj, key=None, sideinfo=None, params=None):
        ldj -= jnp.log(self.alpha) * x.size
        x = x / self.alpha

        return x, ldj


class Reshape(Flow):
    in_shape: tuple
    out_shape: tuple

    def _forward(self, x, ldj, key=None, sideinfo=None, params=None):
        x = jnp.reshape(x, self.out_shape)

        return x, ldj

    def _inverse(self, x, ldj, key=None, sideinfo=None, params=None):
        x = jnp.reshape(x, self.in_shape)

        return x, ldj


class Sigmoid(Flow):
    eps: float = 0.0

    def _forward(self, x, ldj, key=None, sideinfo=None, params=None):
        ldj += (-x + 2 * jnp.log(jax.nn.sigmoid(x))).sum()
        x = jax.nn.sigmoid(x)

        if self.eps is not None:
            ldj -= jnp.log(1 - self.eps) * x.size
            x = (x - 0.5 * self.eps) / (1 - self.eps)

        return x, ldj

    def _inverse(self, x, ldj, key=None, sideinfo=None, params=None):
        if self.eps is not None:
            ldj += jnp.log(1 - self.eps) * x.size
            x = x * (1 - self.eps) + 0.5 * self.eps

        ldj -= jnp.log(x - x**2).sum()
        x = jnp.log(x) - jnp.log(1 - x)

        return x, ldj


class Sequential(Flow):
    flows: list[Flow]

    def loss(self, *args, **kwargs):
        if any(f.prior is not None for f in self.flows):
            return self.residual_loss(*args, **kwargs)
        else:
            return super().loss(*args, **kwargs)

    def _forward(self, x, ldj, key=None, sideinfo=None, params=None):
        if key is None:
            for f in self.flows:
                x, ldj = f._forward(x, ldj, key=None, sideinfo=sideinfo, params=params)

            return x, ldj

        keys = jax.random.split(key, len(self.flows))
        for f, k in zip(self.flows, keys, strict=True):
            x, ldj = f._forward(x, ldj, key=k, sideinfo=sideinfo, params=params)

        return x, ldj

    def _inverse(self, x, ldj, key=None, sideinfo=None, params=None):
        if key is None:
            for f in self.flows[::-1]:
                x, ldj = f._inverse(x, ldj, key=None, sideinfo=sideinfo, params=params)

            return x, ldj

        keys = jax.random.split(key, len(self.flows))
        for f, k in zip(self.flows[::-1], keys, strict=True):
            x, ldj = f._inverse(x, ldj, key=k, sideinfo=sideinfo, params=params)

        return x, ldj

    def residual_loss(self, x, reps=1, key=None):
        assert reps == 1

        rv = 0
        ldj = 0

        if key is None:
            for f in self.flows:
                x, ldj = f._forward(x, ldj, key=None)

                if f.prior is not None:
                    rv += f.prior.log_prob(x).sum() + ldj
                    x = jax.lax.stop_gradient(x)
                    ldj = 0

            if self.prior is not None:
                rv += self.prior.log_prob(x).sum() + ldj

            return -rv

        keys = jax.random.split(key, len(self.flows))
        for f, k in zip(self.flows, keys, strict=True):
            x, ldj = f._forward(x, ldj, key=k)

            if f.prior is not None:
                rv += f.prior.log_prob(x).sum() + ldj
                x = jax.lax.stop_gradient(x)
                ldj = 0

        if self.prior is not None:
            rv += self.prior.log_prob(x).sum() + ldj

        return -rv


class Dequantize(Flow):
    max_val: ... = 255
    in_dtype: ... = jnp.uint8
    out_dtype: ... = jnp.float32
    var_flow: Flow | None = None
    eps: float = 1e-5
    squash: Flow = eqx.field(init=False, default=None)

    def __post_init__(self):
        if self.var_flow:
            self.var_flow = Sequential(
                [
                    Inverse(Sigmoid(self.eps)),
                    self.var_flow,
                    Sigmoid(self.eps),
                ]
            )

        self.squash = Rescale(1 / (self.max_val + 1))

    def _forward(self, x, ldj, key=None, sideinfo=None, params=None):
        assert key is not None
        _key, k = jax.random.split(key)
        v = jax.random.uniform(_key, x.shape, dtype=self.out_dtype, minval=0.0, maxval=1.0)

        if self.var_flow:
            sideinfo = x / self.max_val * 2 - 1
            _key, k = jax.random.split(k)
            v, ldj = self.var_flow._forward(v, ldj, key=_key, sideinfo=sideinfo)

        x = x + v
        x, ldj = self.squash._forward(x, ldj)

        return x, ldj

    def _inverse(self, x, ldj, key=None, sideinfo=None, params=None):
        x, ldj = self.squash._inverse(x, ldj)

        floor_x = jnp.floor(x).astype(self.in_dtype)
        v = x - floor_x
        x = floor_x

        if self.var_flow:
            sideinfo = x / self.max_val * 2 - 1
            assert key is not None
            v, ldj = self.var_flow._inverse(v, ldj, key=key, sideinfo=sideinfo)

        return x, ldj


class ParameterizedAffine(Flow):
    """One-dimensional affine transformation"""

    params: jax.Array | None
    scale: jax.Array

    def __init__(self, shape: tuple[int, ...], params: jax.Array | None = None):
        self.params = params
        self.scale = jnp.zeros(shape)

    def _forward(self, x, ldj, key=None, sideinfo=None, params=None):
        assert params is not None or self.params is not None, 'No parameters provided'
        params = cast(jax.Array, params if params is not None else self.params)
        a, b = params[:2]

        scale = jax.nn.softplus(self.scale) + 1e-6
        a = jnp.tanh(a / scale) * scale

        x = x * jnp.exp(a) + b
        ldj += a.sum()

        return x, ldj

    def _inverse(self, x, ldj, key=None, sideinfo=None, params=None):
        assert params is not None or self.params is not None, 'No parameters provided'
        params = cast(jax.Array, params if params is not None else self.params)
        a, b = params[:2]

        scale = jax.nn.softplus(self.scale) + 1e-6
        a = jnp.tanh(a / scale) * scale

        x = (x - b) / jnp.exp(a)
        ldj -= a.sum()

        return x, ldj


class ParameterizedHinge(Flow):
    """One-dimensional invertible hinge transformation"""

    params: jax.Array | None
    scale: jax.Array

    def __init__(self, shape: tuple[int, ...], params: jax.Array | None = None):
        self.params = params
        self.scale = jnp.zeros((2,) + shape)

    def _forward(self, x, ldj, key=None, sideinfo=None, params=None):
        assert params is not None or self.params is not None, 'No parameters provided'
        params = cast(jax.Array, params if params is not None else self.params)

        scale = jax.nn.softplus(self.scale) + 1e-6
        a, b = list(jnp.tanh(params[:2] / scale) * scale)

        ldj += jnp.where(x < 0, a, b).sum()
        x = jnp.where(x < 0, x * jnp.exp(a), x * jnp.exp(b))

        return x, ldj

    def _inverse(self, x, ldj, key=None, sideinfo=None, params=None):
        assert params is not None or self.params is not None, 'No parameters provided'
        params = cast(jax.Array, params if params is not None else self.params)

        scale = jax.nn.softplus(self.scale) + 1e-6
        a, b = list(jnp.tanh(params[:2] / scale) * scale)

        ldj -= jnp.where(x < 0, a, b).sum()
        x = jnp.where(x < 0, x / jnp.exp(a), x / jnp.exp(b))

        return x, ldj


class ParameterizedNLSq(Flow):
    """One-dimensional non-linear squared transformation

    f(x) = a + bx + c / (1 + (dx + f)^2)

    This is guaranteed to be invertible under certain requirements
    which are guaranteed by the parameter representation.
    """

    params: jax.Array | None
    scale_b: jax.Array
    scale_d: jax.Array

    _log_k = jnp.log(8 * jnp.sqrt(3) / 9 * 0.95)

    def __init__(self, shape: tuple[int, ...], params: jax.Array | None = None):
        self.params = params
        self.scale_b = jnp.zeros(shape)
        self.scale_d = jnp.zeros(shape)

    def _get_params(self, params: jax.Array | None = None) -> tuple[jax.Array, ...]:
        assert params is not None or self.params is not None, 'No parameters provided'
        params = cast(jax.Array, params if params is not None else self.params)
        _a, _b, _c, _d, _f = params[0:5]

        scale_b = jnp.exp(self.scale_b)
        scale_d = jnp.exp(self.scale_d)
        _b = jnp.tanh(_b / scale_b) * scale_b
        _d = jnp.tanh(_d / scale_d) * scale_d

        a = _a
        b = jax.nn.softplus(_b)
        d = jax.nn.softplus(_d)
        c = jnp.tanh(_c) * jnp.exp(self._log_k) * (b / (d + 1e-8))
        f = _f

        return a, b, c, d, f

    def _forward(self, x, ldj, key=None, sideinfo=None, params=None):
        a, b, c, d, f = self._get_params(params)

        u = d * x + f
        ldj += jnp.log(b - (2 * c * d * u / ((1 + u**2) ** 2))).sum()
        x = a + b * x + c / (1 + u**2)

        return x, ldj

    def _inverse(self, x, ldj, key=None, sideinfo=None, params=None):
        a, b, c, d, f = self._get_params(params)

        A = -b * d**2
        B = (x - a) * d**2 - 2 * b * d * f
        C = 2 * d * f * (x - a) - b * (f**2 + 1)
        D = (x - a) * (f**2 + 1) - c

        xN = -B / (3 * A)
        yN = A * xN**3 + B * xN**2 + C * xN + D
        deltasq = (B**2 - 3 * A * C) / (9 * A**2)

        delta = jnp.sqrt(jnp.abs(deltasq))
        h = 2 * A * delta**3

        sign = jnp.sign(yN / h)

        real_x = xN - 2 * sign * delta * jnp.cosh(jnp.arccosh(sign * yN / h) / 3)
        imag_x = xN - 2 * delta * jnp.sinh(jnp.arcsinh(yN / h) / 3)
        x = jnp.select([deltasq >= 0, deltasq < 0], [real_x, imag_x])

        u = d * x + f
        ldj -= jnp.log(b - (2 * c * d * u / ((1 + u**2) ** 2))).sum()

        return x, ldj


class Coupling(Flow):
    mask: jax.Array
    transform_param_net: eqx.Module
    transform: Flow | None = None
    dual: bool = False

    def __post_init__(self):
        if self.transform is None:
            self.transform = ParameterizedAffine(self.mask.shape)

    def _get_params(self, x: jax.Array, mask: jax.Array, sideinfo=None):
        if sideinfo is not None:
            params = self.transform_param_net(x * mask, sideinfo)
        else:
            params = self.transform_param_net(x * mask)

        return params.reshape((-1,) + x.shape) * jnp.expand_dims(1 - mask, 0)

    def _forward(self, x, ldj, key=None, sideinfo=None, params=None):
        params = self._get_params(x, self.mask, sideinfo=sideinfo)
        x, ldj = cast(Flow, self.transform)._forward(
            x, ldj, key=key, params=params, sideinfo=sideinfo
        )

        if self.dual:
            params = self._get_params(x, 1 - self.mask, sideinfo=sideinfo)
            x, ldj = cast(Flow, self.transform)._forward(
                x, ldj, key=key, params=params, sideinfo=sideinfo
            )

        return x, ldj

    def _inverse(self, x, ldj, key=None, sideinfo=None, params=None):
        if self.dual:
            params = self._get_params(x, 1 - self.mask, sideinfo=sideinfo)
            x, ldj = cast(Flow, self.transform)._inverse(
                x, ldj, key=key, params=params, sideinfo=sideinfo
            )

        params = self._get_params(x, self.mask, sideinfo=sideinfo)
        x, ldj = cast(Flow, self.transform)._inverse(
            x, ldj, key=key, params=params, sideinfo=sideinfo
        )

        return x, ldj


def create_mask(shape, block_size, dtype=jnp.uint8):
    """Generates masks for use with Coupling layers.

    shape: shape of the mask
    block_size: size of a constant-mask block

    If block_size[i] is negative, it partitions the i axis into block_size[i]
    number of blocks.

    Examples:
    >>> create_mask((8, 8), (2, 2))
    [[0 0 1 1 0 0 1 1]
     [0 0 1 1 0 0 1 1]
     [1 1 0 0 1 1 0 0]
     [1 1 0 0 1 1 0 0]
     [0 0 1 1 0 0 1 1]
     [0 0 1 1 0 0 1 1]
     [1 1 0 0 1 1 0 0]
     [1 1 0 0 1 1 0 0]]

    >>> create_mask((8, 8), (-1, 2))
    [[0 0 1 1 0 0 1 1]
     [0 0 1 1 0 0 1 1]
     [0 0 1 1 0 0 1 1]
     [0 0 1 1 0 0 1 1]
     [0 0 1 1 0 0 1 1]
     [0 0 1 1 0 0 1 1]
     [0 0 1 1 0 0 1 1]
     [0 0 1 1 0 0 1 1]]

    >>> create_mask((8, 8), (-2, -4))
    [[0 0 1 1 0 0 1 1]
     [0 0 1 1 0 0 1 1]
     [0 0 1 1 0 0 1 1]
     [0 0 1 1 0 0 1 1]
     [1 1 0 0 1 1 0 0]
     [1 1 0 0 1 1 0 0]
     [1 1 0 0 1 1 0 0]
     [1 1 0 0 1 1 0 0]]
    """
    block_size = [shape[i] // -d if d < 0 else d for i, d in enumerate(block_size)]
    mask = (
        sum(
            [
                jnp.reshape(
                    jnp.arange(shape[i]) // block_size[i],
                    [1 if j != i else -1 for j in range(len(shape))],
                )
                for i in range(len(shape))
            ]
        )
        % 2
    )

    return mask


def checkerboard_mask(shape, dtype=jnp.uint8):
    return create_mask(shape, (1,) * len(shape), dtype)


def channel_mask(shape, dtype=jnp.uint8):
    return create_mask(shape, (-1,) * len(shape) + (-2,), dtype)


def mask(dimension, i):
    if i % 2 == 0:
        return checkerboard_mask((dimension,), dtype=jnp.uint8)
    else:
        return create_mask((dimension,), (3,), dtype=jnp.uint8)
