from functools import partial
from typing import NamedTuple, Sequence

import elegy as eg
import jax
import jax.nn.initializers as init
from elegy import nn
from jax import numpy as jn

from .config import ModelConfig


def make_network(cfg: ModelConfig) -> eg.Module:
    def act(name):
        if not hasattr(jax.nn, name):
            raise ValueError(f"'{name}' is not a valid activation function")
        return getattr(jax.nn, name)

    return OperatorNet(
        nn.Sequential(
            Concat(
                FourierFeatures(cfg.n_fourier_features),
                nn.Sequential(
                    nn.Linear(cfg.n_hidden),
                    act(cfg.embedding_activation),
                ),
            ),
            nn.Linear(cfg.embedding_dimension),
        ),
        nn.Sequential(
            MeshCNN(
                list(cfg.meshcnn_hidden_layers) + [cfg.embedding_dimension],
                activation=act(cfg.meshcnn_activation),
                conv_radius=cfg.meshcnn_conv_radius,
            ),
            lambda mesh: mesh.edge_features,
        ),
        num_heads=cfg.n_heads,
        num_layers=cfg.n_attention_layers,
        output_dimension=cfg.output_dimension,
    )


class Mesh(NamedTuple):
    edge_features: jn.ndarray
    edge_adjacency: jn.ndarray

    @property
    def n_edges(self) -> int:
        return self.edge_features.shape[0]

    def expand_neighbours(self, radius=1) -> jn.ndarray:
        features = self.edge_features
        for _ in range(radius):
            fab = features[self.edge_adjacency[:, :2]].reshape(self.n_edges, -1)
            fcd = features[self.edge_adjacency[:, 2:]].reshape(self.n_edges, -1)
            diff = jn.abs(fab - fcd)
            mean = 0.5 * (fab + fcd)
            features = jn.concatenate((diff, mean), axis=-1)
        features = jn.concatenate((self.edge_features, features), axis=-1)
        return features


class MeshConv(eg.Module):
    def __init__(self, features_out: int, activation=None, conv_radius=1):
        if activation is None:

            def activation(x):
                return x

        self.activation = activation
        self.features_out = features_out
        self.conv_radius = conv_radius

    @eg.compact
    def __call__(self, mesh: Mesh) -> Mesh:
        features = mesh.expand_neighbours(self.conv_radius)
        features = self.activation(nn.Linear(self.features_out)(features))
        return Mesh(features, mesh.edge_adjacency)


class MeshCNN(nn.Sequential):
    def __init__(
        self,
        features: Sequence[int],
        activation=jax.nn.relu,
        final_activation=None,
        conv_radius=1,
    ):
        super().__init__(
            *(MeshConv(f, activation, conv_radius) for f in features[:-1]),
            MeshConv(features[-1], final_activation, conv_radius),
        )


class FourierFeatures(eg.Module):
    def __init__(self, features: int, frequency_stddev: float = 1.0):
        self.features = features
        self.frequency_stddev = frequency_stddev

    @eg.compact
    def __call__(self, x: jn.ndarray) -> jn.ndarray:
        y = jn.sin(
            nn.Linear(
                self.features,
                kernel_init=init.normal(self.frequency_stddev),
                bias_init=init.uniform(2 * jn.pi),
            )(x)
        )
        return y


class Concat(eg.Module):
    def __init__(self, left: eg.Module, right: eg.Module):
        self.left = left
        self.right = right

    def __call__(self, x):
        return jn.concatenate((self.left(x), self.right(x)), axis=-1)


# def avg(axis=-1, keepdims=True):
#     return lambda x: jn.mean(x, axis=axis, keepdims=keepdims)


# class WeightedAttention(eg.Module):
#     def __init__(self, split_axis=-1, average_axis=-2):
#         self.split_axis = split_axis
#         self.average_axis = average_axis

#     def __call__(self, x: jn.ndarray):
#         assert (
#             x.shape[self.split_axis] % 2 == 0
#         ), f"x.shape[{self.split_axis}] must be even"
#         x, w = jn.split(x, 2, self.split_axis)
#         w = jax.nn.softmax(w, self.average_axis)
#         return (w * x).sum(self.average_axis)


# def to_linear_map(size_out: int):
#     def layer(x):
#         x = x.reshape(-1, size_out)
#         w = x[:-1, :]
#         b = x[-1, :]
#         return w, b

#     return layer


# class ConstantMap(eg.Module):
#     size_out: int

#     def __init__(self, size_in: int, size_out: int):
#         self.size_in = size_in
#         self.size_out = size_out

#     @eg.compact
#     def __call__(self, *args):
#         lin = nn.Linear(self.size_out)
#         # init linear layer
#         _ = lin(jn.zeros((1, self.size_in)))
#         return lin.kernel, lin.bias


class Attention(nn.FlaxModule):
    def __init__(self, num_heads: int):
        from flax.linen import MultiHeadDotProductAttention

        super().__init__(MultiHeadDotProductAttention(num_heads))


class OperatorNet(eg.Module):
    def __init__(
        self,
        trunk: eg.Module,
        branch: eg.Module,
        num_layers: int = 3,
        num_heads: int = 8,
        output_dimension: int = 1,
    ):
        self.trunk = trunk
        self.branch = branch
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.output_dimension = output_dimension

    @eg.compact
    def __call__(self, x, phi):
        kv = self.branch(phi)[None, ...]
        x = self.trunk(x)
        x = x.reshape((1,) * (3 - len(x.shape)) + x.shape)

        for i in range(self.num_layers):
            h = Attention(self.num_heads)(x, kv)
            x = nn.LayerNorm()(x + h)
            h = nn.MLP([x.shape[-1] * 4, x.shape[-1]], activation=jax.nn.silu)(x)
            x = nn.LayerNorm()(x + h)

        return nn.Linear(self.output_dimension)(x).reshape((-1, self.output_dimension))


# class SkipConnection(eg.Module):
#     def __init__(self, activation=jax.nn.silu):
#         self.activation = activation

#     @eg.compact
#     def __call__(self, x):
#         size = x.shape[-1]
#         h = self.activation(nn.Linear(size)(x))
#         return x + h
