from typing import Callable, Sequence

import jax
import jax.numpy as jnp
from flax import linen as nn
from ott.neural.networks.icnn import ICNN


class TransportMap(nn.Module):
    layers: Sequence[int]
    output_dim: int

    @nn.compact
    def __call__(self, x):
        for f in self.layers:
            x = nn.Dense(f)(x)
            x = nn.silu(x)
        return nn.Dense(self.output_dim)(x)  # same dimensionality


# class UNetMLP(nn.Module):
#     layers: Sequence[int]
#     output_dim: int

#     @nn.compact
#     def __call__(self, x):
#         skips = []

#         for f in self.layers:
#             x = nn.silu(nn.Dense(f)(x))
#             skips.append(x)

#         for f, skip in zip(reversed(self.layers[:-1]), reversed(skips[:-1])):
#             x = nn.silu(nn.Dense(f)(x))
#             x = jnp.concatenate([x, skip], axis=-1)

#         return nn.Dense(self.output_dim)(x)


class UNetMLP(nn.Module):
    layers: Sequence[int]
    output_dim: int

    @nn.compact
    def __call__(self, x):
        skips = []

        # Encoder
        for f in self.layers:
            x = nn.silu(nn.Dense(f)(x))
            skips.append(x)

        # Decoder (skip last encoder output, which is the bottleneck)
        for f, skip in zip(reversed(self.layers[:-1]), reversed(skips[:-1])):
            x = jnp.concatenate([x, skip], axis=-1)
            x = nn.silu(nn.Dense(f)(x))  # compress back to f

        return nn.Dense(self.output_dim)(x)


class GradientICNNMap(nn.Module):
    layers: Sequence[int]
    output_dim: int
    gaussian_map_samples: tuple[jnp.ndarray, jnp.ndarray] | None
    init_fn: Callable = nn.initializers.lecun_normal()

    def setup(self):
        self.icnn = ICNN(
            dim_data=self.output_dim,
            dim_hidden=self.layers,
            gaussian_map_samples=self.gaussian_map_samples,
            init_fn=self.init_fn,
        )

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        def f(x_single):
            return self.icnn(x_single)

        return jax.grad(f)(x)


class TimeVaryingTransport(nn.Module):
    K: int
    layers: Sequence[int]
    input_dim: int
    net_cls: Callable[..., nn.Module] = GradientICNNMap
    net_kwargs_list: list[dict] = None  # one dict per time step
    reg: float = 0.0

    def setup(self):
        assert len(self.net_kwargs_list) == self.K - 1, "Need K-1 config dicts"
        self.transport_list = [
            self.net_cls(
                layers=self.layers,
                output_dim=self.input_dim,
                **self.net_kwargs_list[k],
            )
            for k in range(self.K - 1)
        ]

    def __call__(self, x):  # x: (K-1, D)
        return jnp.stack([self.transport_list[k](x[k]) for k in range(self.K - 1)], axis=0)
        # return jnp.stack([self.transport_list[k](x[k]) + self.reg * x[k] * x[k] for k in range(self.K - 1)], axis=0)
