from typing import Any, Callable

import jax
import jax.numpy as jnp
from flax import linen as nn

from .nn import NN
from .. import activations
from .. import initializers

# import numpyro
# import numpyro.distributions as dist
# from numpyro.contrib.module import random_flax_module


class PFNN(NN):
    """Fully-connected neural network."""

    layer_sizes: Any
    activation: Any
    kernel_initializer: Any

    params: Any = None
    _input_transform: Callable = None
    _output_transform: Callable = None

    def setup(self):
        # TODO: implement get regularizer
        if isinstance(self.activation, list):
            if not (len(self.layer_sizes) - 1) == len(self.activation):
                raise ValueError(
                    "Total number of activation functions do not match with sum of hidden layers and output layer!"
                )
            self._activation = list(map(activations.get, self.activation))
        else:
            self._activation = activations.get(self.activation)
        kernel_initializer = initializers.get(self.kernel_initializer)
        initializer = jax.nn.initializers.zeros
        
        self.denses = [
            [
                nn.Dense(
                    u,
                    kernel_init=kernel_initializer,
                    bias_init=initializer,
                ) 
                for u in [x[i] for x in self.layer_sizes[1:-1]] + [1]
            ]
            for i in range(self.layer_sizes[-1])
        ]

    def __call__(self, inputs, training=False):
        x = inputs
        if self._input_transform is not None:
            x = self._input_transform(x)
            
        x_list = []
        for i in range(self.layer_sizes[-1]):
            denses = self.denses[i]
            for j, linear in enumerate(denses[:-1]):
                x = (
                    self._activation[j](linear(x))
                    if isinstance(self._activation, list)
                    else self._activation(linear(x))
                )
            x = denses[-1](x)
            assert x.shape[-1] == 1
            x_list.append(x)
            
        x = jnp.concatenate(x_list, axis=-1)
        
        if self._output_transform is not None:
            x = self._output_transform(inputs, x)
        return x
