import jax
from flax import linen as nn
from jax import numpy as jnp
from functools import partial

from src.utils.model_utils import FlaxSequential

PRECISION = jax.lax.Precision(2)  # 0: 16bit - 1: 32bit - 2: 64bit
DTYPE = jnp.float64
PreciseDense = partial(nn.Dense, dtype=DTYPE, precision=PRECISION)


def _append_if_not_none(list, item):
    if item is not None:
        list.append(item)


class FCNetNTK(nn.Module):
    input_size: int
    depth: int
    width: int
    num_classes: int
    activation: nn.Module
    num_input_channels: int =1

    def setup(self):
        layers = []
        layers.append(PreciseDense(self.width))
        _append_if_not_none(layers, self.activation)
        for i in range(1, self.depth - 1):
            layers.append(PreciseDense(self.width))
            if i != self.depth - 2:
                _append_if_not_none(layers, self.activation)
        layers.append(PreciseDense(self.num_classes))
        self.linear = FlaxSequential(layers)

    def __call__(self, x):
        x = jnp.reshape(x, newshape=(-1, self.input_size))
        return self.linear(x)