import jax
import jax.numpy as jnp
from flax import linen as nn
import chex
from typing import Tuple, Optional, List
from dataclasses import dataclass, field
from .shared import (
    identity_out,
    tanh_out,
    categorical_out,
    gaussian_out,
    default_bias_init,
    kernel_init_fn,
)

# Define the convolutional block with ReLU
def conv_relu_block(x, features, kernel_size, strides, padding="SAME", kernel_init_type="lecun_normal"):
    x = nn.Conv(features=features, kernel_size=kernel_size, strides=strides, padding=padding,
                use_bias=True, bias_init=default_bias_init(),
                kernel_init=kernel_init_fn[kernel_init_type]())(x)
    x = nn.relu(x)
    return x


# Define the basic residual block
def residual_block(x, features, kernel_size, strides, kernel_init_type='lecun_normal'):
    shortcut = x

    # First convolution layer + ReLU activation
    x = conv_relu_block(x, features, kernel_size, strides, kernel_init_type=kernel_init_type)

    # Second convolution layer (no ReLU yet)
    x = nn.Conv(features, kernel_size, strides=strides, padding='SAME',
                kernel_init=kernel_init_fn[kernel_init_type]())(x)

    # Adjusting the shortcut path if necessary
    if x.shape != shortcut.shape:
        shortcut = nn.Conv(features, (1, 1), strides=strides, padding='SAME',
                           kernel_init=kernel_init_fn[kernel_init_type]())(shortcut)

    # Add the shortcut
    x += shortcut
    x = nn.relu(x)
    return x



# Flexible ResNet class
class ResNet(nn.Module):
    num_output_units: int = 100
    num_blocks_per_layer: List[int] = field(default_factory=lambda: [2, 2, 2, 2])
    features_per_layer: List[int] = field(default_factory=lambda: [64, 128, 256, 512])
    kernel_size: int = 3
    strides: int = 1
    output_activation: str = "identity"
    kernel_init_type: str = "lecun_normal"

    @nn.compact
    def __call__(self, x: chex.Array, rng: Optional[chex.PRNGKey] = None) -> chex.Array:
        # Initial Convolution
        x = conv_relu_block(x, self.features_per_layer[0], (self.kernel_size, self.kernel_size),
                            (self.strides, self.strides), kernel_init_type=self.kernel_init_type)

        # Residual Blocks
        for layer_num, num_blocks in enumerate(self.num_blocks_per_layer):
            for _ in range(num_blocks):
                x = residual_block(x, self.features_per_layer[layer_num], (self.kernel_size, self.kernel_size),
                                   (self.strides, self.strides), kernel_init_type=self.kernel_init_type)

        # Global Average Pooling
        x = nn.avg_pool(x, window_shape=(8, 8), strides=None, padding="VALID")
        x = x.reshape(x.shape[0], -1)

        # Output Layer
        x = nn.Dense(features=self.num_output_units, bias_init=default_bias_init(),
                     kernel_init=kernel_init_fn[self.kernel_init_type]())(x)

        if self.output_activation == "identity":
            x = identity_out(x, self.num_output_units, self.kernel_init_type)

        return x.squeeze() if x.shape[0] == 1 else x


# Flexible VGG-like class
class VGG(nn.Module):
    num_output_units: int = 10
    block_depths: List[int] = field(default_factory=lambda: [2, 2, 3, 3, 3])
    features_per_block: List[int] = field(default_factory=lambda: [64, 128, 256, 512, 512])
    dense_layers: List[int] = field(default_factory=lambda: [128])
    kernel_size: int = 3
    strides: int = 1
    padding: str = "SAME"
    output_activation: str = "identity"
    kernel_init_type: str = "lecun_normal"

    @nn.compact
    def __call__(self, x: chex.Array, rng: Optional[chex.PRNGKey] = None) -> chex.Array:
        for depth, features in zip(self.block_depths, self.features_per_block):
            for _ in range(depth):
                x = conv_relu_block(x, features, (self.kernel_size, self.kernel_size),
                                    (self.strides, self.strides), self.padding, self.kernel_init_type)
            x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2), padding="VALID")

        # Flatten and Dense Layers
        x = x.reshape(x.shape[0], -1)
        # Dense layers
        for units in self.dense_layers:
            x = nn.Dense(units)(x)
            x = nn.relu(x)

        x = nn.Dense(features=self.num_output_units, bias_init=default_bias_init(),
                     kernel_init=kernel_init_fn[self.kernel_init_type]())(x)

        if self.output_activation == "identity":
            x = identity_out(x, self.num_output_units, self.kernel_init_type)

        return x.squeeze() if x.shape[0] == 1 else x