import jax
import jax.numpy as jnp
from flax import linen as nn
import chex
from typing import Tuple, Optional, List
from dataclasses import dataclass, field
from .shared import (
    identity_out,
    tanh_out,
    categorical_out,
    gaussian_out,
    default_bias_init,
    kernel_init_fn,
)

class VGG19(nn.Module):
    """VGG19 network."""
    num_classes: int = 1000
    input_channels: int = 3
    num_blocks: List[int] = field(default_factory=lambda: [2, 2, 4, 4, 4])
    num_features: List[int] = field(default_factory=lambda: [64, 128, 256, 512, 512])
    kernel_size: Tuple[int, int] = (3, 3)
    kernel_init_type: str = "lecun_normal"

    @nn.compact
    def __call__(self, x: chex.Array, rng: Optional[chex.PRNGKey] = None) -> chex.Array:
        # Configurable VGG blocks
        for num_block, num_feature in zip(self.num_blocks, self.num_features):
            for _ in range(num_block):
                x = conv_relu_block(x, num_feature, self.kernel_size, (1, 1), kernel_init_type=self.kernel_init_type)
            x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))

        # Flatten the output into vector for Dense Readout
        x = x.reshape(x.shape[0], -1)

        # Dense layers
        x = nn.Dense(features=4096, kernel_init=kernel_init_fn[self.kernel_init_type]())(x)
        x = nn.relu(x)
        x = nn.Dense(features=4096, kernel_init=kernel_init_fn[self.kernel_init_type]())(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.num_classes, kernel_init=kernel_init_fn[self.kernel_init_type]())(x)

        return x
