import jax.numpy as jnp
import jax
from typing import Sequence, Literal, Callable

import numpy as np


class RiskMeasureGenerator(object):
    def __init__(self,
                 alpha: float,
                 reward_dim: int,
                 index: Sequence[int],
                 name: Literal['cvar', 'wang', 'triangle', 'neutral', 'simplex']
                 ):
        self.alpha = alpha
        self.reward_dim = reward_dim
        self.index = index
        self.name = name

    def __call__(self) -> Callable:
        if self.name == 'neutral':
            return jax.jit(lambda x: x)
        elif self.name == 'cvar':
            return self.cvar()
        elif self.name == 'wang':
            return self.wang()
        elif self.name == 'triangle':
            return self.triangle()
        elif self.name == 'simplex':
            return self.simplex()
        elif self.name == 'polar':
            return self.polar()
        elif self.name == 'polar_wang':
            return self.polar_wang()
        else:
            raise NotImplementedError(f"{self.name} is not implemented")

    def cvar(self, ):
        alpha = self.alpha
        index = self.index

        def cvar(x):
            for i in index:
                target = x[..., i, :]
                x = x.at[..., i, :].set(alpha * target)
            return x

        return jax.jit(cvar)

    def wang(self, ):
        alpha = self.alpha
        index = self.index

        def wang(x):
            for i in index:
                target = x[..., i, :]
                target = (target + 1e-6) * (1 - 1e-6)
                transform = jax.scipy.stats.norm.cdf(jax.scipy.stats.norm.ppf(target) + alpha)
                x = x.at[..., i, :].set(transform)
            return x

        return jax.jit(wang)

    def polar(self):
        alpha = self.alpha
        index = self.index

        @jax.jit
        def polar(x):

            r = jnp.maximum(x[..., index[0], :], x[..., index[1], :])  # x[..., 0] ** 2 + x[..., 1] ** 2
            r = r[..., None, :]
            r = alpha * r
            return alpha * jnp.where(r == 0, x, alpha * x)
        return polar

    def polar_wang(self):
        alpha = self.alpha
        index = self.index

        def polar_wang(x):
            r = jnp.maximum(x[..., index[0], :], x[..., index[1], :])  # x[..., 0] ** 2 + x[..., 1] ** 2
            r = r[..., None, :]
            target = r
            target = (target + 1e-6) * (1 - 1e-6)
            transform = jax.scipy.stats.norm.cdf(jax.scipy.stats.norm.ppf(target) + alpha)
            return jnp.where(r == 0, x, transform/r * x)
        return jax.jit(polar_wang)

    def triangle(self):
        index = self.index
        def triangle(x):
            u = x[..., index[0], :]
            v = x[..., index[1], :]
            r_1 = jnp.abs(u) + jnp.abs(v)
            r_1 = r_1[..., None, : ]
            r_infty = jnp.maximum(u, v)
            r_infty = r_infty[..., None, : ]

            return jnp.where(r_1 == 0, x, r_1 * x / r_infty)

        return jax.jit(triangle)


    def simplex(self):

        index = self.index
        """
        def simplex_measure(x):
            u = x[..., index[0], :]
            v = x[..., index[1], :]
            w = x[..., index[2], :]

            s1 = jnp.cbrt(u)  # u^(1/3)
            s2 = jnp.sqrt(v)  # v^(1/2)
            s3 = w  # w^(1)

            x1 = 1.0 - s1
            x2 = (1.0 - s2) * s1
            x3 = s3 * s1 * s2

            return alpha * jnp.stack([x1, x2, x3], axis=-2)
        """
        def simplex_measure(x):
            u = x[..., index[0], :]
            v = x[..., index[1], :]
            w = x[..., index[2], :]
            r_1 = jnp.abs(u) + jnp.abs(v) + jnp.abs(w)

            r_1 = r_1[..., None, : ]
            r_infty = jnp.maximum(jnp.maximum(u, v), w)
            r_infty = r_infty[..., None, : ]

            return jnp.where(r_1 == 0, x, r_1 * x / r_infty)

        return jax.jit(simplex_measure)


if __name__ == '__main__':
    gen = RiskMeasureGenerator(0.5, 3, [0, 1,2], 'simplex')
    fn = gen()
    outs = (fn(np.random.uniform(size=(512, 3, 8))))
    print(outs.shape)

