import functools
from typing import Any, Callable, Optional, Tuple

from flax import linen as nn
import jax
from jax import numpy as jnp
from jaxrl.networks.initializer import default_kernel_init, default_bias_init
from .base import SequenceModelBase
from .lifgate import StackedLIFGate, BatchStackedEncoderModel

class LIFGate(SequenceModelBase):
    hidden_size: int
    n_layer: int = 1
    pdrop: float = 0.0

    def setup(self):
        self.lifgate = BatchStackedEncoderModel(self.hidden_size, self.hidden_size, num_layers=self.n_layer)

    def __call__(self, carry, x, rng=None):
        if self.has_variable("params", "lifgate"):
            lifgate_params = {"params": self.variables["params"]["lifgate"]}

            if rng is None:  # evaluation
                deterministic, dropout_rng = True, None
            else:  # training or exploration
                # https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.InvalidRngError
                # Module.apply() has rng for dropout
                deterministic, dropout_rng = False, {"dropout": rng}

            out = self.lifgate.apply(
                lifgate_params,
                carry,
                x,
                rngs=dropout_rng,
            )

            return out
        else:
            # init
            return self.lifgate(
                carry,
                x,
            )
    def forward(self, embedded_inputs, initial_states, **kwargs):
        hidden_states, outputs = self.__call__(initial_states, embedded_inputs, **kwargs)

        # ((B, T, D), (B, T, D)), (B, T, D)
        return (hidden_states, outputs), None

    def forward_per_step(self, embedded_inputs, initial_states, **kwargs):
        hidden_states, outputs = self.__call__(initial_states, embedded_inputs, **kwargs)

        # ((D), (D)), (D)
        return (hidden_states, outputs), None
    
    def initialize_carry(self, batch_dims):
        batch_size = batch_dims[0] if len(batch_dims) == 1 else 1
        return [None] * self.n_layer