import numpy as np
import flax.linen as nn
import jax.numpy as jnp
from flax.linen.initializers import constant, orthogonal
from rl_x.environments.action_space_type import ActionSpaceType
from rl_x.environments.observation_space_type import ObservationSpaceType


def get_discriminator(config, env):
    action_space_type = env.general_properties.action_space_type
    observation_space_type = env.general_properties.observation_space_type

    if action_space_type == ActionSpaceType.CONTINUOUS and observation_space_type == ObservationSpaceType.FLAT_VALUES:
        return Discriminator(config.algorithm.nr_hidden_units_disc)


class Discriminator(nn.Module):
    nr_hidden_units_disc: int

    @nn.compact
    def __call__(self, x, y, absorbing):
        x = jnp.concatenate([x.flatten(), absorbing.flatten(), y.flatten()])
        discriminator = nn.Dense(self.nr_hidden_units_disc, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(x)
        discriminator = nn.relu(discriminator)
        discriminator = nn.Dense(self.nr_hidden_units_disc, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(discriminator)
        discriminator = nn.relu(discriminator)
        discriminator = nn.Dense(1, kernel_init=orthogonal(1), bias_init=constant(0.0))(discriminator)
        return discriminator
