import functools

from jaxrl_m.dataset import Dataset
from jaxrl_m.typing import *
from jaxrl_m.networks import *
import jax


class LayerNormMLP(nn.Module):
    hidden_dims: Sequence[int]
    activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.gelu
    activate_final: int = False
    kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_init()

    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        for i, size in enumerate(self.hidden_dims[:-1]):
            x = nn.Dense(size, kernel_init=self.kernel_init)(x)
            x = self.activations(x)
            x = nn.LayerNorm()(x)

        out = nn.Dense(self.hidden_dims[-1], kernel_init=self.kernel_init)(x)
        if self.activate_final:
            out = self.activations(out)
            out = nn.LayerNorm()(out)

        return x, out


class LayerNormRepresentation(nn.Module):
    hidden_dims: tuple = (256, 256)
    activate_final: bool = True
    ensemble: bool = True

    @nn.compact
    def __call__(self, observations):
        module = LayerNormMLP
        if self.ensemble:
            module = ensemblize(module, 2)
        return module(self.hidden_dims, activate_final=self.activate_final)(
            observations
        )


class Representation(nn.Module):
    hidden_dims: tuple = (256, 256)
    activate_final: bool = True
    ensemble: bool = True

    @nn.compact
    def __call__(self, observations):
        module = MLP
        if self.ensemble:
            module = ensemblize(module, 2)
        return module(
            self.hidden_dims, activate_final=self.activate_final, activations=nn.gelu
        )(observations)


class GoalConditionedValue(nn.Module):
    hidden_dims: tuple = (256, 256)
    readout_size: tuple = (256,)
    use_layer_norm: bool = True
    ensemble: bool = True
    encoder: nn.Module = None

    def setup(self) -> None:
        repr_class = LayerNormRepresentation if self.use_layer_norm else Representation
        value_net = repr_class(
            (*self.hidden_dims, 1), activate_final=False, ensemble=self.ensemble
        )
        if self.encoder is not None:
            value_net = nn.Sequential([self.encoder(), value_net])
        self.value_net = value_net

    def __call__(self, observations, goals=None, info=False):
        if goals is None:
            feat, v = self.value_net(observations).squeeze(-1)
        else:
            feat, v = self.value_net(jnp.concatenate([observations, goals], axis=-1))

        return feat, v.squeeze(-1)


class GoalConditionedPhiValue(nn.Module):
    hidden_dims: tuple = (256, 256)
    readout_size: tuple = (256,)
    skill_dim: int = 2
    use_layer_norm: bool = True
    ensemble: bool = True
    encoder: nn.Module = None

    def setup(self) -> None:
        repr_class = LayerNormRepresentation if self.use_layer_norm else Representation
        phi = repr_class(
            (*self.hidden_dims, self.skill_dim),
            activate_final=False,
            ensemble=self.ensemble,
        )
        if self.encoder is not None:
            phi = nn.Sequential([self.encoder(), phi])
        self.phi = phi

    def get_phi(self, observations):
        _, phi = self.phi(observations)
        return phi[0]  # Use the first vf

    def __call__(self, observations, goals=None, info=False):
        _, phi_s = self.phi(observations)
        _, phi_g = self.phi(goals)
        squared_dist = ((phi_s - phi_g) ** 2).sum(axis=-1)
        v = -jnp.sqrt(jnp.maximum(squared_dist, 1e-6))

        return v


class GoalConditionedCritic(nn.Module):
    hidden_dims: tuple = (256, 256)
    readout_size: tuple = (256,)
    use_layer_norm: bool = True
    ensemble: bool = True
    encoder: nn.Module = None

    def setup(self) -> None:
        repr_class = LayerNormRepresentation if self.use_layer_norm else Representation
        critic_net = repr_class(
            (*self.hidden_dims, 1), activate_final=False, ensemble=self.ensemble
        )
        if self.encoder is not None:
            critic_net = nn.Sequential([self.encoder(), critic_net])
        self.critic_net = critic_net

    def __call__(self, observations, goals=None, actions=None, info=False):
        if goals is None:
            feat, q = self.critic_net(jnp.concatenate([observations, actions], axis=-1))
        else:
            feat, q = self.critic_net(
                jnp.concatenate([observations, goals, actions], axis=-1)
            )

        return feat, q.squeeze(-1)


def get_rep(
    encoder: nn.Module,
    targets: jnp.ndarray,
    bases: jnp.ndarray = None,
):
    if encoder is None:
        return targets
    else:
        if bases is None:
            return encoder(targets)
        else:
            return encoder(targets, bases)


class HILPNetwork(nn.Module):
    networks: Dict[str, nn.Module]

    def unsqueeze_context(self, observations, contexts):
        if len(observations.shape) <= 2:
            return contexts
        else:
            # observations: (H, W, D) or (B, H, W, D)
            # contexts: (Z) -> (H, W, Z) or (B, Z) -> (B, H, W, Z)
            assert len(observations.shape) == len(contexts.shape) + 2
            return (
                jnp.expand_dims(jnp.expand_dims(contexts, axis=-2), axis=-2)
                .repeat(observations.shape[-3], axis=-3)
                .repeat(observations.shape[-2], axis=-2)
            )

    def value(self, observations, goals=None, **kwargs):
        return self.networks["value"](observations, goals, **kwargs)

    def target_value(self, observations, goals=None, **kwargs):
        return self.networks["target_value"](observations, goals, **kwargs)

    def phi(self, observations, **kwargs):
        return self.networks["value"].get_phi(observations, **kwargs)

    def skill_value(self, observations, skills, **kwargs):
        skills = self.unsqueeze_context(observations, skills)
        return self.networks["skill_value"](observations, skills, **kwargs)

    def skill_target_value(self, observations, skills, **kwargs):
        skills = self.unsqueeze_context(observations, skills)
        return self.networks["skill_target_value"](observations, skills, **kwargs)

    def skill_critic(self, observations, skills, actions=None, **kwargs):
        skills = self.unsqueeze_context(observations, skills)
        actions = self.unsqueeze_context(observations, actions)
        return self.networks["skill_critic"](observations, skills, actions, **kwargs)

    def skill_target_critic(self, observations, skills, actions=None, **kwargs):
        skills = self.unsqueeze_context(observations, skills)
        actions = self.unsqueeze_context(observations, actions)
        return self.networks["skill_target_critic"](
            observations, skills, actions, **kwargs
        )

    def skill_actor(self, observations, skills, **kwargs):
        skills = self.unsqueeze_context(observations, skills)
        return self.networks["skill_actor"](
            jnp.concatenate([observations, skills], axis=-1), **kwargs
        )

    def __call__(self, observations, goals, actions, skills):
        # Only for initialization
        rets = {
            "value": self.value(observations, goals),
            "target_value": self.target_value(observations, goals),
            "skill_actor": self.skill_actor(observations, skills),
            "skill_value": self.skill_value(observations, skills),
            "skill_critic": self.skill_critic(observations, skills, actions),
            "skill_target_critic": self.skill_target_critic(
                observations, skills, actions
            ),
        }
        return rets


class SimpleGRU(nn.Module):
    @functools.partial(
        nn.transforms.scan,
        variable_broadcast="params",
        in_axes=1,
        out_axes=1,
        split_rngs={"params": False},
    )
    @nn.compact
    def __call__(self, carry, x):
        return nn.GRUCell()(carry, x)

    @staticmethod
    def initialize_carry(batch_dims, hidden_size):
        return nn.GRUCell.initialize_carry(
            jax.random.PRNGKey(0), batch_dims, hidden_size
        )


class SimpleBiGRU(nn.Module):
    hidden_size: int

    def setup(self):
        self.forward_gru = SimpleGRU()
        self.backward_gru = SimpleGRU()

    def __call__(self, embedded_inputs):
        batch_size = embedded_inputs.shape[0]

        initial_state = SimpleGRU.initialize_carry((batch_size,), self.hidden_size)
        _, forward_outputs = self.forward_gru(initial_state, embedded_inputs)

        reversed_inputs = embedded_inputs[:, ::-1, :]
        initial_state = SimpleGRU.initialize_carry((batch_size,), self.hidden_size)
        _, backward_outputs = self.backward_gru(initial_state, reversed_inputs)
        backward_outputs = backward_outputs[:, ::-1, :]

        outputs = jnp.concatenate([forward_outputs, backward_outputs], -1)
        return outputs


class SeqEncoder(nn.Module):
    num_recur_layers: int = 2
    output_dim: int = 2
    recur_output: str = "concat"

    def setup(self) -> None:
        self.obs_mlp = MLP([256, 256], activate_final=True)
        self.recurs = [SimpleBiGRU(256) for _ in range(self.num_recur_layers)]
        self.projection = MLP([self.output_dim], activate_final=False)

    def __call__(
        self,
        seq_observations: jnp.ndarray,
        seq_actions: jnp.ndarray,
    ):
        B, C, D = seq_observations.shape
        observations = jnp.reshape(seq_observations, (B * C, D))
        outputs = jnp.reshape(self.obs_mlp(observations), (B, C, -1))
        outputs = jnp.concatenate([outputs, seq_actions], axis=-1)
        for recur in self.recurs:
            outputs = recur(outputs)
        if self.recur_output == "concat":
            outputs = jnp.reshape(outputs, (B, -1))
        else:
            outputs = outputs[:, -1]
        outputs = self.projection(outputs)

        return outputs


class GaussianModule(nn.Module):
    hidden_dims: Sequence[int]
    output_dim: int
    log_std_min: Optional[float] = -20
    log_std_max: Optional[float] = 2
    final_fc_init_scale: float = 1e-2

    @nn.compact
    def __call__(
        self,
        inputs: jnp.ndarray,
        temperature: float = 1.0,
    ) -> distrax.Distribution:
        outputs = MLP(self.hidden_dims, activate_final=True)(inputs)

        means = nn.Dense(
            self.output_dim, kernel_init=default_init(self.final_fc_init_scale)
        )(outputs)
        log_stds = nn.Dense(
            self.output_dim, kernel_init=default_init(self.final_fc_init_scale)
        )(outputs)

        log_stds = jnp.clip(log_stds, self.log_std_min, self.log_std_max)

        distribution = distrax.MultivariateNormalDiag(
            loc=means, scale_diag=jnp.exp(log_stds) * temperature
        )

        return distribution


class VAE(nn.Module):
    hidden_dims: Sequence[int]
    action_dim: int
    skill_dim: int
    recur_output: str

    def setup(self) -> None:
        self.seq_encoder = SeqEncoder(
            num_recur_layers=2,
            output_dim=self.skill_dim * 2,
            recur_output=self.recur_output,
        )
        self.prior_model = GaussianModule(self.hidden_dims, self.skill_dim)
        self.recon_model = GaussianModule(self.hidden_dims, self.action_dim)

    def encode(self, seq_observations: jnp.ndarray, seq_actions: jnp.ndarray):
        outputs = self.seq_encoder(seq_observations, seq_actions)
        return outputs[..., : self.skill_dim]

    def act(
        self,
        observations: jnp.ndarray,
        skills: jnp.array,
        temperature: float = 1.0,
    ) -> distrax.Distribution:
        szs = jnp.concatenate([observations, skills], axis=-1)
        action_dists = self.recon_model(szs, temperature=temperature)

        return action_dists

    def __call__(
        self,
        seq_observations: jnp.ndarray,
        seq_actions: jnp.ndarray,
        z_rng,
    ):
        B, C, D = seq_observations.shape
        outputs = self.seq_encoder(seq_observations, seq_actions)
        means = outputs[..., : self.skill_dim]
        log_stds = outputs[..., self.skill_dim :]
        stds = jnp.exp(0.5 * log_stds)
        posteriors = distrax.MultivariateNormalDiag(loc=means, scale_diag=stds)

        priors = self.prior_model(seq_observations[:, 0])

        zs = means + stds * jax.random.normal(z_rng, means.shape)
        zs = jnp.expand_dims(zs, axis=1).repeat(C, axis=1)
        szs = jnp.concatenate([seq_observations, zs], axis=-1)
        recon_action_dists = self.recon_model(szs)

        return recon_action_dists, priors, posteriors


class GCMetraNetwork(nn.Module):
    networks: Dict[str, nn.Module]
    skill_value_only: int
    hrl: int

    def unsqueeze_context(self, observations, contexts):
        if len(observations.shape) <= 2:
            return contexts
        else:
            # observations: (H, W, D) or (B, H, W, D)
            # contexts: (Z) -> (H, W, Z) or (B, Z) -> (B, H, W, Z)
            assert len(observations.shape) == len(contexts.shape) + 2
            return (
                jnp.expand_dims(jnp.expand_dims(contexts, axis=-2), axis=-2)
                .repeat(observations.shape[-3], axis=-3)
                .repeat(observations.shape[-2], axis=-2)
            )

    def value(self, observations, goals=None, **kwargs):
        return self.networks["value"](observations, goals, **kwargs)

    def target_value(self, observations, goals=None, **kwargs):
        return self.networks["target_value"](observations, goals, **kwargs)

    def phi(self, observations, **kwargs):
        return self.networks["value"].get_phi(observations, **kwargs)

    def actor(self, observations, goals=None, **kwargs):
        return self.networks["actor"](
            jnp.concatenate([observations, goals], axis=-1), **kwargs
        )

    def skill_value(self, observations, skills, **kwargs):
        skills = self.unsqueeze_context(observations, skills)
        return self.networks["skill_value"](observations, skills, **kwargs)

    def skill_target_value(self, observations, skills, **kwargs):
        skills = self.unsqueeze_context(observations, skills)
        return self.networks["skill_target_value"](observations, skills, **kwargs)

    def skill_critic(self, observations, skills, actions=None, **kwargs):
        skills = self.unsqueeze_context(observations, skills)
        actions = self.unsqueeze_context(observations, actions)
        return self.networks["skill_critic"](observations, skills, actions, **kwargs)

    def skill_target_critic(self, observations, skills, actions=None, **kwargs):
        skills = self.unsqueeze_context(observations, skills)
        actions = self.unsqueeze_context(observations, actions)
        return self.networks["skill_target_critic"](
            observations, skills, actions, **kwargs
        )

    def skill_actor(self, observations, skills, **kwargs):
        skills = self.unsqueeze_context(observations, skills)
        return self.networks["skill_actor"](
            jnp.concatenate([observations, skills], axis=-1), **kwargs
        )

    def high_actor(self, observations, **kwargs):
        return self.networks["high_actor"](observations, **kwargs)

    def high_value(self, observations, **kwargs):
        return self.networks["high_value"](observations, **kwargs)

    def high_critic(self, observations, skills, **kwargs):
        return self.networks["high_critic"](observations, skills, **kwargs)

    def high_target_critic(self, observations, skills, **kwargs):
        return self.networks["high_target_critic"](observations, skills, **kwargs)

    def __call__(self, observations, goals, actions, skills):
        # Only for initialization
        rets = {
            "value": self.value(observations, goals),
            "target_value": self.target_value(observations, goals),
            "actor": self.actor(observations, goals),
            "skill_actor": self.skill_actor(observations, skills),
        }
        if self.skill_value_only:
            rets.update(
                {
                    "skill_value": self.skill_value(observations, skills),
                    "skill_target_value": self.skill_target_value(observations, skills),
                }
            )
        else:
            rets.update(
                {
                    "skill_value": self.skill_value(observations, skills),
                    "skill_critic": self.skill_critic(observations, skills, actions),
                    "skill_target_critic": self.skill_target_critic(
                        observations, skills, actions
                    ),
                }
            )
        if self.hrl:
            rets.update(
                {
                    "high_actor": self.high_actor(observations),
                    "high_value": self.high_value(observations),
                    "high_critic": self.high_critic(observations, skills),
                    "high_target_critic": self.high_target_critic(observations, skills),
                }
            )
        return rets
