from torch import nn
from flax import linen
from jax import numpy as jnp

from typing import (Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union)


class FlaxSequential(linen.Module):
    layers: Sequence[Union[linen.Module, Callable[[jnp.ndarray], jnp.ndarray]]]

    @linen.compact
    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


def construct_nonlinearity(activation, target='torch'):
    return {
        'relu': (nn.ReLU(), linen.relu),
        'sigmoid': (nn.Sigmoid(), linen.sigmoid),
        'tanh': (nn.Tanh(), linen.tanh),
        'eye': (None, None)
    }[activation][0 if target == 'torch' else 1]

