from functools import partial
#from typing import Callable, Optional, Sequence, Tuple
from typing import (Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple,
                    Union, Sized, List)

import jax
from jax import lax
import jax.numpy as jnp
from flax import linen as nn

from src.utils.model_utils import FlaxSequential as Sequential

#from .common import ConvBlock, ModuleDef, Sequential
#from .splat import SplAtConv2d

ModuleDef = Callable[..., Callable]
# InitFn = Callable[[PRNGKey, Shape, DType], Array]
InitFn = Callable[[Any, Iterable[int], Any], Any]

PRECISION = jax.lax.Precision(2)  # 0: 16bit - 1: 32bit - 2: 64bit
DTYPE = jnp.float64

PreciseConv = partial(nn.Conv, dtype=DTYPE, precision=PRECISION)


class ConvNet(nn.Module):
    depth: int
    widen_factor: int
    num_classes: int = 10
    num_input_channels: int = 3

    def setup(self):
        w = 16 * self.widen_factor
        layers = []
        layers += [PreciseConv(w, kernel_size=(3,3), strides=(1,1), padding=[(0, 0), (0, 0)]), nn.relu]
        for i in range(1, self.depth-1):
            layers += [PreciseConv(w, kernel_size=(3,3), strides=(1,1), padding=[(0, 0), (0, 0)]), nn.relu]
        layers += [PreciseConv(self.num_classes, kernel_size=(3, 3), strides=(1, 1), padding=[(0, 0), (0, 0)]), nn.relu]
        self.conv_layers = Sequential(layers)

    def __call__(self, input):
        out = self.conv_layers(input)
        return out.mean(axis=(1, 2))
