import jax
import jax.numpy as jnp
from flax import linen as nn
import chex
from typing import Tuple, Optional, List
from dataclasses import dataclass, field
from .shared import (
    identity_out,
    tanh_out,
    categorical_out,
    gaussian_out,
    default_bias_init,
    kernel_init_fn,
)

class BasicBlock(nn.Module):
    """Basic Block for ResNet 18 and 34."""
    features: int
    downsample: Optional[nn.Module] = None
    kernel_size: Tuple[int, int] = (3, 3)
    strides: Tuple[int, int] = (1, 1)
    kernel_init_type: str = "lecun_normal"

    @nn.compact
    def __call__(self, x: chex.Array) -> chex.Array:
        identity = x

        out = conv_relu_block(x, self.features, self.kernel_size, self.strides, kernel_init_type=self.kernel_init_type)
        out = nn.Conv(features=self.features, kernel_size=self.kernel_size, strides=self.strides, use_bias=False,
                      kernel_init=kernel_init_fn[self.kernel_init_type]())(out)
        out = nn.BatchNorm()(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = nn.relu(out)

        return out

class ResNet18(nn.Module):
    """ResNet-18 model."""
    num_classes: int = 1000
    input_channels: int = 3
    block: nn.Module = BasicBlock
    layers: List[int] = field(default_factory=lambda: [2, 2, 2, 2])
    kernel_init_type: str = "lecun_normal"

    @nn.compact
    def __call__(self, x: chex.Array, rng: Optional[chex.PRNGKey] = None) -> chex.Array:
        x = nn.Conv(self.input_channels, kernel_size=(7, 7), strides=(2, 2), padding="SAME",
                    use_bias=False, kernel_init=kernel_init_fn[self.kernel_init_type]())(x)
        x = nn.BatchNorm()(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding="SAME")

        features = 64
        for layer, num_blocks in enumerate(self.layers):
            for block_idx in range(num_blocks):
                strides = (1, 1)
                if block_idx == 0 and layer != 0:
                    strides = (2, 2)
                x = self.block(features=features, downsample=None if strides == (1, 1) else nn.Conv(features=features, kernel_size=(1, 1), strides=strides))(x)
            features *= 2

        x = nn.avg_pool(x, window_shape=(7, 7))
        x = x.reshape(x.shape[0], -1)
        x = nn.Dense(features=self.num_classes, kernel_init=kernel_init_fn[self.kernel_init_type]())(x)

        return x
