from typing import Any, Callable

import flax.linen as nn
import jax.numpy as jnp
from flax.linen.initializers import constant, orthogonal


class CNN(nn.Module):
    """CNN module for overcooked"""

    activation: str = "relu"

    @nn.compact
    def __call__(self, x):
        # x.shape == (*B, H, W, C)
        activation = nn.relu if self.activation == "relu" else nn.tanh

        x = nn.Conv(
            features=32,
            kernel_size=(5, 5),
        )(x)
        x = activation(x)

        x = nn.Conv(
            features=32,
            kernel_size=(3, 3),
        )(x)
        x = activation(x)

        # x = nn.Conv(
        #     features=32,
        #     kernel_size=(3, 3),
        # )(x)
        # x = activation(x)

        x = x.reshape(*x.shape[:-3], -1)  # Flatten

        # x = nn.Dense(features=64)(x)
        # x = activation(x)

        return x


# class CNN(nn.Module):
#     output_size: int = 64
#     activation: Callable[..., Any] = nn.relu
#
#     @nn.compact
#     def __call__(self, x):
#         x = nn.Conv(
#             features=128,
#             kernel_size=(1, 1),
#             kernel_init=orthogonal(jnp.sqrt(2)),
#             bias_init=constant(0.0),
#         )(x)
#         x = self.activation(x)
#         x = nn.Conv(
#             features=128,
#             kernel_size=(1, 1),
#             kernel_init=orthogonal(jnp.sqrt(2)),
#             bias_init=constant(0.0),
#         )(x)
#         x = self.activation(x)
#         x = nn.Conv(
#             features=8,
#             kernel_size=(1, 1),
#             kernel_init=orthogonal(jnp.sqrt(2)),
#             bias_init=constant(0.0),
#         )(x)
#         x = self.activation(x)
#
#         x = nn.Conv(
#             features=16,
#             kernel_size=(3, 3),
#             kernel_init=orthogonal(jnp.sqrt(2)),
#             bias_init=constant(0.0),
#         )(x)
#         x = self.activation(x)
#
#         x = nn.Conv(
#             features=32,
#             kernel_size=(3, 3),
#             kernel_init=orthogonal(jnp.sqrt(2)),
#             bias_init=constant(0.0),
#         )(x)
#         x = self.activation(x)
#
#         x = nn.Conv(
#             features=32,
#             kernel_size=(3, 3),
#             kernel_init=orthogonal(jnp.sqrt(2)),
#             bias_init=constant(0.0),
#         )(x)
#         x = self.activation(x)
#
#         x = x.reshape((x.shape[0], -1))
#
#         x = nn.Dense(
#             features=self.output_size,
#             kernel_init=orthogonal(jnp.sqrt(2)),
#             bias_init=constant(0.0),
#         )(x)
#         x = self.activation(x)
#
#         return x
