import numpy as np
import flax.linen as nn
import jax.numpy as jnp
from flax.linen.initializers import constant, orthogonal
from rl_x.environments.action_space_type import ActionSpaceType
from rl_x.environments.observation_space_type import ObservationSpaceType


def get_discriminator(config, env, reward_type='state-action'):
    action_space_type = env.general_properties.action_space_type
    observation_space_type = env.general_properties.observation_space_type

    if action_space_type == ActionSpaceType.CONTINUOUS and observation_space_type == ObservationSpaceType.FLAT_VALUES:
        if reward_type == "shaped-state-only":
            return Discriminator(config.algorithm.nr_hidden_units_disc, config.algorithm.gamma, env.single_action_space.high, env.single_action_space.low)
        elif reward_type == "shaped-sa":
            return DiscriminatorSA(config.algorithm.nr_hidden_units_disc, config.algorithm.gamma, env.single_action_space.high, env.single_action_space.low)


class DiscriminatorSA(nn.Module):
    nr_hidden_units_disc: int
    gamma: int
    as_high: jnp.ndarray
    as_low: jnp.ndarray

    def setup(self):
        self.gnet_dense1 = nn.Dense(self.nr_hidden_units_disc, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))
        self.gnet_dense2 = nn.Dense(self.nr_hidden_units_disc, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))
        self.gnet_dense3 = nn.Dense(1, kernel_init=orthogonal(1), bias_init=constant(0.0))

        self.hnet_dense1 = nn.Dense(self.nr_hidden_units_disc, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))
        self.hnet_dense2 = nn.Dense(self.nr_hidden_units_disc, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))
        self.hnet_dense3 = nn.Dense(1, kernel_init=orthogonal(1), bias_init=constant(0.0))
        self.H_terminal = jnp.sum(jnp.log(self.as_high - self.as_low))

    def __call__(self, x, y, x_n, absorbing, logp, shaping: float = 1.0):
        """
        Args:
            x : state
            y: action
            x_n: next state
            abs: bool for whether the next state is absorbing
            logp: log prob of rollout action corresponding to the state
        """
        xa = jnp.concatenate([x.flatten(), y.flatten()])
        xa_n = jnp.concatenate([x_n.flatten(), y.flatten()])

        # g(x)
        r = self.gnet_dense1(xa)
        r = nn.relu(r)
        r = self.gnet_dense2(r)
        r = nn.relu(r)
        r = self.gnet_dense3(r)

        # g(x_n)
        rx_n = self.gnet_dense1(xa_n)
        rx_n = nn.relu(rx_n)
        rx_n = self.gnet_dense2(rx_n)
        rx_n = nn.relu(rx_n)
        rx_n = self.gnet_dense3(rx_n)

        # h(x)
        hx = self.hnet_dense1(x)
        hx = nn.relu(hx)
        hx = self.hnet_dense2(hx)
        hx = nn.relu(hx)
        hx = self.hnet_dense3(hx)

        # h(x_n)
        hx_n = self.hnet_dense1(x_n)
        hx_n = nn.relu(hx_n)
        hx_n = self.hnet_dense2(hx_n)
        hx_n = nn.relu(hx_n)
        hx_n = self.hnet_dense3(hx_n)

        # AIRL reward: step. 6 in alg 1 in https://arxiv.org/pdf/1710.11248v2
        f = r + shaping * ((1 - absorbing) * self.gamma * hx_n + absorbing * ((self.gamma/(1 - self.gamma)) * rx_n) - hx)
        reward = f - shaping * logp # D = sigmoid(reward) and we compute loss using sigmoid BCE

        return reward


class Discriminator(nn.Module):
    nr_hidden_units_disc: int
    gamma: int
    as_high: jnp.ndarray
    as_low: jnp.ndarray

    def setup(self):
        self.gnet_dense1 = nn.Dense(self.nr_hidden_units_disc, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))
        self.gnet_dense2 = nn.Dense(self.nr_hidden_units_disc, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))
        self.gnet_dense3 = nn.Dense(1, kernel_init=orthogonal(1), bias_init=constant(0.0))

        self.hnet_dense1 = nn.Dense(self.nr_hidden_units_disc, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))
        self.hnet_dense2 = nn.Dense(self.nr_hidden_units_disc, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))
        self.hnet_dense3 = nn.Dense(1, kernel_init=orthogonal(1), bias_init=constant(0.0))
        self.H_terminal = jnp.sum(jnp.log(self.as_high - self.as_low))

    def __call__(self, x, y, x_n, absorbing, logp, shaping: float = 1.0):
        """
        Args:
            x : state
            y: action (unused)
            x_n: next state
            abs: bool for whether the next state is absorbing
            logp: log prob of rollout action corresponding to the state
        """
        # g(x)
        r = self.gnet_dense1(x)
        r = nn.relu(r)
        r = self.gnet_dense2(r)
        r = nn.relu(r)
        r = self.gnet_dense3(r)

        # g(x_n)
        rx_n = self.gnet_dense1(x_n)
        rx_n = nn.relu(rx_n)
        rx_n = self.gnet_dense2(rx_n)
        rx_n = nn.relu(rx_n)
        rx_n = self.gnet_dense3(rx_n)

        # h(x)
        hx = self.hnet_dense1(x)
        hx = nn.relu(hx)
        hx = self.hnet_dense2(hx)
        hx = nn.relu(hx)
        hx = self.hnet_dense3(hx)

        # h(x_n)
        hx_n = self.hnet_dense1(x_n)
        hx_n = nn.relu(hx_n)
        hx_n = self.hnet_dense2(hx_n)
        hx_n = nn.relu(hx_n)
        hx_n = self.hnet_dense3(hx_n)

        # AIRL reward: step. 6 in alg 1 in https://arxiv.org/pdf/1710.11248v2
        f = r + shaping * ((1 - absorbing) * self.gamma * hx_n + absorbing * ((self.gamma/(1 - self.gamma)) * rx_n) - hx)
        reward = f - shaping * logp # D = sigmoid(reward) and we compute loss using sigmoid BCE

        return reward
