"""
Jax networks: Contains Haiku code for all the DeepONet models discussed in the paper.
"""

import haiku as hk
import copy
import jax
import jax.numpy as jnp
import os
import sys

sys.path.append((os.path.dirname(os.path.dirname(__file__))))
from utils import split_conv_string, split_linear_string


def EnsembleVanillaTrunkDeepONetCartesianProd(config):
    branch_config = config["branch_config"]
    if branch_config["activation"] == "tanh":
        activation = jax.nn.tanh
    elif branch_config["activation"] == "relu":
        activation = jax.nn.relu
    elif branch_config["activation"] == "elu":
        activation = jax.nn.elu
    elif branch_config["activation"] == "leaky_relu":
        activation = jax.nn.leaky_relu

    num_trunks = config["num_trunks"]
    config_list = []
    for i in range(num_trunks):
        trunk_config_ = copy.deepcopy(config["trunk_config"])
        trunk_config_["name"] = f"trunk_{i}"
        config_list.append(trunk_config_)

    def forward(input_):
        branch_forward = MLP(branch_config)
        trunk_applys = []
        trunk_pred_list = []
        for i in range(num_trunks):
            temp_trunk_pred = MLP(config_list[i])(input_[1])
            trunk_pred_list.append(temp_trunk_pred)

        bias_param = hk.get_parameter("bias", shape=(1,), init=jnp.zeros)
        branch_pred = branch_forward(input_[0])
        trunk_pred = jnp.hstack(trunk_pred_list)  # hstacking the trunk preds
        pred = jnp.matmul(branch_pred, trunk_pred.T) + bias_param
        return pred

    return forward


def EnsemblePODVanillaTrunkDeepONetCartesianConvVectorProd(config):
    branch_config = config["branch_config"]
    if branch_config["activation"] == "tanh":
        activation = jax.nn.tanh
    elif branch_config["activation"] == "relu":
        activation = jax.nn.relu
    elif branch_config["activation"] == "elu":
        activation = jax.nn.elu
    elif branch_config["activation"] == "leaky_relu":
        activation = jax.nn.leaky_relu

    vanilla_trunk_config = config["vanilla_trunk_config"]
    trunk_pod_basis_x = config["trunk_pod_basis_x"]
    trunk_pod_basis_y = config["trunk_pod_basis_y"]
    half_p = int(config["p"]/2)
    half_total_p = int(branch_config["output_dim"]/2)  # half of total p of the trunk output/branch output

    def forward(input_):
        branch_forward = CNN(branch_config)
        vanilla_trunk_forward = MLP(vanilla_trunk_config)
        bias_param = hk.get_parameter("bias", shape=(1,), init=jnp.zeros)
        branch_pred = branch_forward(input_[0])
        vanilla_trunk_pred = vanilla_trunk_forward(input_[1])

        # components of vector valued function
        trunk_pred_x = jnp.hstack([vanilla_trunk_pred[:, :half_p], trunk_pod_basis_x])
        first_component = (jnp.matmul(branch_pred[:, :half_total_p], trunk_pred_x.T) + bias_param) / half_total_p

        trunk_pred_y = jnp.hstack([vanilla_trunk_pred[:, half_p:], trunk_pod_basis_y])
        second_component = (jnp.matmul(branch_pred[:, half_total_p:], trunk_pred_y.T) + bias_param) / half_total_p

        pred = jnp.stack([first_component, second_component], axis=2)
        return pred

    return forward


def EnsemblePODVanillaTrunkDeepONetCartesianConvProd(config):
    branch_config = config["branch_config"]
    if branch_config["activation"] == "tanh":
        activation = jax.nn.tanh
    elif branch_config["activation"] == "relu":
        activation = jax.nn.relu
    elif branch_config["activation"] == "elu":
        activation = jax.nn.elu
    elif branch_config["activation"] == "leaky_relu":
        activation = jax.nn.leaky_relu

    vanilla_trunk_config = config["vanilla_trunk_config"]
    trunk_pod_basis = config["trunk_pod_basis"]
    total_p = branch_config["output_dim"]  # total p of the trunk output/branch output
    p = config["p"]

    def forward(input_):
        branch_forward = CNN(branch_config)
        vanilla_trunk_forward = MLP(vanilla_trunk_config)

        bias_param = hk.get_parameter("bias", shape=(1,), init=jnp.zeros)
        branch_pred = branch_forward(input_[0])
        vanilla_trunk_pred = vanilla_trunk_forward(input_[1])
        trunk_pred = jnp.hstack([vanilla_trunk_pred, trunk_pod_basis])
        pred = jnp.matmul(branch_pred, trunk_pred.T) + bias_param

        # output scaling/transform for POD-DeepONet
        pred = pred / total_p
        return pred

    return forward


def EnsemblePODVanillaTrunkDeepONetCartesianProd(config):
    branch_config = config["branch_config"]
    if branch_config["activation"] == "tanh":
        activation = jax.nn.tanh
    elif branch_config["activation"] == "relu":
        activation = jax.nn.relu
    elif branch_config["activation"] == "elu":
        activation = jax.nn.elu
    elif branch_config["activation"] == "leaky_relu":
        activation = jax.nn.leaky_relu

    vanilla_trunk_config = config["vanilla_trunk_config"]
    trunk_pod_basis = config["trunk_pod_basis"]
    p = config["p"]
    total_p = branch_config["output_dim"]  # total p of the trunk output/branch output

    def forward(input_):
        branch_forward = MLP(branch_config)
        vanilla_trunk_forward = MLP(vanilla_trunk_config)
        bias_param = hk.get_parameter("bias", shape=(1,), init=jnp.zeros)
        branch_pred = branch_forward(input_[0])
        vanilla_trunk_pred = vanilla_trunk_forward(input_[1])
        trunk_pred = jnp.hstack([vanilla_trunk_pred, trunk_pod_basis])

        pred = jnp.matmul(branch_pred, trunk_pred.T) + bias_param

        # output scaling/transform for POD-DeepONet
        pred = pred / total_p
        return pred

    return forward


def EnsembleVanillaTrunkDeepONetCartesianConvVectorProd(config):
    branch_config = config["branch_config"]
    if branch_config["activation"] == "tanh":
        activation = jax.nn.tanh
    elif branch_config["activation"] == "relu":
        activation = jax.nn.relu
    elif branch_config["activation"] == "elu":
        activation = jax.nn.elu
    elif branch_config["activation"] == "leaky_relu":
        activation = jax.nn.leaky_relu

    num_trunks = config["num_trunks"]
    half_p = int(config["p"]/2)
    half_total_p = int(branch_config["output_dim"]/2)  # half of total p of the trunk output/branch output

    config_list = []
    for i in range(num_trunks):
        trunk_config_ = copy.deepcopy(config["trunk_config"])
        trunk_config_["name"] = f"trunk_{i}"
        config_list.append(trunk_config_)

    def forward(input_):
        branch_forward = CNN(branch_config)
        trunk_applys = []
        trunk_pred_list = []
        for i in range(num_trunks):
            temp_trunk_pred = MLP(config_list[i])(input_[1])
            trunk_pred_list.append(temp_trunk_pred)

        bias_param = hk.get_parameter("bias", shape=(1,), init=jnp.zeros)
        branch_pred = branch_forward(input_[0])
        trunk_pred = jnp.hstack(trunk_pred_list)  # hstacking the trunk preds

        # components of vector valued function
        trunk_pred_x = jnp.hstack([temp[:, :half_p] for temp in trunk_pred_list])
        first_component = jnp.matmul(branch_pred[:, :half_total_p], trunk_pred_x.T) + bias_param

        trunk_pred_y = jnp.hstack([temp[:, half_p:] for temp in trunk_pred_list])
        second_component = jnp.matmul(branch_pred[:, half_total_p:], trunk_pred_y.T) + bias_param

        pred = jnp.stack([first_component, second_component], axis=2)
        return pred

    return forward


def EnsembleVanillaTrunkDeepONetCartesianConvProd(config):
    branch_config = config["branch_config"]
    if branch_config["activation"] == "tanh":
        activation = jax.nn.tanh
    elif branch_config["activation"] == "relu":
        activation = jax.nn.relu
    elif branch_config["activation"] == "elu":
        activation = jax.nn.elu
    elif branch_config["activation"] == "leaky_relu":
        activation = jax.nn.leaky_relu

    num_trunks = config["num_trunks"]

    config_list = []
    for i in range(num_trunks):
        trunk_config_ = copy.deepcopy(config["trunk_config"])
        trunk_config_["name"] = f"trunk_{i}"
        config_list.append(trunk_config_)

    def forward(input_):
        branch_forward = CNN(branch_config)
        trunk_applys = []
        trunk_pred_list = []
        for i in range(num_trunks):
            temp_trunk_pred = MLP(config_list[i])(input_[1])
            trunk_pred_list.append(temp_trunk_pred)

        bias_param = hk.get_parameter("bias", shape=(1,), init=jnp.zeros)
        branch_pred = branch_forward(input_[0])
        trunk_pred = jnp.hstack(trunk_pred_list)  # hstacking the trunk preds

        pred = jnp.matmul(branch_pred, trunk_pred.T) + bias_param
        return pred

    return forward


def EnsembleVanillaSparseTrunkDeepONetCartesianConvVectorProd(config):
    branch_config = config["branch_config"]
    if branch_config["activation"] == "tanh":
        activation = jax.nn.tanh
    elif branch_config["activation"] == "relu":
        activation = jax.nn.relu
    elif branch_config["activation"] == "elu":
        activation = jax.nn.elu
    elif branch_config["activation"] == "leaky_relu":
        activation = jax.nn.leaky_relu

    vanilla_trunk_config = config["vanilla_trunk_config"]
    trunk_config_list = config["trunk_config_list"]
    half_p = int(trunk_config_list[0]["output_dim"]/2)
    half_total_p = int(branch_config["output_dim"]/2)  # half of total p of the trunk output/branch output

    def forward(input_, num_groups, num_partitions):
        # param init
        branch_forward = CNN(branch_config)
        vanilla_trunk_forward = MLP(vanilla_trunk_config)

        def func(i, x): return hk.switch(i, trunk_applys, x)
        point_vmap = hk.vmap(func, in_axes=(0, None), split_rng=False)

        def group_func(i, point, point_weight):
            point_eval = point_vmap(i, point) * point_weight  # point_weight is already broadcasted
            return point_eval.sum(0)

        group_vmap = hk.vmap(group_func, in_axes=(0, 0, 0), split_rng=False)
        bias_param = hk.get_parameter("bias", shape=(1,), init=jnp.zeros)
        trunk_applys = []
        for i in range(num_partitions):
            temp_trunk_forward = MLP(trunk_config_list[i])
            trunk_applys.append(temp_trunk_forward)

        if hk.running_init():
            temp_list = []
            for i in range(num_partitions):
                temp_list.append(trunk_applys[i](input_[1]))

            pu_pred = activation(jnp.stack(temp_list).sum(0))
            branch_pred = branch_forward(input_[0])
            vanilla_trunk_pred = vanilla_trunk_forward(input_[1])

            # components of vector valued function
            trunk_pred_x = jnp.hstack([pu_pred[:, :half_p], vanilla_trunk_pred[:, :half_p]])
            first_component = jnp.matmul(branch_pred[:, :half_total_p], trunk_pred_x.T) + bias_param

            trunk_pred_y = jnp.hstack([pu_pred[:, half_p:], vanilla_trunk_pred[:, half_p:]])
            second_component = jnp.matmul(branch_pred[:, half_total_p:], trunk_pred_y.T) + bias_param

            pred = jnp.stack([first_component, second_component], axis=2)
            return pred

        # apply
        else:
            branch_pred = branch_forward(input_[0])
            vanilla_trunk_pred = vanilla_trunk_forward(input_[1])

            pu_trunk_pred = jnp.zeros_like(vanilla_trunk_pred)
            [pu_trunk_pred := pu_trunk_pred.at[input_[4][i]].set(group_vmap(input_[5][i], input_[2][i], input_[3][i])) for i in range(num_groups)]

            pu_trunk_pred = activation(pu_trunk_pred)

            # components of vector valued function
            trunk_pred_x = jnp.hstack([pu_trunk_pred[:, :half_p], vanilla_trunk_pred[:, :half_p]])
            first_component = jnp.matmul(branch_pred[:, :half_total_p], trunk_pred_x.T) + bias_param

            trunk_pred_y = jnp.hstack([pu_trunk_pred[:, half_p:], vanilla_trunk_pred[:, half_p:]])
            second_component = jnp.matmul(branch_pred[:, half_total_p:], trunk_pred_y.T) + bias_param

            pred = jnp.stack([first_component, second_component], axis=2)
            return pred

    return forward


def EnsembleVanillaSparseTrunkDeepONetCartesianConvProd(config):
    branch_config = config["branch_config"]
    if branch_config["activation"] == "tanh":
        activation = jax.nn.tanh
    elif branch_config["activation"] == "relu":
        activation = jax.nn.relu
    elif branch_config["activation"] == "elu":
        activation = jax.nn.elu
    elif branch_config["activation"] == "leaky_relu":
        activation = jax.nn.leaky_relu

    vanilla_trunk_config = config["vanilla_trunk_config"]
    trunk_config_list = config["trunk_config_list"]
    p = trunk_config_list[0]["output_dim"]

    def forward(input_, num_groups, num_partitions):
        # param init
        branch_forward = CNN(branch_config)
        vanilla_trunk_forward = MLP(vanilla_trunk_config)
        def func(i, x): return hk.switch(i, trunk_applys, x)
        point_vmap = hk.vmap(func, in_axes=(0, None), split_rng=False)

        def group_func(i, point, point_weight):
            point_eval = point_vmap(i, point) * point_weight  # point_weight is already broadcasted
            return point_eval.sum(0)

        group_vmap = hk.vmap(group_func, in_axes=(0, 0, 0), split_rng=False)
        bias_param = hk.get_parameter("bias", shape=(1,), init=jnp.zeros)
        trunk_applys = []
        for i in range(num_partitions):
            temp_trunk_forward = MLP(trunk_config_list[i])
            trunk_applys.append(temp_trunk_forward)

        if hk.running_init():
            temp_list = []
            for i in range(num_partitions):
                temp_list.append(trunk_applys[i](input_[1]))

            pu_pred = jnp.stack(temp_list).sum(0)
            branch_pred = branch_forward(input_[0])
            vanilla_trunk_pred = vanilla_trunk_forward(input_[1])

            trunk_pred = jnp.hstack([activation(pu_pred), vanilla_trunk_pred])

            pred = jnp.matmul(branch_pred, trunk_pred.T) + bias_param
            return pred

        # apply
        else:
            branch_pred = branch_forward(input_[0])
            vanilla_trunk_pred = vanilla_trunk_forward(input_[1])

            pu_trunk_pred = jnp.zeros_like(vanilla_trunk_pred)
            [pu_trunk_pred := pu_trunk_pred.at[input_[4][i]].set(group_vmap(input_[5][i], input_[2][i], input_[3][i])) for i in range(num_groups)]

            trunk_pred = jnp.hstack([activation(pu_trunk_pred), vanilla_trunk_pred])

            pred = jnp.matmul(branch_pred, trunk_pred.T) + bias_param
            return pred

    return forward


def DeepONetCartesianConvVectorProd(config):
    branch_config = config["branch_config"]
    trunk_config = config["trunk_config"]
    half_p = int(config["p"]/2)

    def forward(input_):
        branch_forward = CNN(branch_config)
        trunk_forward = MLP(trunk_config)
        bias_param = hk.get_parameter("bias", shape=(1,), init=jnp.zeros)

        branch_pred = branch_forward(input_[0])
        trunk_pred = trunk_forward(input_[1])

        # components of vector valued function
        first_component = jnp.matmul(branch_pred[:, :half_p], trunk_pred[:, :half_p].T) + bias_param
        second_component = jnp.matmul(branch_pred[:, half_p:], trunk_pred[:, half_p:].T) + bias_param
        pred = jnp.stack([first_component, second_component], axis=2)
        return pred

    return forward


def DeepONetCartesianConvProd(config):
    branch_config = config["branch_config"]
    trunk_config = config["trunk_config"]

    def forward(input_):
        branch_forward = CNN(branch_config)
        trunk_forward = MLP(trunk_config)
        bias_param = hk.get_parameter("bias", shape=(1,), init=jnp.zeros)

        branch_pred = branch_forward(input_[0])
        trunk_pred = trunk_forward(input_[1])
        pred = jnp.matmul(branch_pred, trunk_pred.T) + bias_param
        return pred

    return forward


def EnsembleVanillaSparseTrunkDeepONetCartesianProd(config):
    branch_config = config["branch_config"]
    if branch_config["activation"] == "tanh":
        activation = jax.nn.tanh
    elif branch_config["activation"] == "relu":
        activation = jax.nn.relu
    elif branch_config["activation"] == "elu":
        activation = jax.nn.elu
    elif branch_config["activation"] == "leaky_relu":
        activation = jax.nn.leaky_relu

    vanilla_trunk_config = config["vanilla_trunk_config"]
    trunk_config_list = config["trunk_config_list"]
    p = trunk_config_list[0]["output_dim"]

    def forward(input_, num_groups, num_partitions):
        # param init
        branch_forward = MLP(branch_config)
        vanilla_trunk_forward = MLP(vanilla_trunk_config)
        def func(i, x): return hk.switch(i, trunk_applys, x)
        point_vmap = hk.vmap(func, in_axes=(0, None), split_rng=False)

        def group_func(i, point, point_weight):
            point_eval = point_vmap(i, point) * point_weight  # point_weight is already broadcasted
            return point_eval.sum(0)

        group_vmap = hk.vmap(group_func, in_axes=(0, 0, 0), split_rng=False)
        bias_param = hk.get_parameter("bias", shape=(1,), init=jnp.zeros)
        trunk_applys = []
        for i in range(num_partitions):
            temp_trunk_forward = MLP(trunk_config_list[i])
            trunk_applys.append(temp_trunk_forward)

        if hk.running_init():
            temp_list = []
            for i in range(num_partitions):
                temp_list.append(trunk_applys[i](input_[1]))

            pu_pred = jnp.stack(temp_list).sum(0)
            branch_pred = branch_forward(input_[0])
            vanilla_trunk_pred = vanilla_trunk_forward(input_[1])

            trunk_pred = jnp.hstack([activation(pu_pred), vanilla_trunk_pred])

            pred = jnp.matmul(branch_pred, trunk_pred.T) + bias_param
            return pred

        # apply
        else:
            branch_pred = branch_forward(input_[0])
            vanilla_trunk_pred = vanilla_trunk_forward(input_[1])

            pu_trunk_pred = jnp.zeros_like(vanilla_trunk_pred)
            [pu_trunk_pred := pu_trunk_pred.at[input_[4][i]].set(group_vmap(input_[5][i], input_[2][i], input_[3][i])) for i in range(num_groups)]

            trunk_pred = jnp.hstack([activation(pu_trunk_pred), vanilla_trunk_pred])

            pred = jnp.matmul(branch_pred, trunk_pred.T) + bias_param
            return pred

    return forward


def EnsemblePODSparseTrunkDeepONetCartesianConvVectorProd(config):
    branch_config = config["branch_config"]
    if branch_config["activation"] == "tanh":
        activation = jax.nn.tanh
    elif branch_config["activation"] == "relu":
        activation = jax.nn.relu
    elif branch_config["activation"] == "elu":
        activation = jax.nn.elu
    elif branch_config["activation"] == "leaky_relu":
        activation = jax.nn.leaky_relu

    trunk_config_list = config["trunk_config_list"]
    total_p = branch_config["output_dim"]  # total p of the trunk output/branch output
    p = config["p"]
    half_p = int(config["p"]/2)
    half_total_p = int(branch_config["output_dim"]/2)  # half of total p of the trunk output/branch output
    trunk_pod_basis_x = config["trunk_pod_basis_x"]
    trunk_pod_basis_y = config["trunk_pod_basis_y"]

    def forward(input_, num_groups, num_partitions):
        # param init
        branch_forward = CNN(branch_config)
        def func(i, x): return hk.switch(i, trunk_applys, x)
        point_vmap = hk.vmap(func, in_axes=(0, None), split_rng=False)

        def group_func(i, point, point_weight):
            point_eval = point_vmap(i, point) * point_weight  # point_weight is already broadcasted
            return point_eval.sum(0)

        group_vmap = hk.vmap(group_func, in_axes=(0, 0, 0), split_rng=False)
        bias_param = hk.get_parameter("bias", shape=(1,), init=jnp.zeros)
        trunk_applys = []
        for i in range(num_partitions):
            temp_trunk_forward = MLP(trunk_config_list[i])
            trunk_applys.append(temp_trunk_forward)

        if hk.running_init():
            temp_list = []
            for i in range(num_partitions):
                temp_list.append(trunk_applys[i](input_[1]))

            pu_pred = activation(jnp.stack(temp_list).sum(0))
            branch_pred = branch_forward(input_[0])

            # components of vector valued function
            trunk_pred_x = jnp.hstack([pu_pred[:, :half_p], trunk_pod_basis_x])
            first_component = (jnp.matmul(branch_pred[:, :half_total_p], trunk_pred_x.T) + bias_param) / half_total_p

            trunk_pred_y = jnp.hstack([pu_pred[:, half_p:], trunk_pod_basis_y])
            second_component = (jnp.matmul(branch_pred[:, half_total_p:], trunk_pred_y.T) + bias_param) / half_total_p

            pred = jnp.stack([first_component, second_component], axis=2)

            return pred

        # apply
        else:
            branch_pred = branch_forward(input_[0])

            pu_trunk_pred = jnp.zeros((input_[1].shape[0], p))
            [pu_trunk_pred := pu_trunk_pred.at[input_[4][i]].set(group_vmap(input_[5][i], input_[2][i], input_[3][i])) for i in range(num_groups)]

            pu_trunk_pred = activation(pu_trunk_pred)

            # components of vector valued function
            trunk_pred_x = jnp.hstack([pu_trunk_pred[:, :half_p], trunk_pod_basis_x])
            first_component = (jnp.matmul(branch_pred[:, :half_total_p], trunk_pred_x.T) + bias_param) / half_total_p

            trunk_pred_y = jnp.hstack([pu_trunk_pred[:, half_p:], trunk_pod_basis_y])
            second_component = (jnp.matmul(branch_pred[:, half_total_p:], trunk_pred_y.T) + bias_param) / half_total_p

            pred = jnp.stack([first_component, second_component], axis=2)

            return pred

    return forward


def EnsemblePODSparseTrunkDeepONetCartesianConvProd(config):
    branch_config = config["branch_config"]
    if branch_config["activation"] == "tanh":
        activation = jax.nn.tanh
    elif branch_config["activation"] == "relu":
        activation = jax.nn.relu
    elif branch_config["activation"] == "elu":
        activation = jax.nn.elu
    elif branch_config["activation"] == "leaky_relu":
        activation = jax.nn.leaky_relu

    trunk_config_list = config["trunk_config_list"]
    total_p = branch_config["output_dim"]  # total p of the trunk output/branch output
    p = config["p"]
    trunk_pod_basis = config["trunk_pod_basis"]

    def forward(input_, num_groups, num_partitions):
        # param init
        branch_forward = CNN(branch_config)
        def func(i, x): return hk.switch(i, trunk_applys, x)
        point_vmap = hk.vmap(func, in_axes=(0, None), split_rng=False)

        def group_func(i, point, point_weight):
            point_eval = point_vmap(i, point) * point_weight  # point_weight is already broadcasted
            return point_eval.sum(0)

        group_vmap = hk.vmap(group_func, in_axes=(0, 0, 0), split_rng=False)
        bias_param = hk.get_parameter("bias", shape=(1,), init=jnp.zeros)
        trunk_applys = []
        for i in range(num_partitions):
            temp_trunk_forward = MLP(trunk_config_list[i])
            trunk_applys.append(temp_trunk_forward)

        if hk.running_init():
            temp_list = []
            for i in range(num_partitions):
                temp_list.append(trunk_applys[i](input_[1]))

            pu_pred = jnp.stack(temp_list).sum(0)
            branch_pred = branch_forward(input_[0])

            trunk_pred = jnp.hstack([activation(pu_pred), trunk_pod_basis])

            pred = jnp.matmul(branch_pred, trunk_pred.T) + bias_param

            # output scaling/transform for POD-DeepONet
            pred = pred / total_p
            return pred

        # apply
        else:
            branch_pred = branch_forward(input_[0])

            pu_trunk_pred = jnp.zeros((input_[1].shape[0], p))
            [pu_trunk_pred := pu_trunk_pred.at[input_[4][i]].set(group_vmap(input_[5][i], input_[2][i], input_[3][i])) for i in range(num_groups)]

            trunk_pred = jnp.hstack([activation(pu_trunk_pred), trunk_pod_basis])

            pred = jnp.matmul(branch_pred, trunk_pred.T) + bias_param

            # output scaling/transform for POD-DeepONet
            pred = pred / total_p
            return pred

    return forward


def EnsembleVanillaPODSparseTrunkDeepONetCartesianConvVectorProd(config):
    branch_config = config["branch_config"]
    if branch_config["activation"] == "tanh":
        activation = jax.nn.tanh
    elif branch_config["activation"] == "relu":
        activation = jax.nn.relu
    elif branch_config["activation"] == "elu":
        activation = jax.nn.elu
    elif branch_config["activation"] == "leaky_relu":
        activation = jax.nn.leaky_relu

    trunk_config_list = config["trunk_config_list"]
    vanilla_trunk_config = config["vanilla_trunk_config"]
    trunk_pod_basis_x = config["trunk_pod_basis_x"]
    trunk_pod_basis_y = config["trunk_pod_basis_y"]
    half_p = int(config["p"]/2)
    half_total_p = int(branch_config["output_dim"]/2)  # half of total p of the trunk output/branch output

    def forward(input_, num_groups, num_partitions):
        # param init
        branch_forward = CNN(branch_config)
        vanilla_trunk_forward = MLP(vanilla_trunk_config)
        def func(i, x): return hk.switch(i, trunk_applys, x)
        point_vmap = hk.vmap(func, in_axes=(0, None), split_rng=False)

        def group_func(i, point, point_weight):
            point_eval = point_vmap(i, point) * point_weight  # point_weight is already broadcasted
            return point_eval.sum(0)

        group_vmap = hk.vmap(group_func, in_axes=(0, 0, 0), split_rng=False)
        bias_param = hk.get_parameter("bias", shape=(1,), init=jnp.zeros)
        trunk_applys = []
        for i in range(num_partitions):
            temp_trunk_forward = MLP(trunk_config_list[i])
            trunk_applys.append(temp_trunk_forward)

        if hk.running_init():
            temp_list = []
            for i in range(num_partitions):
                temp_list.append(trunk_applys[i](input_[1]))

            pu_pred = activation(jnp.stack(temp_list).sum(0))
            branch_pred = branch_forward(input_[0])
            vanilla_trunk_pred = vanilla_trunk_forward(input_[1])

            # components of vector valued function
            trunk_pred_x = jnp.hstack([pu_pred[:, :half_p], vanilla_trunk_pred[:, :half_p], trunk_pod_basis_x])
            first_component = (jnp.matmul(branch_pred[:, :half_total_p], trunk_pred_x.T) + bias_param) / half_total_p

            trunk_pred_y = jnp.hstack([pu_pred[:, half_p:], vanilla_trunk_pred[:, half_p:], trunk_pod_basis_y])
            second_component = (jnp.matmul(branch_pred[:, half_total_p:], trunk_pred_y.T) + bias_param) / half_total_p

            pred = jnp.stack([first_component, second_component], axis=2)

            return pred

        # apply
        else:
            branch_pred = branch_forward(input_[0])
            vanilla_trunk_pred = vanilla_trunk_forward(input_[1])

            pu_trunk_pred = jnp.zeros_like(vanilla_trunk_pred)
            [pu_trunk_pred := pu_trunk_pred.at[input_[4][i]].set(group_vmap(input_[5][i], input_[2][i], input_[3][i])) for i in range(num_groups)]

            pu_trunk_pred = activation(pu_trunk_pred)

            # components of vector valued function
            trunk_pred_x = jnp.hstack([pu_trunk_pred[:, :half_p], vanilla_trunk_pred[:, :half_p], trunk_pod_basis_x])
            first_component = (jnp.matmul(branch_pred[:, :half_total_p], trunk_pred_x.T) + bias_param) / half_total_p

            trunk_pred_y = jnp.hstack([pu_trunk_pred[:, half_p:], vanilla_trunk_pred[:, half_p:], trunk_pod_basis_y])
            second_component = (jnp.matmul(branch_pred[:, half_total_p:], trunk_pred_y.T) + bias_param) / half_total_p

            pred = jnp.stack([first_component, second_component], axis=2)

            return pred

    return forward


def EnsembleVanillaPODSparseTrunkDeepONetCartesianConvProd(config):
    branch_config = config["branch_config"]
    if branch_config["activation"] == "tanh":
        activation = jax.nn.tanh
    elif branch_config["activation"] == "relu":
        activation = jax.nn.relu
    elif branch_config["activation"] == "elu":
        activation = jax.nn.elu
    elif branch_config["activation"] == "leaky_relu":
        activation = jax.nn.leaky_relu

    trunk_config_list = config["trunk_config_list"]
    vanilla_trunk_config = config["vanilla_trunk_config"]
    total_p = branch_config["output_dim"]  # total p of the trunk output/branch output
    p = config["p"]
    trunk_pod_basis = config["trunk_pod_basis"]

    def forward(input_, num_groups, num_partitions):
        # param init
        branch_forward = CNN(branch_config)
        vanilla_trunk_forward = MLP(vanilla_trunk_config)
        def func(i, x): return hk.switch(i, trunk_applys, x)
        point_vmap = hk.vmap(func, in_axes=(0, None), split_rng=False)

        def group_func(i, point, point_weight):
            point_eval = point_vmap(i, point) * point_weight  # point_weight is already broadcasted
            return point_eval.sum(0)

        group_vmap = hk.vmap(group_func, in_axes=(0, 0, 0), split_rng=False)
        bias_param = hk.get_parameter("bias", shape=(1,), init=jnp.zeros)
        trunk_applys = []
        for i in range(num_partitions):
            temp_trunk_forward = MLP(trunk_config_list[i])
            trunk_applys.append(temp_trunk_forward)

        if hk.running_init():
            temp_list = []
            for i in range(num_partitions):
                temp_list.append(trunk_applys[i](input_[1]))

            pu_pred = jnp.stack(temp_list).sum(0)
            branch_pred = branch_forward(input_[0])
            vanilla_trunk_pred = vanilla_trunk_forward(input_[1])

            trunk_pred = jnp.hstack([activation(pu_pred), vanilla_trunk_pred, trunk_pod_basis])

            pred = jnp.matmul(branch_pred, trunk_pred.T) + bias_param

            # output scaling/transform for POD-DeepONet
            pred = pred / total_p
            return pred

        # apply
        else:
            branch_pred = branch_forward(input_[0])
            vanilla_trunk_pred = vanilla_trunk_forward(input_[1])

            pu_trunk_pred = jnp.zeros((input_[1].shape[0], p))
            [pu_trunk_pred := pu_trunk_pred.at[input_[4][i]].set(group_vmap(input_[5][i], input_[2][i], input_[3][i])) for i in range(num_groups)]

            trunk_pred = jnp.hstack([activation(pu_trunk_pred), vanilla_trunk_pred, trunk_pod_basis])

            pred = jnp.matmul(branch_pred, trunk_pred.T) + bias_param

            # output scaling/transform for POD-DeepONet
            pred = pred / total_p
            return pred

    return forward


def EnsembleVanillaPODSparseTrunkDeepONetCartesianProd(config):
    branch_config = config["branch_config"]
    if branch_config["activation"] == "tanh":
        activation = jax.nn.tanh
    elif branch_config["activation"] == "relu":
        activation = jax.nn.relu
    elif branch_config["activation"] == "elu":
        activation = jax.nn.elu
    elif branch_config["activation"] == "leaky_relu":
        activation = jax.nn.leaky_relu

    trunk_config_list = config["trunk_config_list"]
    vanilla_trunk_config = config["vanilla_trunk_config"]
    total_p = branch_config["output_dim"]  # total p of the trunk output/branch output
    p = config["p"]
    trunk_pod_basis = config["trunk_pod_basis"]

    def forward(input_, num_groups, num_partitions):
        # param init
        branch_forward = MLP(branch_config)
        vanilla_trunk_forward = MLP(vanilla_trunk_config)

        def func(i, x): return hk.switch(i, trunk_applys, x)
        point_vmap = hk.vmap(func, in_axes=(0, None), split_rng=False)

        def group_func(i, point, point_weight):
            point_eval = point_vmap(i, point) * point_weight  # point_weight is already broadcasted
            return point_eval.sum(0)
        group_vmap = hk.vmap(group_func, in_axes=(0, 0, 0), split_rng=False)

        bias_param = hk.get_parameter("bias", shape=(1,), init=jnp.zeros)

        trunk_applys = []
        for i in range(num_partitions):
            temp_trunk_forward = MLP(trunk_config_list[i])
            trunk_applys.append(temp_trunk_forward)

        if hk.running_init():
            temp_list = []
            for i in range(num_partitions):
                temp_list.append(trunk_applys[i](input_[1]))

            pu_pred = jnp.stack(temp_list).sum(0)
            branch_pred = branch_forward(input_[0])
            vanilla_trunk_pred = vanilla_trunk_forward(input_[1])

            trunk_pred = jnp.hstack([activation(pu_pred), vanilla_trunk_pred, trunk_pod_basis])

            pred = jnp.matmul(branch_pred, trunk_pred.T) + bias_param

            # output scaling/transform for POD-DeepONet
            pred = pred / total_p
            return pred

        # apply
        else:
            branch_pred = branch_forward(input_[0])
            vanilla_trunk_pred = vanilla_trunk_forward(input_[1])

            pu_trunk_pred = jnp.zeros((input_[1].shape[0], p))
            [pu_trunk_pred := pu_trunk_pred.at[input_[4][i]].set(group_vmap(input_[5][i], input_[2][i], input_[3][i])) for i in range(num_groups)]

            trunk_pred = jnp.hstack([activation(pu_trunk_pred), vanilla_trunk_pred, trunk_pod_basis])

            pred = jnp.matmul(branch_pred, trunk_pred.T) + bias_param

            # output scaling/transform for POD-DeepONet
            pred = pred / total_p
            return pred

    return forward


def EnsemblePODSparseTrunkDeepONetCartesianProd(config):
    branch_config = config["branch_config"]
    if branch_config["activation"] == "tanh":
        activation = jax.nn.tanh
    elif branch_config["activation"] == "relu":
        activation = jax.nn.relu
    elif branch_config["activation"] == "elu":
        activation = jax.nn.elu
    elif branch_config["activation"] == "leaky_relu":
        activation = jax.nn.leaky_relu

    trunk_config_list = config["trunk_config_list"]
    total_p = branch_config["output_dim"]  # total p of the trunk output/branch output
    p = config["p"]
    trunk_pod_basis = config["trunk_pod_basis"]

    def forward(input_, num_groups, num_partitions):
        # param init
        branch_forward = MLP(branch_config)

        def func(i, x): return hk.switch(i, trunk_applys, x)
        point_vmap = hk.vmap(func, in_axes=(0, None), split_rng=False)

        def group_func(i, point, point_weight):
            point_eval = point_vmap(i, point) * point_weight  # point_weight is already broadcasted
            return point_eval.sum(0)
        group_vmap = hk.vmap(group_func, in_axes=(0, 0, 0), split_rng=False)

        bias_param = hk.get_parameter("bias", shape=(1,), init=jnp.zeros)

        trunk_applys = []
        for i in range(num_partitions):
            temp_trunk_forward = MLP(trunk_config_list[i])
            trunk_applys.append(temp_trunk_forward)

        if hk.running_init():
            temp_list = []
            for i in range(num_partitions):
                temp_list.append(trunk_applys[i](input_[1]))

            pu_pred = jnp.stack(temp_list).sum(0)
            branch_pred = branch_forward(input_[0])

            trunk_pred = jnp.hstack([activation(pu_pred), trunk_pod_basis])

            pred = jnp.matmul(branch_pred, trunk_pred.T) + bias_param

            # output scaling/transform for POD-DeepONet
            pred = pred / total_p
            return pred

        # apply
        else:
            branch_pred = branch_forward(input_[0])

            pu_trunk_pred = jnp.zeros((input_[1].shape[0], p))
            [pu_trunk_pred := pu_trunk_pred.at[input_[4][i]].set(group_vmap(input_[5][i], input_[2][i], input_[3][i])) for i in range(num_groups)]

            trunk_pred = jnp.hstack([activation(pu_trunk_pred), trunk_pod_basis])

            pred = jnp.matmul(branch_pred, trunk_pred.T) + bias_param

            # output scaling/transform for POD-DeepONet
            pred = pred / total_p
            return pred

    return forward


def PODDeepONetCartesianConvVectorProd(config):
    branch_config = config["branch_config"]
    half_p = int(config["p"]/2)
    trunk_pod_basis_x = config["trunk_pod_basis_x"]
    trunk_pod_basis_y = config["trunk_pod_basis_y"]
    pod_mean_x = config["pod_mean_x"]
    pod_mean_y = config["pod_mean_y"]
    mean_bool = config["mean_bool"]

    def forward(input_):
        branch_forward = CNN(branch_config)

        branch_pred = branch_forward(input_[0])
        # input_[1] is the trunk POD basis
        # components of vector valued function
        first_component = jnp.matmul(branch_pred[:, :half_p], trunk_pod_basis_x.T) / half_p
        second_component = jnp.matmul(branch_pred[:, half_p:], trunk_pod_basis_y.T) / half_p

        if mean_bool:
            first_component += pod_mean_x
            second_component += pod_mean_y

        pred = jnp.stack([first_component, second_component], axis=2)
        return pred

    return forward


def PODDeepONetCartesianConvProd(config):
    branch_config = config["branch_config"]
    p = config["p"]
    pod_mean = config["pod_mean"]
    mean_bool = config["mean_bool"]

    def forward(input_):
        branch_forward = CNN(branch_config)

        branch_pred = branch_forward(input_[0])
        # input_[1] is the trunk POD basis
        pred = jnp.matmul(branch_pred, input_[1].T)
        # output scaling/transform for POD-DeepONet
        pred = (pred / p)
        if mean_bool:
            pred += pod_mean

        return pred

    return forward


def PODDeepONetCartesianProd(config):
    branch_config = config["branch_config"]
    p = config["p"]
    pod_mean = config["pod_mean"]
    mean_bool = config["mean_bool"]

    def forward(input_):
        branch_forward = MLP(branch_config)

        branch_pred = branch_forward(input_[0])
        # input_[1] is the trunk POD basis
        pred = jnp.matmul(branch_pred, input_[1].T)
        # output scaling/transform for POD-DeepONet
        pred = (pred / p)
        if mean_bool:
            pred += pod_mean

        return pred

    return forward


def DeepONetCartesianProd(config):
    branch_config = config["branch_config"]
    trunk_config = config["trunk_config"]

    def forward(input_):
        branch_forward = MLP(branch_config)
        trunk_forward = MLP(trunk_config)
        bias_param = hk.get_parameter("bias", shape=(1,), init=jnp.zeros)

        branch_pred = branch_forward(input_[0])
        trunk_pred = trunk_forward(input_[1])
        pred = jnp.matmul(branch_pred, trunk_pred.T) + bias_param
        return pred

    return forward


def DeepONet(config):
    branch_config = config["branch_config"]
    trunk_config = config["trunk_config"]

    def forward(input_):
        branch_forward = MLP(branch_config)
        trunk_forward = MLP(trunk_config)
        bias_param = hk.get_parameter("bias", shape=(1,), init=jnp.zeros)

        branch_pred = branch_forward(input_[0])
        trunk_pred = trunk_forward(input_[1])
        pred = jnp.multiply(branch_pred, trunk_pred).sum(axis=1, keepdims=True) + bias_param
        return pred

    return forward


def CNN(config):
    if config["activation"] == "tanh":
        activation = jax.nn.tanh
    elif config["activation"] == "relu":
        activation = jax.nn.relu
    elif config["activation"] == "elu":
        activation = jax.nn.elu
    elif config["activation"] == "leaky_relu":
        activation = jax.nn.leaky_relu

    def forward(x):
        layers_ = []
        for i in range(len(config["layers"])):
            if "conv" in config["layers"][i]:
                _, num_channels, kernel_size, stride, padding = split_conv_string(config["layers"][i])
                layers_.append(hk.Conv2D(output_channels=num_channels, kernel_shape=kernel_size, stride=stride, padding=padding, name=config["name"]+f"_conv_{i}"))
            elif config["layers"][i] == "activation":
                layers_.append(activation)
            elif config["layers"][i] == "flatten":
                layers_.append(hk.Flatten())
            elif "linear" in config["layers"][i]:
                _, num_neurons = split_linear_string(config["layers"][i])
                layers_.append(hk.Linear(num_neurons, name=config["name"]+f"_linear_{i}"))
            else:
                raise f"Layer {config['layers'][i]} not configured"

        cnn = hk.Sequential(layers_)
        return cnn(x)

    return forward


def MLP(config):
    if config["activation"] == "tanh":
        activation = jax.nn.tanh
    elif config["activation"] == "relu":
        activation = jax.nn.relu
    elif config["activation"] == "elu":
        activation = jax.nn.elu
    elif config["activation"] == "leaky_relu":
        activation = jax.nn.leaky_relu

    if config.get("layer_sizes", None) is None:
        hidden_layers = [config["nodes"] for _ in range(config["num_hidden_layers"])]
        if config["nodes"] == 0 or config["num_hidden_layers"] == 0:
            layer_sizes = [config["output_dim"]]
        else:
            layer_sizes = hidden_layers + [config["output_dim"]]
    else:
        hidden_layers = config["layer_sizes"]
        layer_sizes = hidden_layers + [config["output_dim"]]

    def forward(x):
        mlp_module = hk.nets.MLP(output_sizes=layer_sizes, with_bias=config.get("use_bias", True), activation=activation, activate_final=config.get("last_layer_activate", False), name=config["name"])
        return mlp_module(x)

    return forward


def Linear(output_dim, use_bias=True):
    def forward(x):
        linear_module = hk.Linear(output_dim, with_bias=use_bias)
        return linear_module(x)

    return forward


def get_model(model_name, config):
    _MODELS = dict(
        mlp=MLP,
        linear=Linear,
        cnn=CNN,
        deeponet=DeepONet,
        deeponet_cartesian_prod=DeepONetCartesianProd,
        pod_deeponet_cartesian_prod=PODDeepONetCartesianProd,
        pod_deeponet_cartesian_conv_prod=PODDeepONetCartesianConvProd,
        pod_deeponet_cartesian_conv_vector_prod=PODDeepONetCartesianConvVectorProd,
        deeponet_cartesian_conv_prod=DeepONetCartesianConvProd,
        deeponet_cartesian_conv_vector_prod=DeepONetCartesianConvVectorProd,
        ensemble_vanilla_sparse_trunk_deeponet_cartesian_conv_prod=EnsembleVanillaSparseTrunkDeepONetCartesianConvProd,
        ensemble_vanilla_sparse_trunk_deeponet_cartesian_conv_vector_prod=EnsembleVanillaSparseTrunkDeepONetCartesianConvVectorProd,
        ensemble_vanilla_sparse_trunk_deeponet_cartesian_prod=EnsembleVanillaSparseTrunkDeepONetCartesianProd,
        ensemble_vanilla_trunk_deeponet_cartesian_conv_prod=EnsembleVanillaTrunkDeepONetCartesianConvProd,
        ensemble_vanilla_trunk_deeponet_cartesian_conv_vector_prod=EnsembleVanillaTrunkDeepONetCartesianConvVectorProd,
        ensemble_vanilla_trunk_deeponet_cartesian_prod=EnsembleVanillaTrunkDeepONetCartesianProd,
        ensemble_pod_vanilla_trunk_deeponet_cartesian_prod=EnsemblePODVanillaTrunkDeepONetCartesianProd,
        ensemble_pod_vanilla_trunk_deeponet_cartesian_conv_prod=EnsemblePODVanillaTrunkDeepONetCartesianConvProd,
        ensemble_pod_vanilla_trunk_deeponet_cartesian_conv_vector_prod=EnsemblePODVanillaTrunkDeepONetCartesianConvVectorProd,
        ensemble_pod_sparse_trunk_deeponet_cartesian_prod=EnsemblePODSparseTrunkDeepONetCartesianProd,
        ensemble_pod_sparse_trunk_deeponet_cartesian_conv_prod=EnsemblePODSparseTrunkDeepONetCartesianConvProd,
        ensemble_pod_sparse_trunk_deeponet_cartesian_conv_vector_prod=EnsemblePODSparseTrunkDeepONetCartesianConvVectorProd,
        ensemble_vanilla_pod_sparse_trunk_deeponet_cartesian_prod=EnsembleVanillaPODSparseTrunkDeepONetCartesianProd,
        ensemble_vanilla_pod_sparse_trunk_deeponet_cartesian_conv_prod=EnsembleVanillaPODSparseTrunkDeepONetCartesianConvProd,
        ensemble_vanilla_pod_sparse_trunk_deeponet_cartesian_conv_vector_prod=EnsembleVanillaPODSparseTrunkDeepONetCartesianConvVectorProd,
    )

    _USE_STATE = dict(
        mlp=False,
        linear=False,
        cnn=False,
        deeponet=False,
        deeponet_cartesian_prod=False,
        deeponet_cartesian_conv_prod=False,
        deeponet_cartesian_conv_vector_prod=False,
        ensemble_vanilla_sparse_trunk_deeponet_cartesian_prod=False,
        ensemble_vanilla_sparse_trunk_deeponet_cartesian_conv_prod=False,
        ensemble_vanilla_sparse_trunk_deeponet_cartesian_conv_vector_prod=False,
        ensemble_vanilla_trunk_deeponet_cartesian_conv_prod=False,
        ensemble_vanilla_trunk_deeponet_cartesian_conv_vector_prod=False,
        ensemble_vanilla_trunk_deeponet_cartesian_prod=False,
        pod_deeponet_cartesian_prod=False,
        pod_deeponet_cartesian_conv_prod=False,
        pod_deeponet_cartesian_conv_vector_prod=False,
        ensemble_pod_vanilla_trunk_deeponet_cartesian_prod=False,
        ensemble_pod_vanilla_trunk_deeponet_cartesian_conv_prod=False,
        ensemble_pod_vanilla_trunk_deeponet_cartesian_conv_vector_prod=False,
        ensemble_vanilla_pod_sparse_trunk_deeponet_cartesian_prod=False,
        ensemble_vanilla_pod_sparse_trunk_deeponet_cartesian_conv_prod=False,
        ensemble_vanilla_pod_sparse_trunk_deeponet_cartesian_conv_vector_prod=False,
        ensemble_pod_sparse_trunk_deeponet_cartesian_prod=False,
        ensemble_pod_sparse_trunk_deeponet_cartesian_conv_prod=False,
        ensemble_pod_sparse_trunk_deeponet_cartesian_conv_vector_prod=False,
    )

    if model_name not in _MODELS.keys():
        raise NameError('Available keys:', _MODELS.keys())

    net_fn = _MODELS[model_name](config)

    if _USE_STATE[model_name]:
        net = hk.transform_with_state(net_fn)
    else:
        net = hk.transform(net_fn)

    return net.init, net.apply
