import torch
from scipy.io import savemat, loadmat
from matplotlib.patches import Circle
import copy
from scipy import io
from mpl_toolkits.axes_grid1 import make_axes_locatable
import jax
import jax.numpy as jnp
import numpy as np
import os
import sys
import matplotlib.pyplot as plt
from matplotlib.tri import Triangulation

sys.path.append((os.path.dirname(os.path.dirname(__file__))))
from plot_utils import *
from train_utils import train_func
from utils import fstr, save_results, form_pod_basis, PU
from jax_networks import get_model


def plot_vanilla():
    project = "vanilla_deeponet"
    model_name = "deeponet_cartesian_prod"
    save_file_name = fstr(save_folder.payload + "{project}_p={p}_seed={seed}.json")
    param_save_file_name = fstr(save_folder.payload + "{project}_p={p}_seed={seed}.pickle")
    log_str = fstr("vanilla-deeponet on {problem} with {opt_choice} p={p}")

    trunk_input = 2
    branch_input = m

    for p in ps:
        arr = []
        print(log_str)
        for seed in seeds:
            trunk_output = p
            branch_output = p

            trunk_config = dict(
                activation=activation_choice,
                last_layer_activate=True,
                layer_sizes=trunk_layer_sizes,
                output_dim=trunk_output,
                name="trunk"
            )

            branch_config = dict(
                activation=activation_choice,
                layer_sizes=branch_layer_sizes,
                output_dim=branch_output,
                name="branch"
            )

            bias_config = dict(name="bias")

            model_config = dict(
                branch_config=branch_config,
                trunk_config=trunk_config,
                bias_config=bias_config
            )

            if not os.path.isdir(str(plot_save_folder)):
                os.makedirs(str(plot_save_folder))

            train_input, Y, test_input, Y_test = get_data(str(dataset_folder))

            train_input = (train_input[0], train_input[1])
            test_input = (test_input[0], test_input[1])

            dummy_input = (train_input[0], train_input[1])

            hyperparameter_dict = dict(
                print_bool=print_bool,
                print_interval=print_interval,
                epochs=epochs,
                model_config=model_config,
                opt_choice=opt_choice,
                schedule_choice=schedule_choice,
                lr_dict=lr_dict,
                problem=problem,
                N=N,
                N_test=N_test,
                p=p,
                m=m,
                dtype=dtype,
                seed=seed
            )

            train_config = dict(
                dummy_input=dummy_input,
                train_input=train_input,
                test_input=test_input,
                Y=Y,
                Y_test=Y_test,
                model_name=model_name,
            ) | hyperparameter_dict

            with open(str(param_save_file_name), "rb") as f:
                weights = torch.load(f)

            X = train_input[1]

            model_key = jax.random.PRNGKey(seed)
            _, model_forward = get_model(train_config["model_name"], train_config["model_config"])
            model_forward = jax.jit(model_forward)

            if plot_train_test == "train":
                train_pred = model_forward(weights, None, train_input)
                spatial_gradients = get_max_grad_over_funcs(model_forward, weights, train_input, 1)
                mse = jnp.power(train_pred - Y, 2).mean(0)

            elif plot_train_test == "test":
                test_pred = model_forward(weights, None, test_input)
                spatial_gradients = get_max_grad_over_funcs(model_forward, weights, test_input, 1)
                mse = jnp.power(test_pred - Y_test, 2).mean(0)

            if to_plot == "gradient":
                arr.append(spatial_gradients)
            elif to_plot == "mse":
                arr.append(mse)

        plot_arr = jnp.asarray(arr).mean(0)
        title = "Vanilla"
        title = "Spatial Gradient" if to_plot == "gradient" else title
        plot_helper_3d(X, plot_arr, title, save_file_name, clim_lower, clim_upper, axis_font, tick_fontsize, title_font, figsize)


def plot_pod():
    project = "pod_deeponet"
    model_name = "pod_deeponet_cartesian_prod"
    save_file_name = fstr(save_folder.payload + "{project}_p={p}_seed={seed}.json")
    param_save_file_name = fstr(save_folder.payload + "{project}_p={p}_seed={seed}.pickle")
    log_str = fstr("pod-deeponet on {problem} with {opt_choice} p={p}")

    trunk_input = 2
    branch_input = m

    for p in pod_ps:
        for mean_bool in mean_bools:
            arr = []
            print(log_str)
            for seed in seeds:
                trunk_output = p
                branch_output = p

                trunk_config = dict(
                    activation=activation_choice,
                    last_layer_activate=True,
                    layer_sizes=trunk_layer_sizes,
                    output_dim=trunk_output,
                    name="trunk"
                )

                branch_config = dict(
                    activation=activation_choice,
                    layer_sizes=branch_layer_sizes,
                    output_dim=branch_output,
                    name="branch"
                )

                bias_config = dict(name="bias")

                if not os.path.isdir(str(save_folder)):
                    os.makedirs(str(save_folder))

                train_input, Y, test_input, Y_test = get_data(str(dataset_folder))

                X = train_input[1]

                train_trunk_input, x_mean = form_pod_basis(Y, p, not mean_bool)

                x_mean = x_mean.T

                # putting pod basis back in tuples
                train_input = (train_input[0], train_trunk_input)
                test_input = (test_input[0], train_trunk_input)

                dummy_input = (train_input[0], train_trunk_input)

                model_config = dict(
                    branch_config=branch_config,
                    trunk_config=trunk_config,
                    bias_config=bias_config,
                    pod_mean=x_mean,
                    mean_bool=mean_bool,
                    p=p
                )

                hyperparameter_dict = dict(
                    print_bool=print_bool,
                    print_interval=print_interval,
                    epochs=epochs,
                    model_config=model_config,
                    opt_choice=opt_choice,
                    schedule_choice=schedule_choice,
                    lr_dict=lr_dict,
                    problem=problem,
                    N=N,
                    N_test=N_test,
                    p=p,
                    m=m,
                    dtype=dtype,
                    seed=seed
                )

                train_config = dict(
                    dummy_input=dummy_input,
                    train_input=train_input,
                    test_input=test_input,
                    Y=Y,
                    Y_test=Y_test,
                    model_name=model_name,
                ) | hyperparameter_dict

                with open(str(param_save_file_name), "rb") as f:
                    print(str(param_save_file_name))
                    weights = torch.load(f)

                model_key = jax.random.PRNGKey(seed)
                _, model_forward = get_model(train_config["model_name"], train_config["model_config"])
                model_forward = jax.jit(model_forward)

                tri = Triangulation(X[:, 0], X[:, 1])

                if plot_train_test == "train":
                    train_pred = model_forward(weights, None, train_input)
                    if to_plot == "gradient":
                        spatial_gradients = get_max_grad_over_funcs(model_forward, weights, train_input, 1)
                    mse = jnp.power(train_pred - Y, 2).mean(0)

                elif plot_train_test == "test":
                    test_pred = model_forward(weights, None, test_input)
                    if to_plot == "gradient":
                        spatial_gradients = get_max_grad_over_funcs(model_forward, weights, test_input, 1)
                    mse = jnp.power(test_pred - Y_test, 2).mean(0)

                if to_plot == "gradient":
                    arr.append(spatial_gradients)
                elif to_plot == "mse":
                    arr.append(mse)

            plot_arr = jnp.asarray(arr).mean(0)
            title = "POD" if mean_bool else "Modified POD"
            plot_helper_3d(X, plot_arr, title, save_file_name, clim_lower, clim_upper, axis_font, tick_fontsize, title_font, figsize)


def ensemble_vanilla_trunk_plot():
    project = "ensemble_vanilla_trunk_deeponet"
    model_name = "ensemble_vanilla_trunk_deeponet_cartesian_prod"

    save_file_name = fstr(save_folder.payload + "{project}_p={p}_seed={seed}.json")
    param_save_file_name = fstr(save_folder.payload + "{project}_p={p}_seed={seed}.pickle")
    log_str = fstr("ensemble-vanilla-trunk-deeponet on {problem} with {opt_choice} p={p}")

    trunk_input = 2
    branch_input = m

    for p in ps:
        arr = []
        print(log_str)
        for seed in seeds:
            trunk_output = p
            branch_output = p * num_trunks

            branch_config = dict(
                activation=activation_choice,
                layer_sizes=branch_layer_sizes,
                output_dim=branch_output,
                name="branch"
            )

            bias_config = dict(name="bias")

            if not os.path.isdir(str(save_folder)):
                os.makedirs(str(save_folder))

            train_input, Y, test_input, Y_test = get_data(str(dataset_folder))

            train_input = (train_input[0], train_input[1])
            test_input = (test_input[0], test_input[1])

            dummy_input = (train_input[0], train_input[1])

            trunk_config = dict(
                activation=activation_choice,
                last_layer_activate=True,  # in this case, these trunks should be activated for the stacking
                layer_sizes=trunk_layer_sizes,
                output_dim=trunk_output,
                name="trunk"
            )

            model_config = dict(
                branch_config=branch_config,
                trunk_config=trunk_config,
                bias_config=bias_config,
                num_trunks=num_trunks
            )

            hyperparameter_dict = dict(
                print_bool=print_bool,
                print_interval=print_interval,
                epochs=epochs,
                model_config=model_config,
                opt_choice=opt_choice,
                schedule_choice=schedule_choice,
                lr_dict=lr_dict,
                problem=problem,
                N=N,
                N_test=N_test,
                p=p,
                m=m,
                dtype=dtype,
                seed=seed
            )

            train_config = dict(
                dummy_input=dummy_input,
                train_input=train_input,
                test_input=test_input,
                Y=Y,
                Y_test=Y_test,
                model_name=model_name,
            ) | hyperparameter_dict

            with open(str(param_save_file_name), "rb") as f:
                print(str(param_save_file_name))
                weights = torch.load(f)

            X = train_input[1]

            model_key = jax.random.PRNGKey(seed)
            _, model_forward = get_model(train_config["model_name"], train_config["model_config"])
            model_forward = jax.jit(model_forward)

            tri = Triangulation(X[:, 0], X[:, 1])

            if plot_train_test == "train":
                train_pred = model_forward(weights, None, train_input)
                if to_plot == "gradient":
                    spatial_gradients = get_max_grad_over_funcs(model_forward, weights, train_input, 1)
                mse = jnp.power(train_pred - Y, 2).mean(0)

            elif plot_train_test == "test":
                test_pred = model_forward(weights, None, test_input)
                if to_plot == "gradient":
                    spatial_gradients = get_max_grad_over_funcs(model_forward, weights, test_input, 1)
                mse = jnp.power(test_pred - Y_test, 2).mean(0)

            if to_plot == "gradient":
                arr.append(spatial_gradients)
            elif to_plot == "mse":
                arr.append(mse)

        plot_arr = jnp.asarray(arr).mean(0)
        title = "(P+1)-Vanilla"
        plot_helper_3d(X, plot_arr, title, save_file_name, clim_lower, clim_upper, axis_font, tick_fontsize, title_font, figsize)


def plot_ensemble_vanilla_sparse():
    project = "ensemble_vanilla_sparse_trunk_deeponet"
    model_name = "ensemble_vanilla_sparse_trunk_deeponet_cartesian_prod"

    save_file_name = fstr(save_folder.payload + "{project}_M={M}_p={p}_seed={seed}.json")
    param_save_file_name = fstr(save_folder.payload + "{project}_M={M}_p={p}_seed={seed}.pickle")
    log_str = fstr("ensemble-vanilla-sparse-trunk-deeponet on {problem} with {opt_choice} M={M} p={p}")

    trunk_input = 2
    branch_input = m

    for p in ps:
        arr = []
        print(log_str)
        for seed in seeds:
            trunk_output = p
            branch_output = p * 2

            branch_config = dict(
                activation=activation_choice,
                layer_sizes=branch_layer_sizes,
                output_dim=branch_output,
                name="branch"
            )

            bias_config = dict(name="bias")

            if not os.path.isdir(str(save_folder)):
                os.makedirs(str(save_folder))

            train_input, Y, test_input, Y_test = get_data(str(dataset_folder))

            train_input = (train_input[0], train_input[1])
            test_input = (test_input[0], test_input[1])

            # partitioning ============================================================
            # TRAINING
            train_trunk_input = train_input[1]
            test_trunk_input = test_input[1]

            pu_config = dict(
                num_partitions=M,
                dim=train_trunk_input.shape[-1],
            )

            pu_obj = PU(pu_config)

            pu_obj.partition_domain(train_trunk_input, centers=centers_list, radius=radius_list)
            K, M_changed, _ = pu_obj.check_partioning(train_trunk_input, change_M=True, min_partitions_per_point=min_partitions_per_point)
            pu_trunk_width = int(1/onp.sqrt(K) * trunk_width)
            num_groups, participation_idx, points_per_group, indices_per_group, _, radius_arrs = pu_obj.form_points_per_group(train_trunk_input)
            weights_per_group = pu_obj.form_weights_per_group(points_per_group=points_per_group, participation_idx=participation_idx, radius_arrs=radius_arrs)
            weights_per_group = [jnp.broadcast_to(w.reshape(*w.shape, 1), (w.shape[0], w.shape[1], p)) for w in weights_per_group]

            # TEST
            _, M_test, _ = pu_obj.check_partioning(test_trunk_input, False, min_partitions_per_point=min_partitions_per_point)
            test_num_groups, test_participation_idx, test_points_per_group, test_indices_per_group, _, test_radius_arrs = pu_obj.form_points_per_group(test_trunk_input)
            test_weights_per_group = pu_obj.form_weights_per_group(points_per_group=test_points_per_group, participation_idx=test_participation_idx, radius_arrs=test_radius_arrs)
            test_weights_per_group = [jnp.broadcast_to(w.reshape(*w.shape, 1), (w.shape[0], w.shape[1], p)) for w in test_weights_per_group]

            # reassembling training and test and dummy tuples
            train_input = (train_input[0], train_input[1], points_per_group, weights_per_group, indices_per_group, participation_idx, num_groups)
            test_input = (test_input[0], test_input[1], test_points_per_group, test_weights_per_group, test_indices_per_group, test_participation_idx, test_num_groups)
            dummy_input = (train_input[0], train_input[1])

            pu_trunk_config = dict(
                last_layer_activate=False,
                input_dim=trunk_input,
                output_dim=trunk_output,
                name="pu_trunk"
            )

            trunk_config_list = []
            for trunk_idx in range(M_changed):
                temp_config = copy.deepcopy(pu_trunk_config)
                temp_config["activation"] = activation_choice
                temp_config["name"] = f"pu_{trunk_idx}_trunk"
                if trunk_idx < trunk_special_index:
                    temp_config["layer_sizes"] = special_pu_trunk_layer_sizes
                else:
                    temp_config["layer_sizes"] = [trunk_width for _ in range(trunk_depth)]
                trunk_config_list.append(temp_config)

            vanilla_trunk_config = dict(
                activation=activation_choice,
                last_layer_activate=True,  # in stacking, this should be activated
                layer_sizes=trunk_layer_sizes,
                output_dim=trunk_output,
                name="vanilla_trunk"
            )

            model_config = dict(
                branch_config=branch_config,
                trunk_config_list=trunk_config_list,
                vanilla_trunk_config=vanilla_trunk_config,
                bias_config=bias_config,
                num_partitions=M_changed
            )

            hyperparameter_dict = dict(
                print_bool=print_bool,
                print_interval=print_interval,
                epochs=epochs,
                model_config=model_config,
                opt_choice=opt_choice,
                schedule_choice=schedule_choice,
                lr_dict=lr_dict,
                problem=problem,
                N=N,
                N_test=N_test,
                p=p,
                M=M_changed,
                K=K,
                train_num_partitions=M_changed,
                test_num_partitions=M_test,
                train_num_groups=num_groups,
                test_num_groups=test_num_groups,
                m=m,
                dtype=dtype,
                seed=seed
            )

            train_config = dict(
                dummy_input=dummy_input,
                train_input=train_input,
                test_input=test_input,
                Y=Y,
                Y_test=Y_test,
                model_name=model_name,
            ) | hyperparameter_dict

            with open(str(param_save_file_name), "rb") as f:
                print(str(param_save_file_name))
                weights = torch.load(f)

            X = train_input[1]

            model_key = jax.random.PRNGKey(seed)
            _, model_forward = get_model(train_config["model_name"], train_config["model_config"])
            model_forward = jax.jit(model_forward, static_argnums=(3, 4))

            tri = Triangulation(X[:, 0], X[:, 1])

            if plot_train_test == "train":
                train_pred = model_forward(weights, None, train_input, num_groups, M_changed)
                if to_plot == "gradient":
                    spatial_gradients = get_max_grad_over_funcs(model_forward, weights, train_input, 1, num_groups, M_changed)
                mse = jnp.power(train_pred - Y, 2).mean(0)

            elif plot_train_test == "test":
                test_pred = model_forward(weights, None, test_input, test_num_groups, M_test)
                if to_plot == "gradient":
                    spatial_gradients = get_max_grad_over_funcs(model_forward, weights, test_input, 1, test_num_groups, M_test)
                mse = jnp.power(test_pred - Y_test, 2).mean(0)

            if to_plot == "gradient":
                arr.append(spatial_gradients)
            elif to_plot == "mse":
                arr.append(mse)

        plot_arr = jnp.asarray(arr).mean(0)
        title = "(Vanilla, PoU)"
        plot_helper_3d(X, plot_arr, title, save_file_name, clim_lower, clim_upper, axis_font, tick_fontsize, title_font, figsize)


def plot_ensemble_pod_sparse():
    project = "ensemble_pod_sparse_trunk_deeponet"
    model_name = "ensemble_pod_sparse_trunk_deeponet_cartesian_prod"

    save_file_name = fstr(save_folder.payload + "{project}_M={M}_p={p}_seed={seed}.json")
    param_save_file_name = fstr(save_folder.payload + "{project}_M={M}_p={p}_seed={seed}.pickle")
    log_str = fstr("ensemble-POD-sparse-trunk-deeponet on {problem} with {opt_choice} M={M} p={p}")

    trunk_input = 2
    branch_input = m

    for p in ps:
        arr = []
        print(log_str)
        for seed in seeds:
            trunk_output = p
            branch_output = p + pod_p

            branch_config = dict(
                activation=activation_choice,
                layer_sizes=branch_layer_sizes,
                output_dim=branch_output,
                name="branch"
            )

            bias_config = dict(name="bias")

            if not os.path.isdir(str(save_folder)):
                os.makedirs(str(save_folder))

            train_input, Y, test_input, Y_test = get_data(str(dataset_folder))

            trunk_pod_basis, x_mean = form_pod_basis(Y, pod_p)
            x_mean = x_mean.T

            train_input = (train_input[0], train_input[1])
            test_input = (test_input[0], test_input[1])

            # partitioning ============================================================
            # TRAINING
            train_trunk_input = train_input[1]
            test_trunk_input = test_input[1]

            pu_config = dict(
                num_partitions=M,
                dim=train_trunk_input.shape[-1],
            )

            pu_obj = PU(pu_config)

            pu_obj.partition_domain(train_trunk_input, centers=centers_list, radius=radius_list)
            K, M_changed, _ = pu_obj.check_partioning(train_trunk_input, change_M=True, min_partitions_per_point=min_partitions_per_point)
            pu_trunk_width = int(1/onp.sqrt(K) * trunk_width)
            num_groups, participation_idx, points_per_group, indices_per_group, _, radius_arrs = pu_obj.form_points_per_group(train_trunk_input)
            weights_per_group = pu_obj.form_weights_per_group(points_per_group=points_per_group, participation_idx=participation_idx, radius_arrs=radius_arrs)
            weights_per_group = [jnp.broadcast_to(w.reshape(*w.shape, 1), (w.shape[0], w.shape[1], p)) for w in weights_per_group]

            # TEST
            _, M_test, _ = pu_obj.check_partioning(test_trunk_input, False, min_partitions_per_point=min_partitions_per_point)
            test_num_groups, test_participation_idx, test_points_per_group, test_indices_per_group, _, test_radius_arrs = pu_obj.form_points_per_group(test_trunk_input)
            test_weights_per_group = pu_obj.form_weights_per_group(points_per_group=test_points_per_group, participation_idx=test_participation_idx, radius_arrs=test_radius_arrs)
            test_weights_per_group = [jnp.broadcast_to(w.reshape(*w.shape, 1), (w.shape[0], w.shape[1], p)) for w in test_weights_per_group]

            # reassembling training and test and dummy tuples
            train_input = (train_input[0], train_input[1], points_per_group, weights_per_group, indices_per_group, participation_idx, num_groups)
            test_input = (test_input[0], test_input[1], test_points_per_group, test_weights_per_group, test_indices_per_group, test_participation_idx, test_num_groups)

            dummy_input = (train_input[0], train_input[1])

            pu_trunk_config = dict(
                last_layer_activate=False,
                input_dim=trunk_input,
                output_dim=trunk_output,
                name="pu_trunk"
            )

            trunk_config_list = []
            for trunk_idx in range(M_changed):
                temp_config = copy.deepcopy(pu_trunk_config)
                temp_config["activation"] = activation_choice
                temp_config["name"] = f"pu_{trunk_idx}_trunk"
                if trunk_idx < trunk_special_index:
                    temp_config["layer_sizes"] = special_pu_trunk_layer_sizes
                else:
                    temp_config["layer_sizes"] = [trunk_width for _ in range(trunk_depth)]
                trunk_config_list.append(temp_config)

            model_config = dict(
                branch_config=branch_config,
                trunk_config_list=trunk_config_list,
                bias_config=bias_config,
                num_partitions=M_changed,
                trunk_pod_basis=trunk_pod_basis,
                pod_mean=x_mean,
                p=p,
                pod_p=pod_p
            )

            hyperparameter_dict = dict(
                print_bool=print_bool,
                print_interval=print_interval,
                epochs=epochs,
                model_config=model_config,
                opt_choice=opt_choice,
                schedule_choice=schedule_choice,
                lr_dict=lr_dict,
                problem=problem,
                N=N,
                N_test=N_test,
                p=p,
                M=M_changed,
                K=K,
                train_num_partitions=M_changed,
                test_num_partitions=M_test,
                train_num_groups=num_groups,
                test_num_groups=test_num_groups,
                m=m,
                dtype=dtype,
                seed=seed
            )

            train_config = dict(
                dummy_input=dummy_input,
                train_input=train_input,
                test_input=test_input,
                Y=Y,
                Y_test=Y_test,
                model_name=model_name,
            ) | hyperparameter_dict

            with open(str(param_save_file_name), "rb") as f:
                print(str(param_save_file_name))
                weights = torch.load(f)

            X = train_input[1]

            model_key = jax.random.PRNGKey(seed)
            _, model_forward = get_model(train_config["model_name"], train_config["model_config"])
            model_forward = jax.jit(model_forward, static_argnums=(3, 4))

            tri = Triangulation(X[:, 0], X[:, 1])
            if plot_train_test == "train":
                train_pred = model_forward(weights, None, train_input, num_groups, M_changed)
                if to_plot == "gradient":
                    spatial_gradients = get_max_grad_over_funcs(model_forward, weights, train_input, 1, num_groups, M_changed)
                mse = jnp.power(train_pred - Y, 2).mean(0)

            elif plot_train_test == "test":
                test_pred = model_forward(weights, None, test_input, test_num_groups, M_test)
                if to_plot == "gradient":
                    spatial_gradients = get_max_grad_over_funcs(model_forward, weights, test_input, 1, test_num_groups, M_test)
                mse = jnp.power(test_pred - Y_test, 2).mean(0)

            if to_plot == "gradient":
                arr.append(spatial_gradients)
            elif to_plot == "mse":
                arr.append(mse)

        plot_arr = jnp.asarray(arr).mean(0)
        title = "(POD, PoU)"
        plot_helper_3d(X, plot_arr, title, save_file_name, clim_lower, clim_upper, axis_font, tick_fontsize, title_font, figsize)


def ensemble_pod_vanilla_trunk_plot():
    project = "ensemble_pod_vanilla_trunk_deeponet"
    model_name = "ensemble_pod_vanilla_trunk_deeponet_cartesian_prod"

    save_file_name = fstr(save_folder.payload + "{project}_p={p}_seed={seed}.json")
    param_save_file_name = fstr(save_folder.payload + "{project}_p={p}_seed={seed}.pickle")
    log_str = fstr("ensemble-POD-vanilla-trunk-deeponet on {problem} with {opt_choice} p={p}")

    trunk_input = 2
    branch_input = m

    for p in ps:
        arr = []
        print(log_str)
        for seed in seeds:
            trunk_output = p
            branch_output = p + pod_p

            branch_config = dict(
                activation=activation_choice,
                layer_sizes=branch_layer_sizes,
                output_dim=branch_output,
                name="branch"
            )

            bias_config = dict(name="bias")

            if not os.path.isdir(str(save_folder)):
                os.makedirs(str(save_folder))

            train_input, Y, test_input, Y_test = get_data(str(dataset_folder))

            trunk_pod_basis, x_mean = form_pod_basis(Y, pod_p)
            x_mean = x_mean.T

            train_input = (train_input[0], train_input[1])
            test_input = (test_input[0], test_input[1])

            dummy_input = (train_input[0], train_input[1])

            vanilla_trunk_config = dict(
                activation=activation_choice,
                last_layer_activate=True,  # should be activated when stacking
                layer_sizes=trunk_layer_sizes,
                output_dim=trunk_output,
                name="trunk"
            )

            model_config = dict(
                branch_config=branch_config,
                vanilla_trunk_config=vanilla_trunk_config,
                bias_config=bias_config,
                trunk_pod_basis=trunk_pod_basis,
                pod_mean=x_mean,
                p=p,
                pod_p=pod_p
            )

            hyperparameter_dict = dict(
                print_bool=print_bool,
                print_interval=print_interval,
                epochs=epochs,
                model_config=model_config,
                opt_choice=opt_choice,
                schedule_choice=schedule_choice,
                lr_dict=lr_dict,
                problem=problem,
                N=N,
                N_test=N_test,
                p=p,
                m=m,
                dtype=dtype,
                seed=seed
            )

            train_config = dict(
                dummy_input=dummy_input,
                train_input=train_input,
                test_input=test_input,
                Y=Y,
                Y_test=Y_test,
                model_name=model_name,
            ) | hyperparameter_dict

            with open(str(param_save_file_name), "rb") as f:
                print(str(param_save_file_name))
                weights = torch.load(f)

            X = train_input[1]

            model_key = jax.random.PRNGKey(seed)
            _, model_forward = get_model(train_config["model_name"], train_config["model_config"])
            model_forward = jax.jit(model_forward)

            tri = Triangulation(X[:, 0], X[:, 1])

            if plot_train_test == "train":
                train_pred = model_forward(weights, None, train_input)
                if to_plot == "gradient":
                    spatial_gradients = get_max_grad_over_funcs(model_forward, weights, train_input, 1)
                mse = jnp.power(train_pred - Y, 2).mean(0)

            elif plot_train_test == "test":
                test_pred = model_forward(weights, None, test_input)
                if to_plot == "gradient":
                    spatial_gradients = get_max_grad_over_funcs(model_forward, weights, test_input, 1)
                mse = jnp.power(test_pred - Y_test, 2).mean(0)

            if to_plot == "gradient":
                arr.append(spatial_gradients)
            elif to_plot == "mse":
                arr.append(mse)

        plot_arr = jnp.asarray(arr).mean(0)
        title = "(POD, Vanilla)"
        plot_helper_3d(X, plot_arr, title, save_file_name, clim_lower, clim_upper, axis_font, tick_fontsize, title_font, figsize)


def ensemble_vanilla_pod_sparse_trunk_plot():
    project = "ensemble_vanilla_pod_sparse_trunk_deeponet"
    model_name = "ensemble_vanilla_pod_sparse_trunk_deeponet_cartesian_prod"

    save_file_name = fstr(save_folder.payload + "{project}_M={M}_p={p}_seed={seed}.json")
    param_save_file_name = fstr(save_folder.payload + "{project}_M={M}_p={p}_seed={seed}.pickle")
    log_str = fstr("ensemble-vanilla-POD-sparse-trunk-deeponet on {problem} with {opt_choice} M={M} p={p}")

    trunk_input = 2
    branch_input = m

    for p in ps:
        arr = []
        print(log_str)
        for seed in seeds:
            trunk_output = p
            branch_output = p*2 + pod_p

            branch_config = dict(
                activation=activation_choice,
                layer_sizes=branch_layer_sizes,
                output_dim=branch_output,
                name="branch"
            )

            bias_config = dict(name="bias")

            if not os.path.isdir(str(save_folder)):
                os.makedirs(str(save_folder))

            train_input, Y, test_input, Y_test = get_data(str(dataset_folder))

            trunk_pod_basis, x_mean = form_pod_basis(Y, pod_p)
            x_mean = x_mean.T

            train_input = (train_input[0], train_input[1])
            test_input = (test_input[0], test_input[1])

            # partitioning ============================================================
            # TRAINING
            train_trunk_input = train_input[1]
            test_trunk_input = test_input[1]

            pu_config = dict(
                num_partitions=M,
                dim=train_trunk_input.shape[-1],
            )

            pu_obj = PU(pu_config)

            pu_obj.partition_domain(train_trunk_input, centers=centers_list, radius=radius_list)
            K, M_changed, _ = pu_obj.check_partioning(train_trunk_input, change_M=True, min_partitions_per_point=min_partitions_per_point)
            pu_trunk_width = int(1/onp.sqrt(K) * trunk_width)
            num_groups, participation_idx, points_per_group, indices_per_group, _, radius_arrs = pu_obj.form_points_per_group(train_trunk_input)
            weights_per_group = pu_obj.form_weights_per_group(points_per_group=points_per_group, participation_idx=participation_idx, radius_arrs=radius_arrs)
            weights_per_group = [jnp.broadcast_to(w.reshape(*w.shape, 1), (w.shape[0], w.shape[1], p)) for w in weights_per_group]

            # TEST
            _, M_test, _ = pu_obj.check_partioning(test_trunk_input, False, min_partitions_per_point=min_partitions_per_point)
            test_num_groups, test_participation_idx, test_points_per_group, test_indices_per_group, _, test_radius_arrs = pu_obj.form_points_per_group(test_trunk_input)
            test_weights_per_group = pu_obj.form_weights_per_group(points_per_group=test_points_per_group, participation_idx=test_participation_idx, radius_arrs=test_radius_arrs)
            test_weights_per_group = [jnp.broadcast_to(w.reshape(*w.shape, 1), (w.shape[0], w.shape[1], p)) for w in test_weights_per_group]

            # reassembling training and test and dummy tuples
            train_input = (train_input[0], train_input[1], points_per_group, weights_per_group, indices_per_group, participation_idx, num_groups)
            test_input = (test_input[0], test_input[1], test_points_per_group, test_weights_per_group, test_indices_per_group, test_participation_idx, test_num_groups)

            dummy_input = (train_input[0], train_input[1])

            pu_trunk_config = dict(
                last_layer_activate=False,
                input_dim=trunk_input,
                output_dim=trunk_output,
                name="pu_trunk"
            )

            trunk_config_list = []
            for trunk_idx in range(M_changed):
                temp_config = copy.deepcopy(pu_trunk_config)
                temp_config["activation"] = activation_choice
                temp_config["name"] = f"pu_{trunk_idx}_trunk"
                if trunk_idx < trunk_special_index:
                    temp_config["layer_sizes"] = special_pu_trunk_layer_sizes
                else:
                    temp_config["layer_sizes"] = [trunk_width for _ in range(trunk_depth)]
                trunk_config_list.append(temp_config)

            vanilla_trunk_config = dict(
                activation=activation_choice,
                last_layer_activate=True,  # should be activated when stacking
                layer_sizes=trunk_layer_sizes,
                output_dim=trunk_output,
                name="trunk"
            )

            model_config = dict(
                branch_config=branch_config,
                trunk_config_list=trunk_config_list,
                vanilla_trunk_config=vanilla_trunk_config,
                bias_config=bias_config,
                num_partitions=M_changed,
                trunk_pod_basis=trunk_pod_basis,
                pod_mean=x_mean,
                p=p,
                pod_p=pod_p
            )

            hyperparameter_dict = dict(
                print_bool=print_bool,
                print_interval=print_interval,
                epochs=epochs,
                model_config=model_config,
                opt_choice=opt_choice,
                schedule_choice=schedule_choice,
                lr_dict=lr_dict,
                problem=problem,
                N=N,
                N_test=N_test,
                p=p,
                M=M_changed,
                K=K,
                train_num_partitions=M_changed,
                test_num_partitions=M_test,
                train_num_groups=num_groups,
                test_num_groups=test_num_groups,
                m=m,
                dtype=dtype,
                seed=seed
            )

            train_config = dict(
                dummy_input=dummy_input,
                train_input=train_input,
                test_input=test_input,
                Y=Y,
                Y_test=Y_test,
                model_name=model_name,
            ) | hyperparameter_dict

            with open(str(param_save_file_name), "rb") as f:
                print(str(param_save_file_name))
                weights = torch.load(f)

            X = train_input[1]

            model_key = jax.random.PRNGKey(seed)
            _, model_forward = get_model(train_config["model_name"], train_config["model_config"])
            model_forward = jax.jit(model_forward, static_argnums=(3, 4))

            tri = Triangulation(X[:, 0], X[:, 1])

            if plot_train_test == "train":
                train_pred = model_forward(weights, None, train_input, num_groups, M_changed)
                if to_plot == "gradient":
                    spatial_gradients = get_max_grad_over_funcs(model_forward, weights, train_input, 1, num_groups, M_changed)
                mse = jnp.power(train_pred - Y, 2).mean(0)

            elif plot_train_test == "test":
                test_pred = model_forward(weights, None, test_input, test_num_groups, M_test)
                if to_plot == "gradient":
                    spatial_gradients = get_max_grad_over_funcs(model_forward, weights, test_input, 1, test_num_groups, M_test)
                mse = jnp.power(test_pred - Y_test, 2).mean(0)

            if to_plot == "gradient":
                arr.append(spatial_gradients)
            elif to_plot == "mse":
                arr.append(mse)

        plot_arr = jnp.asarray(arr).mean(0)
        title = "(Vanilla, POD, PoU)"
        plot_helper_3d(X, plot_arr, title, save_file_name, clim_lower, clim_upper, axis_font, tick_fontsize, title_font, figsize)


def get_data(filename):
    """
    Please refer to the paper for the datasets used.
    """
    try:
        gamma = jnp.asarray(io.loadmat(filename+"c_initial.mat")["c_initial"]).T.astype(dtype)
    except:
        raise BaseException("\n\nPlease place the gamma.mat file in the dataset/diffrec_3d folder.")
    X = io.loadmat(filename+"X.mat")["X"]
    trunk_input = jnp.asarray(X).astype(dtype)
    c_end = jnp.asarray(io.loadmat(filename+"c1.mat")["c1"]).T.astype(dtype)

    x_branch_train = gamma[:N, :]
    x_branch_test = gamma[N:N+N_test, :]

    x_branch_train = x_branch_train.astype(dtype)
    x_branch_test = x_branch_test.astype(dtype)
    trunk_input = trunk_input.astype(dtype)
    y_train = c_end[:N, :].astype(dtype)
    y_test = c_end[N:N+N_test, :].astype(dtype)

    train_input = (x_branch_train, trunk_input)
    test_input = (x_branch_test, trunk_input)

    return train_input, y_train, test_input, y_test


def vanilla_deeponet_run():
    project = "vanilla_deeponet"
    model_name = "deeponet_cartesian_prod"
    save_file_name = fstr(save_folder.payload + "{project}_p={p}_seed={seed}.json")
    param_save_file_name = fstr(save_folder.payload + "{project}_p={p}_seed={seed}.pickle")
    log_str = fstr("vanilla-deeponet on {problem} with {opt_choice} p={p} seed={seed}")

    trunk_input = 2
    branch_input = m

    for seed in seeds:
        for p in ps:
            trunk_output = p
            branch_output = p

            trunk_config = dict(
                activation=activation_choice,
                last_layer_activate=True,
                layer_sizes=trunk_layer_sizes,
                output_dim=trunk_output,
                name="trunk"
            )

            branch_config = dict(
                activation=activation_choice,
                layer_sizes=branch_layer_sizes,
                output_dim=branch_output,
                name="branch"
            )

            bias_config = dict(name="bias")

            model_config = dict(
                branch_config=branch_config,
                trunk_config=trunk_config,
                bias_config=bias_config
            )

            if not os.path.isdir(str(save_folder)):
                os.makedirs(str(save_folder))

            print("\n\n")
            print("+"*100)
            print(log_str)
            print(str(save_folder))
            print("+"*100)
            print("\n\n")

            train_input, Y, test_input, Y_test = get_data(str(dataset_folder))

            train_input = (train_input[0], train_input[1])
            test_input = (test_input[0], test_input[1])

            dummy_input = (train_input[0], train_input[1])

            hyperparameter_dict = dict(
                print_bool=print_bool,
                print_interval=print_interval,
                epochs=epochs,
                model_config=model_config,
                opt_choice=opt_choice,
                schedule_choice=schedule_choice,
                lr_dict=lr_dict,
                problem=problem,
                N=N,
                N_test=N_test,
                p=p,
                m=m,
                dtype=dtype,
                seed=seed
            )

            train_config = dict(
                dummy_input=dummy_input,
                train_input=train_input,
                test_input=test_input,
                Y=Y,
                Y_test=Y_test,
                model_name=model_name,
            ) | hyperparameter_dict

            logged_results, trained_params = train_func(train_config)
            logged_results = logged_results | hyperparameter_dict

            if save_results_bool:
                torch.save(trained_params, str(param_save_file_name))
                save_results(logged_results, str(save_file_name))


def pod_deeponet_run():
    project = "pod_deeponet"
    model_name = "pod_deeponet_cartesian_prod"
    save_file_name = fstr(save_folder.payload + "{project}_p={p}_seed={seed}.json")
    param_save_file_name = fstr(save_folder.payload + "{project}_p={p}_seed={seed}.pickle")
    log_str = fstr("pod-deeponet on {problem} with {opt_choice} p={p} seed={seed}")

    trunk_input = 2
    branch_input = m

    for seed in seeds:
        for mean_bool in mean_bools:
            for p in pod_ps:
                trunk_output = p
                branch_output = p

                trunk_config = dict(
                    activation=activation_choice,
                    last_layer_activate=True,
                    layer_sizes=trunk_layer_sizes,
                    output_dim=trunk_output,
                    name="trunk"
                )

                branch_config = dict(
                    activation=activation_choice,
                    layer_sizes=branch_layer_sizes,
                    output_dim=branch_output,
                    name="branch"
                )

                bias_config = dict(name="bias")

                if not os.path.isdir(str(save_folder)):
                    os.makedirs(str(save_folder))

                print("\n\n")
                print("+"*100)
                print(log_str)
                print(str(save_folder))
                print("+"*100)
                print("\n\n")

                train_input, Y, test_input, Y_test = get_data(str(dataset_folder))

                train_trunk_input, x_mean = form_pod_basis(Y, p, not mean_bool)

                x_mean = x_mean.T

                # putting pod basis back in tuples
                train_input = (train_input[0], train_trunk_input)
                test_input = (test_input[0], train_trunk_input)

                dummy_input = (train_input[0], train_trunk_input)

                model_config = dict(
                    branch_config=branch_config,
                    trunk_config=trunk_config,
                    bias_config=bias_config,
                    pod_mean=x_mean,
                    mean_bool=mean_bool,
                    p=p
                )

                hyperparameter_dict = dict(
                    print_bool=print_bool,
                    print_interval=print_interval,
                    epochs=epochs,
                    model_config=model_config,
                    opt_choice=opt_choice,
                    schedule_choice=schedule_choice,
                    lr_dict=lr_dict,
                    problem=problem,
                    N=N,
                    N_test=N_test,
                    p=p,
                    m=m,
                    dtype=dtype,
                    seed=seed
                )

                train_config = dict(
                    dummy_input=dummy_input,
                    train_input=train_input,
                    test_input=test_input,
                    Y=Y,
                    Y_test=Y_test,
                    model_name=model_name,
                ) | hyperparameter_dict

                logged_results, trained_params = train_func(train_config)

                # removing tensors from dict so results can be saved
                del hyperparameter_dict["model_config"]["pod_mean"]

                logged_results = logged_results | hyperparameter_dict

                if save_results_bool:
                    torch.save(trained_params, str(param_save_file_name))
                    save_results(logged_results, str(save_file_name))


def ensemble_pod_vanilla_trunk_deeponet_run():
    project = "ensemble_pod_vanilla_trunk_deeponet"
    model_name = "ensemble_pod_vanilla_trunk_deeponet_cartesian_prod"

    save_file_name = fstr(save_folder.payload + "{project}_p={p}_seed={seed}.json")
    param_save_file_name = fstr(save_folder.payload + "{project}_p={p}_seed={seed}.pickle")
    log_str = fstr("ensemble-POD-vanilla-trunk-deeponet on {problem} with {opt_choice} p={p} seed={seed}")

    trunk_input = 2
    branch_input = m

    for seed in seeds:
        for p in ps:
            trunk_output = p
            branch_output = p + pod_p

            branch_config = dict(
                activation=activation_choice,
                layer_sizes=branch_layer_sizes,
                output_dim=branch_output,
                name="branch"
            )

            bias_config = dict(name="bias")

            if not os.path.isdir(str(save_folder)):
                os.makedirs(str(save_folder))

            print("\n\n")
            print("+"*100)
            print(log_str)
            print(str(save_folder))
            print("+"*100)
            print("\n\n")

            train_input, Y, test_input, Y_test = get_data(str(dataset_folder))

            trunk_pod_basis, x_mean = form_pod_basis(Y, pod_p)
            x_mean = x_mean.T

            train_input = (train_input[0], train_input[1])
            test_input = (test_input[0], test_input[1])

            dummy_input = (train_input[0], train_input[1])

            vanilla_trunk_config = dict(
                activation=activation_choice,
                last_layer_activate=True,  # should be activated when stacking
                layer_sizes=trunk_layer_sizes,
                output_dim=trunk_output,
                name="trunk"
            )

            model_config = dict(
                branch_config=branch_config,
                vanilla_trunk_config=vanilla_trunk_config,
                bias_config=bias_config,
                trunk_pod_basis=trunk_pod_basis,
                pod_mean=x_mean,
                p=p,
                pod_p=pod_p
            )

            hyperparameter_dict = dict(
                print_bool=print_bool,
                print_interval=print_interval,
                epochs=epochs,
                model_config=model_config,
                opt_choice=opt_choice,
                schedule_choice=schedule_choice,
                lr_dict=lr_dict,
                problem=problem,
                N=N,
                N_test=N_test,
                p=p,
                m=m,
                dtype=dtype,
                seed=seed
            )

            train_config = dict(
                dummy_input=dummy_input,
                train_input=train_input,
                test_input=test_input,
                Y=Y,
                Y_test=Y_test,
                model_name=model_name,
            ) | hyperparameter_dict

            logged_results, trained_params = train_func(train_config)

            # removing tensors from dict so results can be saved
            del hyperparameter_dict["model_config"]["pod_mean"]
            del hyperparameter_dict["model_config"]["trunk_pod_basis"]

            logged_results = logged_results | hyperparameter_dict

            if save_results_bool:
                torch.save(trained_params, str(param_save_file_name))
                save_results(logged_results, str(save_file_name))


def ensemble_vanilla_sparse_trunk_deeponet_run():
    project = "ensemble_vanilla_sparse_trunk_deeponet"
    model_name = "ensemble_vanilla_sparse_trunk_deeponet_cartesian_prod"

    save_file_name = fstr(save_folder.payload + "{project}_M={M}_p={p}_seed={seed}.json")
    param_save_file_name = fstr(save_folder.payload + "{project}_M={M}_p={p}_seed={seed}.pickle")
    log_str = fstr("ensemble-vanilla-sparse-trunk-deeponet on {problem} with {opt_choice} M={M} p={p} seed={seed}")

    trunk_input = 2
    branch_input = m

    for seed in seeds:
        for p in ps:
            trunk_output = p
            branch_output = p * 2

            branch_config = dict(
                activation=activation_choice,
                layer_sizes=branch_layer_sizes,
                output_dim=branch_output,
                name="branch"
            )

            bias_config = dict(name="bias")

            if not os.path.isdir(str(save_folder)):
                os.makedirs(str(save_folder))

            print("\n\n")
            print("+"*100)
            print(log_str)
            print(str(save_folder))
            print("+"*100)
            print("\n\n")

            train_input, Y, test_input, Y_test = get_data(str(dataset_folder))

            train_input = (train_input[0], train_input[1])
            test_input = (test_input[0], test_input[1])

            # partitioning ============================================================
            # TRAINING
            train_trunk_input = train_input[1]
            test_trunk_input = test_input[1]

            pu_config = dict(
                num_partitions=M,
                dim=train_trunk_input.shape[-1],
            )

            pu_obj = PU(pu_config)

            pu_obj.partition_domain(train_trunk_input, centers=centers_list, radius=radius_list)
            print("Radius = ", pu_obj.radius)
            print(f"M={M}")
            print("Checking train partitioning")
            K, M_changed, _ = pu_obj.check_partioning(train_trunk_input, change_M=True, min_partitions_per_point=min_partitions_per_point)
            pu_trunk_width = int(1/onp.sqrt(K) * trunk_width)
            print(f"K={K}, PU trunk_width={pu_trunk_width}")
            print(f"new M={M_changed}")
            print("Forming points per groups")
            num_groups, participation_idx, points_per_group, indices_per_group, _, radius_arrs = pu_obj.form_points_per_group(train_trunk_input)
            print("Group sizes: ", [len(g) for g in points_per_group])
            print("Forming weights per groups")
            weights_per_group = pu_obj.form_weights_per_group(points_per_group=points_per_group, participation_idx=participation_idx, radius_arrs=radius_arrs)
            weights_per_group = [jnp.broadcast_to(w.reshape(*w.shape, 1), (w.shape[0], w.shape[1], p)) for w in weights_per_group]

            # TEST
            print("\nChecking test partitioning")
            _, M_test, _ = pu_obj.check_partioning(test_trunk_input, False, min_partitions_per_point=min_partitions_per_point)
            print(f"M_test={M}")
            print("Forming points per groups")
            test_num_groups, test_participation_idx, test_points_per_group, test_indices_per_group, _, test_radius_arrs = pu_obj.form_points_per_group(test_trunk_input)
            print("Group sizes: ", [len(g) for g in test_points_per_group])
            print("Forming weights per groups")
            test_weights_per_group = pu_obj.form_weights_per_group(points_per_group=test_points_per_group, participation_idx=test_participation_idx, radius_arrs=test_radius_arrs)
            test_weights_per_group = [jnp.broadcast_to(w.reshape(*w.shape, 1), (w.shape[0], w.shape[1], p)) for w in test_weights_per_group]

            # reassembling training and test and dummy tuples
            train_input = (train_input[0], train_input[1], points_per_group, weights_per_group, indices_per_group, participation_idx, num_groups)
            test_input = (test_input[0], test_input[1], test_points_per_group, test_weights_per_group, test_indices_per_group, test_participation_idx, test_num_groups)
            dummy_input = (train_input[0], train_input[1])

            pu_trunk_config = dict(
                last_layer_activate=False,
                input_dim=trunk_input,
                output_dim=trunk_output,
                name="pu_trunk"
            )

            trunk_config_list = []
            for trunk_idx in range(M_changed):
                temp_config = copy.deepcopy(pu_trunk_config)
                temp_config["activation"] = activation_choice
                temp_config["name"] = f"pu_{trunk_idx}_trunk"
                if trunk_idx < trunk_special_index:
                    temp_config["layer_sizes"] = special_pu_trunk_layer_sizes
                else:
                    temp_config["layer_sizes"] = [trunk_width for _ in range(trunk_depth)]
                trunk_config_list.append(temp_config)

            vanilla_trunk_config = dict(
                activation=activation_choice,
                last_layer_activate=True,  # in stacking, this should be activated
                layer_sizes=trunk_layer_sizes,
                output_dim=trunk_output,
                name="vanilla_trunk"
            )

            model_config = dict(
                branch_config=branch_config,
                trunk_config_list=trunk_config_list,
                vanilla_trunk_config=vanilla_trunk_config,
                bias_config=bias_config,
                num_partitions=M_changed
            )

            hyperparameter_dict = dict(
                print_bool=print_bool,
                print_interval=print_interval,
                epochs=epochs,
                model_config=model_config,
                opt_choice=opt_choice,
                schedule_choice=schedule_choice,
                lr_dict=lr_dict,
                problem=problem,
                N=N,
                N_test=N_test,
                p=p,
                M=M_changed,
                K=K,
                train_num_partitions=M_changed,
                test_num_partitions=M_test,
                train_num_groups=num_groups,
                test_num_groups=test_num_groups,
                m=m,
                dtype=dtype,
                seed=seed
            )

            train_config = dict(
                dummy_input=dummy_input,
                train_input=train_input,
                test_input=test_input,
                Y=Y,
                Y_test=Y_test,
                model_name=model_name,
            ) | hyperparameter_dict

            logged_results, trained_params = train_func(train_config)
            logged_results = logged_results | hyperparameter_dict

            if save_results_bool:
                torch.save(trained_params, str(param_save_file_name))
                save_results(logged_results, str(save_file_name))


def ensemble_pod_sparse_trunk_deeponet_run():
    project = "ensemble_pod_sparse_trunk_deeponet"
    model_name = "ensemble_pod_sparse_trunk_deeponet_cartesian_prod"

    save_file_name = fstr(save_folder.payload + "{project}_M={M}_p={p}_seed={seed}.json")
    param_save_file_name = fstr(save_folder.payload + "{project}_M={M}_p={p}_seed={seed}.pickle")
    log_str = fstr("ensemble-POD-sparse-trunk-deeponet on {problem} with {opt_choice} M={M} p={p} seed={seed}")

    trunk_input = 2
    branch_input = m

    for seed in seeds:
        for p in ps:
            trunk_output = p
            branch_output = p + pod_p

            branch_config = dict(
                activation=activation_choice,
                layer_sizes=branch_layer_sizes,
                output_dim=branch_output,
                name="branch"
            )

            bias_config = dict(name="bias")

            if not os.path.isdir(str(save_folder)):
                os.makedirs(str(save_folder))

            print("\n\n")
            print("+"*100)
            print(log_str)
            print(str(save_folder))
            print("+"*100)
            print("\n\n")

            train_input, Y, test_input, Y_test = get_data(str(dataset_folder))

            trunk_pod_basis, x_mean = form_pod_basis(Y, pod_p)
            x_mean = x_mean.T

            train_input = (train_input[0], train_input[1])
            test_input = (test_input[0], test_input[1])

            # partitioning ============================================================
            # TRAINING
            train_trunk_input = train_input[1]
            test_trunk_input = test_input[1]

            pu_config = dict(
                num_partitions=M,
                dim=train_trunk_input.shape[-1],
            )

            pu_obj = PU(pu_config)

            pu_obj.partition_domain(train_trunk_input, centers=centers_list, radius=radius_list)
            print("Radius = ", pu_obj.radius)
            print(f"M={M}")
            print("Checking train partitioning")
            K, M_changed, _ = pu_obj.check_partioning(train_trunk_input, change_M=True, min_partitions_per_point=min_partitions_per_point)
            pu_trunk_width = int(1/onp.sqrt(K) * trunk_width)
            print(f"K={K}, PU trunk_width={pu_trunk_width}")
            print(f"new M={M_changed}")
            print("Forming points per groups")
            num_groups, participation_idx, points_per_group, indices_per_group, _, radius_arrs = pu_obj.form_points_per_group(train_trunk_input)
            print("Group sizes: ", [len(g) for g in points_per_group])
            print("Forming weights per groups")
            weights_per_group = pu_obj.form_weights_per_group(points_per_group=points_per_group, participation_idx=participation_idx, radius_arrs=radius_arrs)
            weights_per_group = [jnp.broadcast_to(w.reshape(*w.shape, 1), (w.shape[0], w.shape[1], p)) for w in weights_per_group]

            # TEST
            print("\nChecking test partitioning")
            _, M_test, _ = pu_obj.check_partioning(test_trunk_input, False, min_partitions_per_point=min_partitions_per_point)
            print(f"M_test={M}")
            print("Forming points per groups")
            test_num_groups, test_participation_idx, test_points_per_group, test_indices_per_group, _, test_radius_arrs = pu_obj.form_points_per_group(test_trunk_input)
            print("Group sizes: ", [len(g) for g in test_points_per_group])
            print("Forming weights per groups")
            test_weights_per_group = pu_obj.form_weights_per_group(points_per_group=test_points_per_group, participation_idx=test_participation_idx, radius_arrs=test_radius_arrs)
            test_weights_per_group = [jnp.broadcast_to(w.reshape(*w.shape, 1), (w.shape[0], w.shape[1], p)) for w in test_weights_per_group]

            # reassembling training and test and dummy tuples
            train_input = (train_input[0], train_input[1], points_per_group, weights_per_group, indices_per_group, participation_idx, num_groups)
            test_input = (test_input[0], test_input[1], test_points_per_group, test_weights_per_group, test_indices_per_group, test_participation_idx, test_num_groups)

            dummy_input = (train_input[0], train_input[1])

            pu_trunk_config = dict(
                last_layer_activate=False,
                input_dim=trunk_input,
                output_dim=trunk_output,
                name="pu_trunk"
            )

            trunk_config_list = []
            for trunk_idx in range(M_changed):
                temp_config = copy.deepcopy(pu_trunk_config)
                temp_config["activation"] = activation_choice
                temp_config["name"] = f"pu_{trunk_idx}_trunk"
                if trunk_idx < trunk_special_index:
                    temp_config["layer_sizes"] = special_pu_trunk_layer_sizes
                else:
                    temp_config["layer_sizes"] = [trunk_width for _ in range(trunk_depth)]
                trunk_config_list.append(temp_config)

            model_config = dict(
                branch_config=branch_config,
                trunk_config_list=trunk_config_list,
                bias_config=bias_config,
                num_partitions=M_changed,
                trunk_pod_basis=trunk_pod_basis,
                pod_mean=x_mean,
                p=p,
                pod_p=pod_p
            )

            hyperparameter_dict = dict(
                print_bool=print_bool,
                print_interval=print_interval,
                epochs=epochs,
                model_config=model_config,
                opt_choice=opt_choice,
                schedule_choice=schedule_choice,
                lr_dict=lr_dict,
                problem=problem,
                N=N,
                N_test=N_test,
                p=p,
                M=M_changed,
                K=K,
                train_num_partitions=M_changed,
                test_num_partitions=M_test,
                train_num_groups=num_groups,
                test_num_groups=test_num_groups,
                m=m,
                dtype=dtype,
                seed=seed
            )

            train_config = dict(
                dummy_input=dummy_input,
                train_input=train_input,
                test_input=test_input,
                Y=Y,
                Y_test=Y_test,
                model_name=model_name,
            ) | hyperparameter_dict

            logged_results, trained_params = train_func(train_config)

            # removing tensors from dict so results can be saved
            del hyperparameter_dict["model_config"]["pod_mean"]
            del hyperparameter_dict["model_config"]["trunk_pod_basis"]

            logged_results = logged_results | hyperparameter_dict

            if save_results_bool:
                torch.save(trained_params, str(param_save_file_name))
                save_results(logged_results, str(save_file_name))


def ensemble_vanilla_pod_sparse_trunk_deeponet_run():
    project = "ensemble_vanilla_pod_sparse_trunk_deeponet"
    model_name = "ensemble_vanilla_pod_sparse_trunk_deeponet_cartesian_prod"

    save_file_name = fstr(save_folder.payload + "{project}_M={M}_p={p}_seed={seed}.json")
    param_save_file_name = fstr(save_folder.payload + "{project}_M={M}_p={p}_seed={seed}.pickle")
    log_str = fstr("ensemble-vanilla-POD-sparse-trunk-deeponet on {problem} with {opt_choice} M={M} p={p} seed={seed}")

    trunk_input = 2
    branch_input = m

    for seed in seeds:
        for p in ps:
            trunk_output = p
            branch_output = p*2 + pod_p

            branch_config = dict(
                activation=activation_choice,
                layer_sizes=branch_layer_sizes,
                output_dim=branch_output,
                name="branch"
            )

            bias_config = dict(name="bias")

            if not os.path.isdir(str(save_folder)):
                os.makedirs(str(save_folder))

            print("\n\n")
            print("+"*100)
            print(log_str)
            print(str(save_folder))
            print("+"*100)
            print("\n\n")

            train_input, Y, test_input, Y_test = get_data(str(dataset_folder))

            trunk_pod_basis, x_mean = form_pod_basis(Y, pod_p)
            x_mean = x_mean.T

            train_input = (train_input[0], train_input[1])
            test_input = (test_input[0], test_input[1])

            # partitioning ============================================================
            # TRAINING
            train_trunk_input = train_input[1]
            test_trunk_input = test_input[1]

            pu_config = dict(
                num_partitions=M,
                dim=train_trunk_input.shape[-1],
            )

            pu_obj = PU(pu_config)

            pu_obj.partition_domain(train_trunk_input, centers=centers_list, radius=radius_list)
            print("Radius = ", pu_obj.radius)
            print(f"M={M}")
            print("Checking train partitioning")
            K, M_changed, _ = pu_obj.check_partioning(train_trunk_input, change_M=True, min_partitions_per_point=min_partitions_per_point)
            pu_trunk_width = int(1/onp.sqrt(K) * trunk_width)
            print(f"K={K}, PU trunk_width={pu_trunk_width}")
            print(f"new M={M_changed}")
            print("Forming points per groups")
            num_groups, participation_idx, points_per_group, indices_per_group, _, radius_arrs = pu_obj.form_points_per_group(train_trunk_input)
            print("Group sizes: ", [len(g) for g in points_per_group])
            print("Forming weights per groups")
            weights_per_group = pu_obj.form_weights_per_group(points_per_group=points_per_group, participation_idx=participation_idx, radius_arrs=radius_arrs)
            weights_per_group = [jnp.broadcast_to(w.reshape(*w.shape, 1), (w.shape[0], w.shape[1], p)) for w in weights_per_group]

            # TEST
            print("\nChecking test partitioning")
            _, M_test, _ = pu_obj.check_partioning(test_trunk_input, False, min_partitions_per_point=min_partitions_per_point)
            print(f"M_test={M}")
            print("Forming points per groups")
            test_num_groups, test_participation_idx, test_points_per_group, test_indices_per_group, _, test_radius_arrs = pu_obj.form_points_per_group(test_trunk_input)
            print("Group sizes: ", [len(g) for g in test_points_per_group])
            print("Forming weights per groups")
            test_weights_per_group = pu_obj.form_weights_per_group(points_per_group=test_points_per_group, participation_idx=test_participation_idx, radius_arrs=test_radius_arrs)
            test_weights_per_group = [jnp.broadcast_to(w.reshape(*w.shape, 1), (w.shape[0], w.shape[1], p)) for w in test_weights_per_group]

            # reassembling training and test and dummy tuples
            train_input = (train_input[0], train_input[1], points_per_group, weights_per_group, indices_per_group, participation_idx, num_groups)
            test_input = (test_input[0], test_input[1], test_points_per_group, test_weights_per_group, test_indices_per_group, test_participation_idx, test_num_groups)

            dummy_input = (train_input[0], train_input[1])

            pu_trunk_config = dict(
                last_layer_activate=False,
                input_dim=trunk_input,
                output_dim=trunk_output,
                name="pu_trunk"
            )

            trunk_config_list = []
            for trunk_idx in range(M_changed):
                temp_config = copy.deepcopy(pu_trunk_config)
                temp_config["activation"] = activation_choice
                temp_config["name"] = f"pu_{trunk_idx}_trunk"
                if trunk_idx < trunk_special_index:
                    temp_config["layer_sizes"] = special_pu_trunk_layer_sizes
                else:
                    temp_config["layer_sizes"] = [trunk_width for _ in range(trunk_depth)]
                trunk_config_list.append(temp_config)

            vanilla_trunk_config = dict(
                activation=activation_choice,
                last_layer_activate=True,  # should be activated when stacking
                layer_sizes=trunk_layer_sizes,
                output_dim=trunk_output,
                name="trunk"
            )

            model_config = dict(
                branch_config=branch_config,
                trunk_config_list=trunk_config_list,
                vanilla_trunk_config=vanilla_trunk_config,
                bias_config=bias_config,
                num_partitions=M_changed,
                trunk_pod_basis=trunk_pod_basis,
                pod_mean=x_mean,
                p=p,
                pod_p=pod_p
            )

            hyperparameter_dict = dict(
                print_bool=print_bool,
                print_interval=print_interval,
                epochs=epochs,
                model_config=model_config,
                opt_choice=opt_choice,
                schedule_choice=schedule_choice,
                lr_dict=lr_dict,
                problem=problem,
                N=N,
                N_test=N_test,
                p=p,
                M=M_changed,
                K=K,
                train_num_partitions=M_changed,
                test_num_partitions=M_test,
                train_num_groups=num_groups,
                test_num_groups=test_num_groups,
                m=m,
                dtype=dtype,
                seed=seed
            )

            train_config = dict(
                dummy_input=dummy_input,
                train_input=train_input,
                test_input=test_input,
                Y=Y,
                Y_test=Y_test,
                model_name=model_name,
            ) | hyperparameter_dict

            logged_results, trained_params = train_func(train_config)

            # removing tensors from dict so results can be saved
            del hyperparameter_dict["model_config"]["pod_mean"]
            del hyperparameter_dict["model_config"]["trunk_pod_basis"]

            logged_results = logged_results | hyperparameter_dict

            if save_results_bool:
                torch.save(trained_params, str(param_save_file_name))
                save_results(logged_results, str(save_file_name))


def ensemble_vanilla_trunk_deeponet_run():
    project = "ensemble_vanilla_trunk_deeponet"
    model_name = "ensemble_vanilla_trunk_deeponet_cartesian_prod"

    save_file_name = fstr(save_folder.payload + "{project}_p={p}_seed={seed}.json")
    param_save_file_name = fstr(save_folder.payload + "{project}_p={p}_seed={seed}.pickle")
    log_str = fstr("ensemble-vanilla-trunk-deeponet on {problem} with {opt_choice} p={p} seed={seed}")

    trunk_input = 2
    branch_input = m

    for seed in seeds:
        for p in ps:
            trunk_output = p
            branch_output = p * num_trunks

            branch_config = dict(
                activation=activation_choice,
                layer_sizes=branch_layer_sizes,
                output_dim=branch_output,
                name="branch"
            )

            bias_config = dict(name="bias")

            if not os.path.isdir(str(save_folder)):
                os.makedirs(str(save_folder))

            print("\n\n")
            print("+"*100)
            print(log_str)
            print(str(save_folder))
            print("+"*100)
            print("\n\n")

            train_input, Y, test_input, Y_test = get_data(str(dataset_folder))

            train_input = (train_input[0], train_input[1])
            test_input = (test_input[0], test_input[1])

            dummy_input = (train_input[0], train_input[1])

            trunk_config = dict(
                activation=activation_choice,
                last_layer_activate=True,  # in this case, these trunks should be activated for the stacking
                layer_sizes=trunk_layer_sizes,
                output_dim=trunk_output,
                name="trunk"
            )

            model_config = dict(
                branch_config=branch_config,
                trunk_config=trunk_config,
                bias_config=bias_config,
                num_trunks=num_trunks
            )

            hyperparameter_dict = dict(
                print_bool=print_bool,
                print_interval=print_interval,
                epochs=epochs,
                model_config=model_config,
                opt_choice=opt_choice,
                schedule_choice=schedule_choice,
                lr_dict=lr_dict,
                problem=problem,
                N=N,
                N_test=N_test,
                p=p,
                m=m,
                dtype=dtype,
                seed=seed
            )

            train_config = dict(
                dummy_input=dummy_input,
                train_input=train_input,
                test_input=test_input,
                Y=Y,
                Y_test=Y_test,
                model_name=model_name,
            ) | hyperparameter_dict

            logged_results, trained_params = train_func(train_config)
            logged_results = logged_results | hyperparameter_dict

            if save_results_bool:
                torch.save(trained_params, str(param_save_file_name))
                save_results(logged_results, str(save_file_name))


if __name__ == "__main__":
    dataset_folder = fstr("../dataset/{problem}/")

    save_results_bool = True
    print_bool = True

    problem = "diffrec_3d"
    opt_choice = "adamw"
    schedule_choice = "inverse_time_decay"
    # activations
    activation_choice = "relu"  # activation everywhere

    print_interval = 1000
    dtype = "float32"

    N = 1000
    N_test = 200
    m = 20

    lr_dict = dict(
        peak_lr=6e-3,
        weight_decay=1e-4,
        decay_steps=1,
        decay_rate=1e-4,
    )

    if dtype == "float64":
        jax.config.update("jax_enable_x64", True)
    else:
        jax.config.update("jax_enable_x64", False)

    epochs = 150000

    centers_list = [
        [0.6, 0.5, -0.2],
        [0.6, -0.3, 0.3],
        [-0.7, 0.5, -0.4],
        [0.2, -0.7, 0.1],
        [0.1, -0.7, 0.4],
        [-0.3, 0.5, 0.7],
        [0., -0.4, -0.7],
        [-0.8, -.2, 0.3],
    ]

    M = len(centers_list)

    radius_list = []
    cnt = 1
    for c in centers_list:
        radius_list.append(0.9)
        cnt += 1
    trunk_special_index = M

    to_plot = "mse"
    plot_train_test = "test"
    clim_lower, clim_upper = 7e-7, 2e-6  # mse
    # clim_lower, clim_upper = 0.2, 2 # gradient
    title_font = 110
    axis_font = 110
    tick_fontsize = 70
    figsize = (22, 20)

    # sub project folder
    all_run_folder = "all_results"

    plot_save_folder = f"../figures/{problem}/"  # for heatmaps

    # architectures ===========================================================================
    trunk_depth = 3
    trunk_width = 128

    trunk_layer_sizes = [trunk_width for _ in range(trunk_depth)]

    branch_width = 128
    branch_depth = 3

    branch_layer_sizes = [branch_width for _ in range(branch_depth)]

    seeds = [0, 1, 2, 3, 4]
    ps = [100]
    pod_ps = [20]
    pod_p = 20
    # ============================================================================================================
    # vanilla deeponet

    folder_string = f"vanilla_deeponet_results"
    save_folder = fstr("../{all_run_folder}/{folder_string}/{problem}/")
    vanilla_deeponet_run()
    # plot_vanilla()

    # ============================================================================================================
    # POD deeponet

    mean_bools = [False]
    folder_string = fstr("pod_deeponet_results_mean_bool={mean_bool}")
    save_folder = fstr("../{all_run_folder}/" + folder_string.payload + "/{problem}/")
    # pod_deeponet_run()
    # if not to_plot == "gradient":
    #     plot_pod()

    # ============================================================================================================
    # ensemble vanilla trunk deeponet

    num_trunks = M + 1
    folder_string = f"ensemble_vanilla_trunk_deeponet_results"
    save_folder = fstr("../{all_run_folder}/{folder_string}/{problem}/")
    # ensemble_vanilla_trunk_deeponet_run()
    # if not to_plot == "gradient":
    #     ensemble_vanilla_trunk_plot()

    # ============================================================================================================
    # ensemble sparse trunk deeponet

    min_partitions_per_point = 1
    special_pu_trunk_layer_sizes = [int(trunk_width/1) for _ in range(trunk_depth)]
    folder_string = f"ensemble_vanilla_sparse_trunk_deeponet_results"
    save_folder = fstr("../{all_run_folder}/{folder_string}/{problem}/")
    # ensemble_vanilla_sparse_trunk_deeponet_run()
    # if not to_plot == "gradient":
    #     plot_ensemble_vanilla_sparse()

    # ============================================================================================================
    # ensemble POD sparse trunk deeponet

    min_partitions_per_point = 1
    special_pu_trunk_layer_sizes = [int(trunk_width/1) for _ in range(trunk_depth)]
    folder_string = f"ensemble_pod_sparse_trunk_deeponet_results"
    save_folder = fstr("../{all_run_folder}/{folder_string}/{problem}/")
    # ensemble_pod_sparse_trunk_deeponet_run()
    # if not to_plot == "gradient":
    #     plot_ensemble_pod_sparse()

    # ============================================================================================================
    # ensemble POD vanilla trunk deeponet

    folder_string = f"ensemble_pod_vanilla_trunk_deeponet_results"
    save_folder = fstr("../{all_run_folder}/{folder_string}/{problem}/")
    # ensemble_pod_vanilla_trunk_deeponet_run()
    # if not to_plot == "gradient":
    #     ensemble_pod_vanilla_trunk_plot()

    # ============================================================================================================
    # ensemble vanilla POD sparse trunk deeponet

    folder_string = f"ensemble_vanilla_pod_sparse_trunk_deeponet_results"
    save_folder = fstr("../{all_run_folder}/{folder_string}/{problem}/")
    # ensemble_vanilla_pod_sparse_trunk_deeponet_run()
    # if not to_plot == "gradient":
    #     ensemble_vanilla_pod_sparse_trunk_plot()
