import flax.linen as nn
import jax.numpy as jnp

def default_init(scale: float = jnp.sqrt(2)):
    return nn.initializers.orthogonal(scale)

def xavier_init():
    return nn.initializers.xavier_normal()

def kaiming_init():
    return nn.initializers.kaiming_normal()

class ResnetStack(nn.Module):
    num_ch: int
    num_blocks: int
    use_max_pooling: bool = True

    @nn.compact
    def __call__(self, observations: jnp.ndarray) -> jnp.ndarray:
        initializer = nn.initializers.xavier_uniform()
        conv_out = nn.Conv(
            features=self.num_ch,
            kernel_size=(3, 3),
            strides=1,
            kernel_init=initializer,
            padding='SAME'
        )(observations)

        if self.use_max_pooling:
            conv_out = nn.max_pool(
                conv_out,
                window_shape=(3, 3),
                padding='SAME',
                strides=(2, 2)
            )

        for _ in range(self.num_blocks):
            block_input = conv_out
            conv_out = nn.relu(conv_out)
            conv_out = nn.Conv(
                features=self.num_ch, kernel_size=(3, 3), strides=1,
                padding='SAME',
                kernel_init=initializer)(conv_out)

            conv_out = nn.relu(conv_out)
            conv_out = nn.Conv(
                features=self.num_ch, kernel_size=(3, 3), strides=1,
                padding='SAME', kernel_init=initializer
            )(conv_out)
            conv_out += block_input

        return conv_out


class ImpalaEncoder(nn.Module):
    width: int = 1
    use_multiplicative_cond: bool = False
    stack_sizes: tuple = (16, 32, 32)
    num_blocks: int = 2
    dropout_rate: float = None

    def setup(self):
        stack_sizes = self.stack_sizes
        self.stack_blocks = [
            ResnetStack(
                num_ch=stack_sizes[i] * self.width,
                num_blocks=self.num_blocks)
            for i in range(len(stack_sizes))

        ]
        if self.dropout_rate is not None:
            self.dropout = nn.Dropout(rate=self.dropout_rate)

    @nn.compact
    def __call__(self, x, train=True, cond_var=None):
        x = x.astype(jnp.float32) / 255.0
        # x = jnp.reshape(x, (*x.shape[:-2], -1))

        conv_out = x

        for idx in range(len(self.stack_blocks)):
            conv_out = self.stack_blocks[idx](conv_out)
            if self.dropout_rate is not None:
                conv_out = self.dropout(conv_out, deterministic=not train)
            if self.use_multiplicative_cond:
                assert cond_var is not None, "Cond var shouldn't be done when using it"
                print("Using Multiplicative Cond!")
                temp_out = nn.Dense(conv_out.shape[-1], kernel_init=xavier_init())(cond_var)
                x_mult = jnp.expand_dims(jnp.expand_dims(temp_out, 1), 1)
                print ('x_mult shape in IMPALA:', x_mult.shape, conv_out.shape)
                conv_out = conv_out * x_mult

        conv_out = nn.relu(conv_out)
        # print(conv_out.shape, conv_out.reshape((*x.shape[:-3], -1)).shape)
        return conv_out.reshape((*x.shape[:-3], -1))
    
import functools as ft
impala_configs = {
    'impala': ImpalaEncoder,
    'impala_large': ft.partial(ImpalaEncoder, stack_sizes=(16, 32, 32, 32)),
    'impala_larger': ft.partial(ImpalaEncoder, stack_sizes=(16, 32, 32, 32, 32)),
    'impala_largest': ft.partial(ImpalaEncoder, stack_sizes=(16, 32, 32, 32, 32, 32)),
    'impala_wider': ft.partial(ImpalaEncoder, width=2),
    'impala_widest': ft.partial(ImpalaEncoder, width=4),
    'impala_deeper': ft.partial(ImpalaEncoder, num_blocks=4),
    'impala_deepest': ft.partial(ImpalaEncoder, num_blocks=8),
}