import torch
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.patches import Circle
import copy
from scipy import io
import jax
import jax.numpy as jnp
import numpy as np
import os
import sys
from matplotlib.tri import Triangulation
import matplotlib.pyplot as plt

sys.path.append((os.path.dirname(os.path.dirname(__file__))))
from plot_utils import *
from train_utils import train_func
from utils import fstr, save_results, form_pod_basis, PU
from jax_networks import get_model


def plot_vanilla():
    project = "vanilla_deeponet"
    model_name = "deeponet_cartesian_prod"
    save_file_name = fstr(save_folder.payload + "{project}_p={p}_seed={seed}.json")
    param_save_file_name = fstr(save_folder.payload + "{project}_p={p}_seed={seed}.pickle")
    log_str = fstr("vanilla-deeponet on {problem} with {opt_choice} p={p}")

    trunk_input = 2
    branch_input = m

    for p in ps:
        arr = []
        print(log_str)
        for seed in seeds:
            trunk_output = p
            branch_output = p

            trunk_config = dict(
                activation=activation_choice,
                last_layer_activate=True,
                num_hidden_layers=trunk_depth,
                nodes=trunk_width,
                output_dim=trunk_output,
                name="trunk"
            )

            branch_config = dict(
                activation=activation_choice,
                num_hidden_layers=branch_depth,
                nodes=branch_width,
                output_dim=branch_output,
                name="branch"
            )

            bias_config = dict(name="bias")

            model_config = dict(
                branch_config=branch_config,
                trunk_config=trunk_config,
                bias_config=bias_config
            )

            if not os.path.isdir(str(save_folder)):
                os.makedirs(str(save_folder))

            train_input, Y, test_input, Y_test = get_data()

            dummy_input = (jnp.expand_dims(train_input[0][0, :], 0), jnp.expand_dims(train_input[1][0, :], 0))

            hyperparameter_dict = dict(
                print_bool=print_bool,
                print_interval=print_interval,
                epochs=epochs,
                model_config=model_config,
                opt_choice=opt_choice,
                schedule_choice=schedule_choice,
                lr_dict=lr_dict,
                problem=problem,
                N=N,
                N_test=N_test,
                p=p,
                m=m,
                dtype=dtype,
                seed=seed
            )

            train_config = dict(
                dummy_input=dummy_input,
                train_input=train_input,
                test_input=test_input,
                Y=Y,
                Y_test=Y_test,
                model_name=model_name,
            ) | hyperparameter_dict

            with open(str(param_save_file_name), "rb") as f:
                print(str(param_save_file_name))
                weights = torch.load(f)

            X = train_input[1]

            model_key = jax.random.PRNGKey(seed)
            _, model_forward = get_model(train_config["model_name"], train_config["model_config"])
            model_forward = jax.jit(model_forward)

            tri = Triangulation(X[:, 0], X[:, 1])

            if plot_train_test == "train":
                train_pred = model_forward(weights, None, train_input)
                if to_plot == "gradient":
                    spatial_gradients = get_max_grad_over_funcs(model_forward, weights, train_input, 1)
                mse = jnp.power(train_pred - Y, 2).mean(0)

            elif plot_train_test == "test":
                test_pred = model_forward(weights, None, test_input)
                if to_plot == "gradient":
                    spatial_gradients = get_max_grad_over_funcs(model_forward, weights, test_input, 1)
                mse = jnp.power(test_pred - Y_test, 2).mean(0)

            if to_plot == "gradient":
                arr.append(spatial_gradients)
            elif to_plot == "mse":
                arr.append(mse)

        plot_arr = jnp.asarray(arr).mean(0)
        save_file_name = str(fstr(plot_save_folder+"vanilla_{plot_train_test}_heatmap_{to_plot}_p={p}.png"))
        title = "Vanilla"
        title = "Spatial Gradient" if to_plot == "gradient" else title
        plot_helper(tri, plot_arr, title, save_file_name, clim_lower, clim_upper, axis_font, tick_fontsize, title_font, figsize)

def plot_pod():
    project = "pod_deeponet"
    model_name = "pod_deeponet_cartesian_prod"
    save_file_name = fstr(save_folder.payload + "{project}_p={p}_seed={seed}.json")
    param_save_file_name = fstr(save_folder.payload + "{project}_p={p}_seed={seed}.pickle")
    log_str = fstr("pod-deeponet on {problem} with {opt_choice} p={p} mean_bool={mean_bool}")

    trunk_input = 2
    branch_input = m

    for mean_bool in mean_bools:
        for p in pod_ps:
            arr = []
            print(log_str)
            for seed in seeds:
                trunk_output = p
                branch_output = p

                trunk_config = dict(
                    activation=activation_choice,
                    last_layer_activate=True,
                    num_hidden_layers=trunk_depth,
                    nodes=trunk_width,
                    output_dim=trunk_output,
                    name="trunk"
                )

                branch_config = dict(
                    activation=activation_choice,
                    num_hidden_layers=branch_depth,
                    nodes=branch_width,
                    output_dim=branch_output,
                    name="branch"
                )

                bias_config = dict(name="bias")

                if not os.path.isdir(str(save_folder)):
                    os.makedirs(str(save_folder))

                train_input, Y, test_input, Y_test = get_data()

                X = train_input[1]

                # applying POD
                train_trunk_input, x_mean = form_pod_basis(Y, p, not mean_bool)
                x_mean = x_mean.T

                # putting pod basis back in tuples
                train_input = (train_input[0], train_trunk_input)
                test_input = (test_input[0], train_trunk_input)

                dummy_input = (train_input[0], train_input[1])

                model_config = dict(
                    branch_config=branch_config,
                    trunk_config=trunk_config,
                    bias_config=bias_config,
                    pod_mean=x_mean,
                    mean_bool=mean_bool,
                    p=p
                )

                hyperparameter_dict = dict(
                    print_bool=print_bool,
                    print_interval=print_interval,
                    epochs=epochs,
                    model_config=model_config,
                    opt_choice=opt_choice,
                    schedule_choice=schedule_choice,
                    lr_dict=lr_dict,
                    problem=problem,
                    N=N,
                    N_test=N_test,
                    p=p,
                    m=m,
                    dtype=dtype,
                    seed=seed
                )

                train_config = dict(
                    dummy_input=dummy_input,
                    train_input=train_input,
                    test_input=test_input,
                    Y=Y,
                    Y_test=Y_test,
                    model_name=model_name,
                ) | hyperparameter_dict

                with open(str(param_save_file_name), "rb") as f:
                    print(str(param_save_file_name))
                    weights = torch.load(f)

                model_key = jax.random.PRNGKey(seed)
                _, model_forward = get_model(train_config["model_name"], train_config["model_config"])
                model_forward = jax.jit(model_forward)

                tri = Triangulation(X[:, 0], X[:, 1])

                if plot_train_test == "train":
                    train_pred = model_forward(weights, None, train_input)
                    if to_plot == "gradient":
                        spatial_gradients = get_max_grad_over_funcs(model_forward, weights, train_input, 1)
                    mse = jnp.power(train_pred - Y, 2).mean(0)

                elif plot_train_test == "test":
                    test_pred = model_forward(weights, None, test_input)
                    if to_plot == "gradient":
                        spatial_gradients = get_max_grad_over_funcs(model_forward, weights, test_input, 1)
                    mse = jnp.power(test_pred - Y_test, 2).mean(0)

                if to_plot == "gradient":
                    arr.append(spatial_gradients)
                elif to_plot == "mse":
                    arr.append(mse)

            plot_arr = jnp.asarray(arr).mean(0)
            save_file_name = str(fstr(plot_save_folder+"pod_mean_bool={mean_bool}_{plot_train_test}_heatmap_{to_plot}_p={p}.png"))
            title = "POD" if mean_bool else "Modified POD"
            title = "Spatial Gradient" if to_plot == "gradient" else title
            plot_helper(tri, plot_arr, title, save_file_name, clim_lower, clim_upper, axis_font, tick_fontsize, title_font, figsize)


def plot_ensemble_vanilla_sparse():
    project = "ensemble_vanilla_sparse_trunk_deeponet"
    model_name = "ensemble_vanilla_sparse_trunk_deeponet_cartesian_prod"

    save_file_name = fstr(save_folder.payload + "{project}_M={M}_p={p}_seed={seed}.json")
    param_save_file_name = fstr(save_folder.payload + "{project}_M={M}_p={p}_seed={seed}.pickle")
    log_str = fstr("ensemble-vanilla-sparse-trunk-deeponet on {problem} with {opt_choice} M={M} p={p}")

    trunk_input = 2
    branch_input = m

    for p in ps:
        arr = []
        print(log_str)
        print("+"*100)

        for seed in seeds:
            trunk_output = p
            branch_output = p*2

            branch_config = dict(
                activation=activation_choice,
                num_hidden_layers=branch_depth,
                nodes=branch_width,
                output_dim=branch_output,
                name="branch"
            )

            bias_config = dict(name="bias")

            if not os.path.isdir(str(save_folder)):
                os.makedirs(str(save_folder))

            train_input, Y, test_input, Y_test = get_data()

            # partitioning ============================================================
            # TRAINING
            train_trunk_input = train_input[1]
            test_trunk_input = test_input[1]

            pu_config = dict(
                num_partitions=M,
                dim=train_trunk_input.shape[-1]
            )

            pu_obj = PU(pu_config)

            pu_obj.partition_domain(train_trunk_input, centers=centers_list, radius=radius_list)
            K, M_changed, _ = pu_obj.check_partioning(train_trunk_input, change_M=True, min_partitions_per_point=min_partitions_per_point)
            pu_trunk_width = int(1/np.sqrt(K) * trunk_width)
            num_groups, participation_idx, points_per_group, indices_per_group, _, radius_arrs = pu_obj.form_points_per_group(train_trunk_input)
            weights_per_group = pu_obj.form_weights_per_group(points_per_group=points_per_group, participation_idx=participation_idx, radius_arrs=radius_arrs)
            weights_per_group = [jnp.broadcast_to(w.reshape(*w.shape, 1), (w.shape[0], w.shape[1], p)) for w in weights_per_group]

            # TEST
            _, M_test, _ = pu_obj.check_partioning(test_trunk_input, False, min_partitions_per_point=min_partitions_per_point)
            test_num_groups, test_participation_idx, test_points_per_group, test_indices_per_group, _, test_radius_arrs = pu_obj.form_points_per_group(test_trunk_input)
            test_weights_per_group = pu_obj.form_weights_per_group(points_per_group=test_points_per_group, participation_idx=test_participation_idx, radius_arrs=test_radius_arrs)
            test_weights_per_group = [jnp.broadcast_to(w.reshape(*w.shape, 1), (w.shape[0], w.shape[1], p)) for w in test_weights_per_group]

            # reassembling training and test and dummy tuples
            train_input = (train_input[0], train_input[1], points_per_group, weights_per_group, indices_per_group, participation_idx, num_groups)
            test_input = (test_input[0], test_input[1], test_points_per_group, test_weights_per_group, test_indices_per_group, test_participation_idx, test_num_groups)
            dummy_input = (train_input[0], train_input[1])

            pu_trunk_config = dict(
                last_layer_activate=False,
                input_dim=trunk_input,
                output_dim=trunk_output,
                name="pu_trunk"
            )

            trunk_config_list = []
            for trunk_idx in range(M_changed):
                temp_config = copy.deepcopy(pu_trunk_config)
                temp_config["activation"] = activation_choice
                temp_config["name"] = f"pu_{trunk_idx}_trunk"
                if trunk_idx < trunk_special_index:
                    temp_config["layer_sizes"] = special_pu_trunk_layer_sizes
                trunk_config_list.append(temp_config)

            vanilla_trunk_config = dict(
                activation=activation_choice,
                last_layer_activate=True,  # in stacking, this should be activated
                num_hidden_layers=trunk_depth,
                nodes=trunk_width,
                output_dim=trunk_output,
                name="vanilla_trunk"
            )

            linear_config = dict(
                activation=activation_choice,
                num_hidden_layers=0,
                nodes=0,
                output_dim=p,
                name="final_linear"
            )

            model_config = dict(
                branch_config=branch_config,
                linear_config=linear_config,
                trunk_config_list=trunk_config_list,
                vanilla_trunk_config=vanilla_trunk_config,
                bias_config=bias_config,
                num_partitions=M_changed
            )

            hyperparameter_dict = dict(
                print_bool=print_bool,
                print_interval=print_interval,
                epochs=epochs,
                model_config=model_config,
                opt_choice=opt_choice,
                schedule_choice=schedule_choice,
                lr_dict=lr_dict,
                problem=problem,
                N=N,
                N_test=N_test,
                p=p,
                M=M_changed,
                K=K,
                train_num_partitions=M_changed,
                test_num_partitions=M_test,
                train_num_groups=num_groups,
                test_num_groups=test_num_groups,
                m=m,
                dtype=dtype,
                seed=seed
            )

            train_config = dict(
                dummy_input=dummy_input,
                train_input=train_input,
                test_input=test_input,
                Y=Y,
                Y_test=Y_test,
                model_name=model_name,
            ) | hyperparameter_dict

            with open(str(param_save_file_name), "rb") as f:
                print(str(param_save_file_name))
                weights = torch.load(f)

            X = train_input[1]

            model_key = jax.random.PRNGKey(seed)
            _, model_forward = get_model(train_config["model_name"], train_config["model_config"])
            model_forward = jax.jit(model_forward, static_argnums=(3, 4))

            tri = Triangulation(X[:, 0], X[:, 1])

            if plot_train_test == "train":
                train_pred = model_forward(weights, None, train_input, num_groups, M_changed)
                if to_plot == "gradient":
                    spatial_gradients = get_max_grad_over_funcs(model_forward, weights, train_input, 1, num_groups, M_changed)
                mse = jnp.power(train_pred - Y, 2).mean(0)

            elif plot_train_test == "test":
                test_pred = model_forward(weights, None, test_input, test_num_groups, M_test)
                if to_plot == "gradient":
                    spatial_gradients = get_max_grad_over_funcs(model_forward, weights, test_input, 1, test_num_groups, M_test)
                mse = jnp.power(test_pred - Y_test, 2).mean(0)

            if to_plot == "gradient":
                arr.append(spatial_gradients)
            elif to_plot == "mse":
                arr.append(mse)

        plot_arr = jnp.asarray(arr).mean(0)
        save_file_name = str(fstr(plot_save_folder+"ensemble_vanilla_sparse_trunk_{plot_train_test}_heatmap_{to_plot}_p={p}.png"))
        title = "(P+1)-Vanilla"
        title = "Spatial Gradient" if to_plot == "gradient" else title
        plot_helper(tri, plot_arr, title, save_file_name, clim_lower, clim_upper, axis_font, tick_fontsize, title_font, figsize)


def plot_ensemble_pod_sparse():
    project = "ensemble_pod_sparse_trunk_deeponet"
    model_name = "ensemble_pod_sparse_trunk_deeponet_cartesian_prod"

    save_file_name = fstr(save_folder.payload + "{project}_M={M}_p={p}_seed={seed}.json")
    param_save_file_name = fstr(save_folder.payload + "{project}_M={M}_p={p}_seed={seed}.pickle")
    log_str = fstr("ensemble-pod-sparse-trunk-deeponet on {problem} with {opt_choice} M={M} p={p}")

    trunk_input = 2
    branch_input = m

    for p in ps:
        arr = []
        print("+"*100)
        print(log_str)
        for seed in seeds:
            trunk_output = p
            branch_output = p + pod_p

            branch_config = dict(
                activation=activation_choice,
                num_hidden_layers=branch_depth,
                nodes=branch_width,
                output_dim=branch_output,
                name="branch"
            )

            bias_config = dict(name="bias")

            if not os.path.isdir(str(save_folder)):
                os.makedirs(str(save_folder))

            train_input, Y, test_input, Y_test = get_data()

            trunk_pod_basis, x_mean = form_pod_basis(Y, pod_p)
            x_mean = x_mean.T

            # partitioning ============================================================
            # TRAINING
            train_trunk_input = train_input[1]
            test_trunk_input = test_input[1]

            pu_config = dict(
                num_partitions=M,
                dim=train_trunk_input.shape[-1]
            )

            pu_obj = PU(pu_config)

            pu_obj.partition_domain(train_trunk_input, centers=centers_list, radius=radius_list)
            K, M_changed, _ = pu_obj.check_partioning(train_trunk_input, change_M=True, min_partitions_per_point=min_partitions_per_point)
            pu_trunk_width = int(1/np.sqrt(K) * trunk_width)
            num_groups, participation_idx, points_per_group, indices_per_group, _, radius_arrs = pu_obj.form_points_per_group(train_trunk_input)
            weights_per_group = pu_obj.form_weights_per_group(points_per_group=points_per_group, participation_idx=participation_idx, radius_arrs=radius_arrs)
            weights_per_group = [jnp.broadcast_to(w.reshape(*w.shape, 1), (w.shape[0], w.shape[1], p)) for w in weights_per_group]

            # TEST
            _, M_test, _ = pu_obj.check_partioning(test_trunk_input, False, min_partitions_per_point=min_partitions_per_point)
            test_num_groups, test_participation_idx, test_points_per_group, test_indices_per_group, _, test_radius_arrs = pu_obj.form_points_per_group(test_trunk_input)
            test_weights_per_group = pu_obj.form_weights_per_group(points_per_group=test_points_per_group, participation_idx=test_participation_idx, radius_arrs=test_radius_arrs)
            test_weights_per_group = [jnp.broadcast_to(w.reshape(*w.shape, 1), (w.shape[0], w.shape[1], p)) for w in test_weights_per_group]

            # reassembling training and test and dummy tuples
            train_input = (train_input[0], train_input[1], points_per_group, weights_per_group, indices_per_group, participation_idx, num_groups)
            test_input = (test_input[0], test_input[1], test_points_per_group, test_weights_per_group, test_indices_per_group, test_participation_idx, test_num_groups)

            dummy_input = (train_input[0], train_input[1])

            pu_trunk_config = dict(
                last_layer_activate=False,
                input_dim=trunk_input,
                output_dim=trunk_output,
                name="pu_trunk"
            )

            trunk_config_list = []
            for trunk_idx in range(M_changed):
                temp_config = copy.deepcopy(pu_trunk_config)
                temp_config["activation"] = activation_choice
                temp_config["name"] = f"pu_{trunk_idx}_trunk"
                if trunk_idx < trunk_special_index:
                    temp_config["layer_sizes"] = special_pu_trunk_layer_sizes
                trunk_config_list.append(temp_config)

            model_config = dict(
                branch_config=branch_config,
                trunk_config_list=trunk_config_list,
                bias_config=bias_config,
                num_partitions=M_changed,
                trunk_pod_basis=trunk_pod_basis,
                pod_mean=x_mean,
                p=p,
                pod_p=pod_p
            )

            hyperparameter_dict = dict(
                print_bool=print_bool,
                print_interval=print_interval,
                epochs=epochs,
                model_config=model_config,
                opt_choice=opt_choice,
                schedule_choice=schedule_choice,
                lr_dict=lr_dict,
                problem=problem,
                N=N,
                N_test=N_test,
                p=p,
                M=M_changed,
                K=K,
                train_num_partitions=M_changed,
                test_num_partitions=M_test,
                train_num_groups=num_groups,
                test_num_groups=test_num_groups,
                m=m,
                dtype=dtype,
                seed=seed
            )

            train_config = dict(
                dummy_input=dummy_input,
                train_input=train_input,
                test_input=test_input,
                Y=Y,
                Y_test=Y_test,
                model_name=model_name,
            ) | hyperparameter_dict

            with open(str(param_save_file_name), "rb") as f:
                print(str(param_save_file_name))
                weights = torch.load(f)

            X = train_input[1]

            model_key = jax.random.PRNGKey(seed)
            _, model_forward = get_model(train_config["model_name"], train_config["model_config"])
            model_forward = jax.jit(model_forward, static_argnums=(3, 4))

            tri = Triangulation(X[:, 0], X[:, 1])
            if plot_train_test == "train":
                train_pred = model_forward(weights, None, train_input, num_groups, M_changed)
                if to_plot == "gradient":
                    spatial_gradients = get_max_grad_over_funcs(model_forward, weights, train_input, 1, num_groups, M_changed)
                mse = jnp.power(train_pred - Y, 2).mean(0)

            elif plot_train_test == "test":
                test_pred = model_forward(weights, None, test_input, test_num_groups, M_test)
                if to_plot == "gradient":
                    spatial_gradients = get_max_grad_over_funcs(model_forward, weights, test_input, 1, test_num_groups, M_test)
                mse = jnp.power(test_pred - Y_test, 2).mean(0)

            if to_plot == "gradient":
                arr.append(spatial_gradients)
            elif to_plot == "mse":
                arr.append(mse)

        plot_arr = jnp.asarray(arr).mean(0)
        save_file_name = str(fstr(plot_save_folder+"ensemble_pod_sparse_trunk_{plot_train_test}_heatmap_{to_plot}_p={p}.png"))
        title = "(POD, PoU)"
        title = "Spatial Gradient" if to_plot == "gradient" else title
        plot_helper(tri, plot_arr, title, save_file_name, clim_lower, clim_upper, axis_font, tick_fontsize, title_font, figsize)


def ensemble_vanilla_trunk_plot():
    project = "ensemble_vanilla_trunk_deeponet"
    model_name = "ensemble_vanilla_trunk_deeponet_cartesian_prod"

    save_file_name = fstr(save_folder.payload + "{project}_p={p}_seed={seed}.json")
    param_save_file_name = fstr(save_folder.payload + "{project}_p={p}_seed={seed}.pickle")
    log_str = fstr("ensemble-vanilla-trunk-deeponet on {problem} with {opt_choice} p={p}")

    trunk_input = 2
    branch_input = m

    for p in ps:
        arr = []
        print("+"*100)
        print(log_str)

        for seed in seeds:
            trunk_output = p
            branch_output = p * num_trunks

            branch_config = dict(
                activation=activation_choice,
                num_hidden_layers=branch_depth,
                nodes=branch_width,
                output_dim=branch_output,
                name="branch"
            )

            bias_config = dict(name="bias")

            if not os.path.isdir(str(save_folder)):
                os.makedirs(str(save_folder))

            train_input, Y, test_input, Y_test = get_data()

            dummy_input = (train_input[0], train_input[1])

            trunk_config = dict(
                activation=activation_choice,
                last_layer_activate=True,  # in this case, these trunks should be activated for the stacking
                num_hidden_layers=trunk_depth,
                nodes=trunk_width,
                output_dim=trunk_output,
                name="trunk"
            )

            model_config = dict(
                branch_config=branch_config,
                trunk_config=trunk_config,
                bias_config=bias_config,
                num_trunks=num_trunks
            )

            hyperparameter_dict = dict(
                print_bool=print_bool,
                print_interval=print_interval,
                epochs=epochs,
                model_config=model_config,
                opt_choice=opt_choice,
                schedule_choice=schedule_choice,
                lr_dict=lr_dict,
                problem=problem,
                N=N,
                N_test=N_test,
                p=p,
                m=m,
                dtype=dtype,
                seed=seed
            )

            train_config = dict(
                dummy_input=dummy_input,
                train_input=train_input,
                test_input=test_input,
                Y=Y,
                Y_test=Y_test,
                model_name=model_name,
            ) | hyperparameter_dict

            with open(str(param_save_file_name), "rb") as f:
                print(str(param_save_file_name))
                weights = torch.load(f)

            X = train_input[1]

            model_key = jax.random.PRNGKey(seed)
            _, model_forward = get_model(train_config["model_name"], train_config["model_config"])
            model_forward = jax.jit(model_forward)

            tri = Triangulation(X[:, 0], X[:, 1])

            if plot_train_test == "train":
                train_pred = model_forward(weights, None, train_input)
                if to_plot == "gradient":
                    spatial_gradients = get_max_grad_over_funcs(model_forward, weights, train_input, 1)
                mse = jnp.power(train_pred - Y, 2).mean(0)

            elif plot_train_test == "test":
                test_pred = model_forward(weights, None, test_input)
                if to_plot == "gradient":
                    spatial_gradients = get_max_grad_over_funcs(model_forward, weights, test_input, 1)
                mse = jnp.power(test_pred - Y_test, 2).mean(0)

            if to_plot == "gradient":
                arr.append(spatial_gradients)
            elif to_plot == "mse":
                arr.append(mse)

        plot_arr = jnp.asarray(arr).mean(0)
        save_file_name = str(fstr(plot_save_folder+"ensemble_vanilla_trunk_{plot_train_test}_heatmap_{to_plot}_p={p}.png"))
        title = "(Vanilla, Vanilla)"
        title = "Spatial Gradient" if to_plot == "gradient" else title
        plot_helper(tri, plot_arr, title, save_file_name, clim_lower, clim_upper, axis_font, tick_fontsize, title_font, figsize)


def ensemble_pod_vanilla_trunk_plot():
    project = "ensemble_pod_vanilla_trunk_deeponet"
    model_name = "ensemble_pod_vanilla_trunk_deeponet_cartesian_prod"

    save_file_name = fstr(save_folder.payload + "{project}_p={p}_seed={seed}.json")
    param_save_file_name = fstr(save_folder.payload + "{project}_p={p}_seed={seed}.pickle")
    log_str = fstr("ensemble-POD-vanilla-trunk-deeponet on {problem} with {opt_choice} p={p}")

    trunk_input = 2
    branch_input = m

    for p in ps:
        arr = []
        print("+"*100)
        print(log_str)
        for seed in seeds:
            trunk_output = p
            branch_output = p

            branch_config = dict(
                activation=activation_choice,
                num_hidden_layers=branch_depth,
                nodes=branch_width,
                output_dim=branch_output + pod_p,
                name="branch"
            )

            bias_config = dict(name="bias")

            if not os.path.isdir(str(save_folder)):
                os.makedirs(str(save_folder))

            train_input, Y, test_input, Y_test = get_data()

            trunk_pod_basis, x_mean = form_pod_basis(Y, pod_p)
            x_mean = x_mean.T

            dummy_input = (train_input[0], train_input[1])

            vanilla_trunk_config = dict(
                activation=activation_choice,
                last_layer_activate=True,  # should be activated when stacking
                num_hidden_layers=trunk_depth,
                nodes=trunk_width,
                output_dim=trunk_output,
                name="trunk"
            )

            model_config = dict(
                branch_config=branch_config,
                vanilla_trunk_config=vanilla_trunk_config,
                bias_config=bias_config,
                trunk_pod_basis=trunk_pod_basis,
                pod_mean=x_mean,
                p=p,
                pod_p=pod_p
            )

            hyperparameter_dict = dict(
                print_bool=print_bool,
                print_interval=print_interval,
                epochs=epochs,
                model_config=model_config,
                opt_choice=opt_choice,
                schedule_choice=schedule_choice,
                lr_dict=lr_dict,
                problem=problem,
                N=N,
                N_test=N_test,
                p=p,
                m=m,
                dtype=dtype,
                seed=seed
            )

            train_config = dict(
                dummy_input=dummy_input,
                train_input=train_input,
                test_input=test_input,
                Y=Y,
                Y_test=Y_test,
                model_name=model_name,
            ) | hyperparameter_dict

            with open(str(param_save_file_name), "rb") as f:
                print(str(param_save_file_name))
                weights = torch.load(f)

            X = train_input[1]

            model_key = jax.random.PRNGKey(seed)
            _, model_forward = get_model(train_config["model_name"], train_config["model_config"])
            model_forward = jax.jit(model_forward)

            tri = Triangulation(X[:, 0], X[:, 1])

            if plot_train_test == "train":
                train_pred = model_forward(weights, None, train_input)
                if to_plot == "gradient":
                    spatial_gradients = get_max_grad_over_funcs(model_forward, weights, train_input, 1)
                mse = jnp.power(train_pred - Y, 2).mean(0)

            elif plot_train_test == "test":
                test_pred = model_forward(weights, None, test_input)
                if to_plot == "gradient":
                    spatial_gradients = get_max_grad_over_funcs(model_forward, weights, test_input, 1)
                mse = jnp.power(test_pred - Y_test, 2).mean(0)

            if to_plot == "gradient":
                arr.append(spatial_gradients)
            elif to_plot == "mse":
                arr.append(mse)

        plot_arr = jnp.asarray(arr).mean(0)
        save_file_name = str(fstr(plot_save_folder+"ensemble_pod_vanilla_{plot_train_test}_heatmap_{to_plot}_p={p}.png"))
        title = "(Vanilla, POD)"
        title = "Sptial Gradient" if to_plot == "gradient" else title
        plot_helper(tri, plot_arr, title, save_file_name, clim_lower, clim_upper, axis_font, tick_fontsize, title_font, figsize)


def ensemble_vanilla_pod_sparse_trunk_plot():
    project = "ensemble_vanilla_pod_sparse_trunk_deeponet"
    model_name = "ensemble_vanilla_pod_sparse_trunk_deeponet_cartesian_prod"

    save_file_name = fstr(save_folder.payload + "{project}_M={M}_p={p}_seed={seed}.json")
    param_save_file_name = fstr(save_folder.payload + "{project}_M={M}_p={p}_seed={seed}.pickle")
    log_str = fstr("ensemble-vanilla-POD-sparse-trunk-deeponet on {problem} with {opt_choice} M={M} p={p}")

    trunk_input = 2
    branch_input = m

    for p in ps:
        print("+"*100)
        print(log_str)
        arr = []
        for seed in seeds:
            trunk_output = p
            branch_output = p

            branch_config = dict(
                activation=activation_choice,
                num_hidden_layers=branch_depth,
                nodes=branch_width,
                output_dim=branch_output*2 + pod_p,
                name="branch"
            )

            bias_config = dict(name="bias")

            if not os.path.isdir(str(save_folder)):
                os.makedirs(str(save_folder))

            train_input, Y, test_input, Y_test = get_data()

            trunk_pod_basis, x_mean = form_pod_basis(Y, pod_p)
            x_mean = x_mean.T

            # partitioning ============================================================
            # TRAINING
            train_trunk_input = train_input[1]
            test_trunk_input = test_input[1]

            pu_config = dict(
                num_partitions=M,
                dim=train_trunk_input.shape[-1]
            )

            pu_obj = PU(pu_config)

            pu_obj.partition_domain(train_trunk_input, centers=centers_list, radius=radius_list)
            K, M_changed, _ = pu_obj.check_partioning(train_trunk_input, change_M=True, min_partitions_per_point=min_partitions_per_point)
            pu_trunk_width = int(1/np.sqrt(K) * trunk_width)
            num_groups, participation_idx, points_per_group, indices_per_group, _, radius_arrs = pu_obj.form_points_per_group(train_trunk_input)
            weights_per_group = pu_obj.form_weights_per_group(points_per_group=points_per_group, participation_idx=participation_idx, radius_arrs=radius_arrs)
            weights_per_group = [jnp.broadcast_to(w.reshape(*w.shape, 1), (w.shape[0], w.shape[1], p)) for w in weights_per_group]

            # TEST
            _, M_test, _ = pu_obj.check_partioning(test_trunk_input, False, min_partitions_per_point=min_partitions_per_point)
            test_num_groups, test_participation_idx, test_points_per_group, test_indices_per_group, _, test_radius_arrs = pu_obj.form_points_per_group(test_trunk_input)
            test_weights_per_group = pu_obj.form_weights_per_group(points_per_group=test_points_per_group, participation_idx=test_participation_idx, radius_arrs=test_radius_arrs)
            test_weights_per_group = [jnp.broadcast_to(w.reshape(*w.shape, 1), (w.shape[0], w.shape[1], p)) for w in test_weights_per_group]

            # reassembling training and test and dummy tuples
            train_input = (train_input[0], train_input[1], points_per_group, weights_per_group, indices_per_group, participation_idx, num_groups)
            test_input = (test_input[0], test_input[1], test_points_per_group, test_weights_per_group, test_indices_per_group, test_participation_idx, test_num_groups)

            dummy_input = (train_input[0], train_input[1])

            pu_trunk_config = dict(
                last_layer_activate=False,
                input_dim=trunk_input,
                output_dim=trunk_output,
                name="pu_trunk"
            )

            trunk_config_list = []
            for trunk_idx in range(M_changed):
                temp_config = copy.deepcopy(pu_trunk_config)
                temp_config["activation"] = activation_choice
                temp_config["name"] = f"pu_{trunk_idx}_trunk"
                if trunk_idx < trunk_special_index:
                    temp_config["layer_sizes"] = special_pu_trunk_layer_sizes
                else:
                    temp_config["layer_sizes"] = [trunk_width for _ in range(trunk_depth)]
                trunk_config_list.append(temp_config)

            vanilla_trunk_config = dict(
                activation=activation_choice,
                last_layer_activate=True,  # should be activated when stacking
                num_hidden_layers=trunk_depth,
                nodes=trunk_width,
                output_dim=trunk_output,
                name="trunk"
            )

            model_config = dict(
                branch_config=branch_config,
                trunk_config_list=trunk_config_list,
                vanilla_trunk_config=vanilla_trunk_config,
                bias_config=bias_config,
                num_partitions=M_changed,
                trunk_pod_basis=trunk_pod_basis,
                pod_mean=x_mean,
                p=p,
                pod_p=pod_p
            )

            hyperparameter_dict = dict(
                print_bool=print_bool,
                print_interval=print_interval,
                epochs=epochs,
                model_config=model_config,
                opt_choice=opt_choice,
                schedule_choice=schedule_choice,
                lr_dict=lr_dict,
                problem=problem,
                N=N,
                N_test=N_test,
                p=p,
                M=M_changed,
                K=K,
                train_num_partitions=M_changed,
                test_num_partitions=M_test,
                train_num_groups=num_groups,
                test_num_groups=test_num_groups,
                m=m,
                dtype=dtype,
                seed=seed
            )

            train_config = dict(
                dummy_input=dummy_input,
                train_input=train_input,
                test_input=test_input,
                Y=Y,
                Y_test=Y_test,
                model_name=model_name,
            ) | hyperparameter_dict

            with open(str(param_save_file_name), "rb") as f:
                print(str(param_save_file_name))
                weights = torch.load(f)

            X = train_input[1]

            model_key = jax.random.PRNGKey(seed)
            _, model_forward = get_model(train_config["model_name"], train_config["model_config"])
            model_forward = jax.jit(model_forward, static_argnums=(3, 4))

            tri = Triangulation(X[:, 0], X[:, 1])

            if plot_train_test == "train":
                train_pred = model_forward(weights, None, train_input, num_groups, M_changed)
                if to_plot == "gradient":
                    spatial_gradients = get_max_grad_over_funcs(model_forward, weights, train_input, 1, num_groups, M_changed)
                mse = jnp.power(train_pred - Y, 2).mean(0)

            elif plot_train_test == "test":
                test_pred = model_forward(weights, None, test_input, test_num_groups, M_test)
                if to_plot == "gradient":
                    spatial_gradients = get_max_grad_over_funcs(model_forward, weights, test_input, 1, test_num_groups, M_test)
                mse = jnp.power(test_pred - Y_test, 2).mean(0)

            if to_plot == "gradient":
                arr.append(spatial_gradients)
            elif to_plot == "mse":
                arr.append(mse)

        plot_arr = jnp.asarray(arr).mean(0)
        save_file_name = str(fstr(plot_save_folder+"ensemble_vanilla_pod_sparse_trunk_{plot_train_test}_heatmap_{to_plot}_p={p}.png"))
        title = "(Vanilla, POD, PoU)"
        title = "Spatial Gradient" if to_plot == "gradient" else title
        plot_helper(tri, plot_arr, title, save_file_name, clim_lower, clim_upper, axis_font, tick_fontsize, title_font, figsize)


def get_data():
    """
    Please refer to the paper for the datasets used.
    """
    num_train = N
    num_test = N_test

    s = m
    r = 861

    try:
        data = io.loadmat(str(dataset_folder)+"Darcy_Triangular.mat")
    except:
        raise BaseException("\n\nPlease place the Darcy_Triangular.mat file in the dataset/Darcy_triangular folder.")

    f_train = data['f_bc'][:num_train, :]
    u_train = data['u_field'][:num_train, :]

    f_test = data['f_bc'][num_train:, :]
    u_test = data['u_field'][num_train:, :]

    xx = data['xx']
    yy = data['yy']
    xx = np.reshape(xx, (-1, 1))
    yy = np.reshape(yy, (-1, 1))
    X = np.hstack((xx, yy))

    f_train_mean = np.mean(np.reshape(f_train, (-1, s)), 0)
    f_train_std = np.std(np.reshape(f_train, (-1, s)), 0)
    u_train_mean = np.mean(np.reshape(u_train, (-1, r)), 0)
    u_train_std = np.std(np.reshape(u_train, (-1, r)), 0)

    f_train_mean = np.reshape(f_train_mean, (-1, 1, s))
    f_train_std = np.reshape(f_train_std, (-1, 1, s))
    u_train_mean = np.reshape(u_train_mean, (-1, r, 1))
    u_train_std = np.reshape(u_train_std, (-1, r, 1))

    num_res = r
    F_train = np.reshape(f_train, (-1, 1, s))
    F_train = (F_train - f_train_mean)/(f_train_std + 1.0e-9)
    U_train = np.reshape(u_train, (-1, num_res, 1))
    U_train = (U_train - u_train_mean)/(u_train_std + 1.0e-9)

    F_test = np.reshape(f_test, (-1, 1, s))
    F_test = (F_test - f_train_mean)/(f_train_std + 1.0e-9)
    U_test = np.reshape(u_test, (-1, num_res, 1))
    U_test = (U_test - u_train_mean)/(u_train_std + 1.0e-9)

    F_train = F_train.astype(dtype)
    F_test = F_test.astype(dtype)
    U_train = U_train.astype(dtype)
    U_test = U_test.astype(dtype)
    X = X.astype(dtype)

    train_input = (F_train.squeeze(), X)
    test_input = (F_test.squeeze(), X)

    return train_input, U_train.squeeze(), test_input, U_test.squeeze()


def vanilla_deeponet_run():
    project = "vanilla_deeponet"
    model_name = "deeponet_cartesian_prod"
    save_file_name = fstr(save_folder.payload + "{project}_p={p}_seed={seed}.json")
    param_save_file_name = fstr(save_folder.payload + "{project}_p={p}_seed={seed}.pickle")
    log_str = fstr("vanilla-deeponet on {problem} with {opt_choice} p={p} seed={seed}")

    trunk_input = 2
    branch_input = m

    for seed in seeds:
        for p in ps:
            trunk_output = p
            branch_output = p

            trunk_config = dict(
                activation=activation_choice,
                last_layer_activate=True,
                num_hidden_layers=trunk_depth,
                nodes=trunk_width,
                output_dim=trunk_output,
                name="trunk"
            )

            branch_config = dict(
                activation=activation_choice,
                num_hidden_layers=branch_depth,
                nodes=branch_width,
                output_dim=branch_output,
                name="branch"
            )

            bias_config = dict(name="bias")

            model_config = dict(
                branch_config=branch_config,
                trunk_config=trunk_config,
                bias_config=bias_config
            )

            if not os.path.isdir(str(save_folder)):
                os.makedirs(str(save_folder))

            print("\n\n")
            print("+"*100)
            print(log_str)
            print("+"*100)
            print("\n\n")

            train_input, Y, test_input, Y_test = get_data()

            dummy_input = (jnp.expand_dims(train_input[0][0, :], 0), jnp.expand_dims(train_input[1][0, :], 0))

            hyperparameter_dict = dict(
                print_bool=print_bool,
                print_interval=print_interval,
                epochs=epochs,
                model_config=model_config,
                opt_choice=opt_choice,
                schedule_choice=schedule_choice,
                lr_dict=lr_dict,
                problem=problem,
                N=N,
                N_test=N_test,
                p=p,
                m=m,
                dtype=dtype,
                seed=seed
            )

            train_config = dict(
                dummy_input=dummy_input,
                train_input=train_input,
                test_input=test_input,
                Y=Y,
                Y_test=Y_test,
                model_name=model_name,
            ) | hyperparameter_dict

            logged_results, trained_params = train_func(train_config)
            logged_results = logged_results | hyperparameter_dict

            if save_results_bool:
                torch.save(trained_params, str(param_save_file_name))
                save_results(logged_results, str(save_file_name))


def pod_deeponet_run():
    project = "pod_deeponet"
    model_name = "pod_deeponet_cartesian_prod"
    save_file_name = fstr(save_folder.payload + "{project}_p={p}_seed={seed}.json")
    param_save_file_name = fstr(save_folder.payload + "{project}_p={p}_seed={seed}.pickle")
    log_str = fstr("pod-deeponet on {problem} with {opt_choice} p={p} seed={seed}, mean_bool={mean_bool}")

    trunk_input = 2
    branch_input = m

    for seed in seeds:
        for mean_bool in mean_bools:
            for p in pod_ps:
                trunk_output = p
                branch_output = p

                trunk_config = dict(
                    activation=activation_choice,
                    last_layer_activate=True,
                    num_hidden_layers=trunk_depth,
                    nodes=trunk_width,
                    output_dim=trunk_output,
                    name="trunk"
                )

                branch_config = dict(
                    activation=activation_choice,
                    num_hidden_layers=branch_depth,
                    nodes=branch_width,
                    output_dim=branch_output,
                    name="branch"
                )

                bias_config = dict(name="bias")

                if not os.path.isdir(str(save_folder)):
                    os.makedirs(str(save_folder))

                print("\n\n")
                print("+"*100)
                print(log_str)
                print("+"*100)
                print("\n\n")

                train_input, Y, test_input, Y_test = get_data()

                # applying POD
                train_trunk_input, x_mean = form_pod_basis(Y, p, not mean_bool)
                x_mean = x_mean.T

                # putting pod basis back in tuples
                train_input = (train_input[0], train_trunk_input)
                test_input = (test_input[0], train_trunk_input)

                dummy_input = (train_input[0], train_input[1])

                model_config = dict(
                    branch_config=branch_config,
                    trunk_config=trunk_config,
                    bias_config=bias_config,
                    pod_mean=x_mean,
                    mean_bool=mean_bool,
                    p=p
                )

                hyperparameter_dict = dict(
                    print_bool=print_bool,
                    print_interval=print_interval,
                    epochs=epochs,
                    model_config=model_config,
                    opt_choice=opt_choice,
                    schedule_choice=schedule_choice,
                    lr_dict=lr_dict,
                    problem=problem,
                    N=N,
                    N_test=N_test,
                    p=p,
                    m=m,
                    dtype=dtype,
                    seed=seed
                )

                train_config = dict(
                    dummy_input=dummy_input,
                    train_input=train_input,
                    test_input=test_input,
                    Y=Y,
                    Y_test=Y_test,
                    model_name=model_name,
                ) | hyperparameter_dict

                logged_results, trained_params = train_func(train_config)

                # removing tensors from dict so results can be saved
                del hyperparameter_dict["model_config"]["pod_mean"]

                logged_results = logged_results | hyperparameter_dict

                if save_results_bool:
                    torch.save(trained_params, str(param_save_file_name))
                    save_results(logged_results, str(save_file_name))


def ensemble_vanilla_trunk_deeponet_run():
    project = "ensemble_vanilla_trunk_deeponet"
    model_name = "ensemble_vanilla_trunk_deeponet_cartesian_prod"

    save_file_name = fstr(save_folder.payload + "{project}_p={p}_seed={seed}.json")
    param_save_file_name = fstr(save_folder.payload + "{project}_p={p}_seed={seed}.pickle")
    log_str = fstr("ensemble-vanilla-trunk-deeponet on {problem} with {opt_choice} p={p} seed={seed}")

    trunk_input = 2
    branch_input = m

    for seed in seeds:
        for p in ps:
            trunk_output = p
            branch_output = p * num_trunks

            branch_config = dict(
                activation=activation_choice,
                num_hidden_layers=branch_depth,
                nodes=branch_width,
                output_dim=branch_output,
                name="branch"
            )

            bias_config = dict(name="bias")

            if not os.path.isdir(str(save_folder)):
                os.makedirs(str(save_folder))

            print("\n\n")
            print("+"*100)
            print(log_str)
            print("+"*100)
            print("\n\n")

            train_input, Y, test_input, Y_test = get_data()

            dummy_input = (jnp.expand_dims(train_input[0][0, :], 0), jnp.expand_dims(train_input[1][0, :], 0))

            trunk_config = dict(
                activation=activation_choice,
                last_layer_activate=True,  # in this case, these trunks should be activated for the stacking
                num_hidden_layers=trunk_depth,
                nodes=trunk_width,
                output_dim=trunk_output,
                name="trunk"
            )

            model_config = dict(
                branch_config=branch_config,
                trunk_config=trunk_config,
                bias_config=bias_config,
                num_trunks=num_trunks
            )

            hyperparameter_dict = dict(
                print_bool=print_bool,
                print_interval=print_interval,
                epochs=epochs,
                model_config=model_config,
                opt_choice=opt_choice,
                schedule_choice=schedule_choice,
                lr_dict=lr_dict,
                problem=problem,
                N=N,
                N_test=N_test,
                p=p,
                m=m,
                dtype=dtype,
                seed=seed
            )

            train_config = dict(
                dummy_input=dummy_input,
                train_input=train_input,
                test_input=test_input,
                Y=Y,
                Y_test=Y_test,
                model_name=model_name,
            ) | hyperparameter_dict

            logged_results, trained_params = train_func(train_config)
            logged_results = logged_results | hyperparameter_dict

            if save_results_bool:
                torch.save(trained_params, str(param_save_file_name))
                save_results(logged_results, str(save_file_name))


def ensemble_pod_vanilla_trunk_deeponet_run():
    project = "ensemble_pod_vanilla_trunk_deeponet"
    model_name = "ensemble_pod_vanilla_trunk_deeponet_cartesian_prod"

    save_file_name = fstr(save_folder.payload + "{project}_p={p}_seed={seed}.json")
    param_save_file_name = fstr(save_folder.payload + "{project}_p={p}_seed={seed}.pickle")
    log_str = fstr("ensemble-POD-vanilla-trunk-deeponet on {problem} with {opt_choice} p={p} seed={seed}")

    trunk_input = 2
    branch_input = m

    for seed in seeds:
        for p in ps:
            trunk_output = p
            branch_output = p

            branch_config = dict(
                activation=activation_choice,
                num_hidden_layers=branch_depth,
                nodes=branch_width,
                output_dim=branch_output + pod_p,
                name="branch"
            )

            bias_config = dict(name="bias")

            if not os.path.isdir(str(save_folder)):
                os.makedirs(str(save_folder))

            print("\n\n")
            print("+"*100)
            print(log_str)
            print("+"*100)
            print("\n\n")

            train_input, Y, test_input, Y_test = get_data()

            trunk_pod_basis, x_mean = form_pod_basis(Y, pod_p)
            x_mean = x_mean.T

            dummy_input = (train_input[0], train_input[1])

            vanilla_trunk_config = dict(
                activation=activation_choice,
                last_layer_activate=True,  # should be activated when stacking
                num_hidden_layers=trunk_depth,
                nodes=trunk_width,
                output_dim=trunk_output,
                name="trunk"
            )

            model_config = dict(
                branch_config=branch_config,
                vanilla_trunk_config=vanilla_trunk_config,
                bias_config=bias_config,
                trunk_pod_basis=trunk_pod_basis,
                pod_mean=x_mean,
                p=p,
                pod_p=pod_p
            )

            hyperparameter_dict = dict(
                print_bool=print_bool,
                print_interval=print_interval,
                epochs=epochs,
                model_config=model_config,
                opt_choice=opt_choice,
                schedule_choice=schedule_choice,
                lr_dict=lr_dict,
                problem=problem,
                N=N,
                N_test=N_test,
                p=p,
                m=m,
                dtype=dtype,
                seed=seed
            )

            train_config = dict(
                dummy_input=dummy_input,
                train_input=train_input,
                test_input=test_input,
                Y=Y,
                Y_test=Y_test,
                model_name=model_name,
            ) | hyperparameter_dict

            logged_results, trained_params = train_func(train_config)

            # removing tensors from dict so results can be saved
            del hyperparameter_dict["model_config"]["pod_mean"]
            del hyperparameter_dict["model_config"]["trunk_pod_basis"]

            logged_results = logged_results | hyperparameter_dict

            if save_results_bool:
                torch.save(trained_params, str(param_save_file_name))
                save_results(logged_results, str(save_file_name))


def ensemble_vanilla_sparse_trunk_deeponet_run():
    project = "ensemble_vanilla_sparse_trunk_deeponet"
    model_name = "ensemble_vanilla_sparse_trunk_deeponet_cartesian_prod"

    save_file_name = fstr(save_folder.payload + "{project}_M={M}_p={p}_seed={seed}.json")
    param_save_file_name = fstr(save_folder.payload + "{project}_M={M}_p={p}_seed={seed}.pickle")
    log_str = fstr("ensemble-vanilla-sparse-trunk-deeponet on {problem} with {opt_choice} M={M} p={p} seed={seed}")

    trunk_input = 2
    branch_input = m

    for seed in seeds:
        for p in ps:
            trunk_output = p
            branch_output = p

            branch_config = dict(
                activation=activation_choice,
                num_hidden_layers=branch_depth,
                nodes=branch_width,
                output_dim=branch_output*2,
                name="branch"
            )

            bias_config = dict(name="bias")

            if not os.path.isdir(str(save_folder)):
                os.makedirs(str(save_folder))

            print("\n\n")
            print("+"*100)
            print(log_str)
            print("+"*100)
            print("\n\n")

            train_input, Y, test_input, Y_test = get_data()

            # partitioning ============================================================
            # TRAINING
            train_trunk_input = train_input[1]
            test_trunk_input = test_input[1]

            pu_config = dict(
                num_partitions=M,
                dim=train_trunk_input.shape[-1]
            )

            pu_obj = PU(pu_config)

            pu_obj.partition_domain(train_trunk_input, centers=centers_list, radius=radius_list)
            print("Radius = ", pu_obj.radius)
            print(f"M={M}")
            print("Checking train partitioning")
            K, M_changed, _ = pu_obj.check_partioning(train_trunk_input, change_M=True, min_partitions_per_point=min_partitions_per_point)
            pu_trunk_width = int(1/np.sqrt(K) * trunk_width)
            print(f"K={K}, PU trunk_width={pu_trunk_width}")
            print(f"new M={M_changed}")
            print("Forming points per groups")
            num_groups, participation_idx, points_per_group, indices_per_group, _, radius_arrs = pu_obj.form_points_per_group(train_trunk_input)
            print("Group sizes: ", [len(g) for g in points_per_group])
            print("Forming weights per groups")
            weights_per_group = pu_obj.form_weights_per_group(points_per_group=points_per_group, participation_idx=participation_idx, radius_arrs=radius_arrs)
            weights_per_group = [jnp.broadcast_to(w.reshape(*w.shape, 1), (w.shape[0], w.shape[1], p)) for w in weights_per_group]

            # TEST
            print("\nChecking test partitioning")
            _, M_test, _ = pu_obj.check_partioning(test_trunk_input, False, min_partitions_per_point=min_partitions_per_point)
            print(f"M_test={M}")
            print("Forming points per groups")
            test_num_groups, test_participation_idx, test_points_per_group, test_indices_per_group, _, test_radius_arrs = pu_obj.form_points_per_group(test_trunk_input)
            print("Group sizes: ", [len(g) for g in test_points_per_group])
            print("Forming weights per groups")
            test_weights_per_group = pu_obj.form_weights_per_group(points_per_group=test_points_per_group, participation_idx=test_participation_idx, radius_arrs=test_radius_arrs)
            test_weights_per_group = [jnp.broadcast_to(w.reshape(*w.shape, 1), (w.shape[0], w.shape[1], p)) for w in test_weights_per_group]

            # reassembling training and test and dummy tuples
            train_input = (train_input[0], train_input[1], points_per_group, weights_per_group, indices_per_group, participation_idx, num_groups)
            test_input = (test_input[0], test_input[1], test_points_per_group, test_weights_per_group, test_indices_per_group, test_participation_idx, test_num_groups)
            dummy_input = (jnp.expand_dims(train_input[0][0, :], 0), jnp.expand_dims(train_input[1][0, :], 0))

            pu_trunk_config = dict(
                last_layer_activate=False,
                input_dim=trunk_input,
                output_dim=trunk_output,
                name="pu_trunk"
            )

            trunk_config_list = []
            for trunk_idx in range(M_changed):
                temp_config = copy.deepcopy(pu_trunk_config)
                temp_config["activation"] = activation_choice
                temp_config["name"] = f"pu_{trunk_idx}_trunk"
                if trunk_idx < trunk_special_index:
                    temp_config["layer_sizes"] = special_pu_trunk_layer_sizes
                trunk_config_list.append(temp_config)

            vanilla_trunk_config = dict(
                activation=activation_choice,
                last_layer_activate=True,  # in stacking, this should be activated
                num_hidden_layers=trunk_depth,
                nodes=trunk_width,
                output_dim=trunk_output,
                name="vanilla_trunk"
            )

            model_config = dict(
                branch_config=branch_config,
                trunk_config_list=trunk_config_list,
                vanilla_trunk_config=vanilla_trunk_config,
                bias_config=bias_config,
                num_partitions=M_changed
            )

            hyperparameter_dict = dict(
                print_bool=print_bool,
                print_interval=print_interval,
                epochs=epochs,
                model_config=model_config,
                opt_choice=opt_choice,
                schedule_choice=schedule_choice,
                lr_dict=lr_dict,
                problem=problem,
                N=N,
                N_test=N_test,
                p=p,
                M=M_changed,
                K=K,
                train_num_partitions=M_changed,
                test_num_partitions=M_test,
                train_num_groups=num_groups,
                test_num_groups=test_num_groups,
                m=m,
                dtype=dtype,
                seed=seed
            )

            train_config = dict(
                dummy_input=dummy_input,
                train_input=train_input,
                test_input=test_input,
                Y=Y,
                Y_test=Y_test,
                model_name=model_name,
            ) | hyperparameter_dict

            logged_results, trained_params = train_func(train_config)
            logged_results = logged_results | hyperparameter_dict

            if save_results_bool:
                torch.save(trained_params, str(param_save_file_name))
                save_results(logged_results, str(save_file_name))


def ensemble_vanilla_pod_sparse_trunk_deeponet_run():
    project = "ensemble_vanilla_pod_sparse_trunk_deeponet"
    model_name = "ensemble_vanilla_pod_sparse_trunk_deeponet_cartesian_prod"

    save_file_name = fstr(save_folder.payload + "{project}_M={M}_p={p}_seed={seed}.json")
    param_save_file_name = fstr(save_folder.payload + "{project}_M={M}_p={p}_seed={seed}.pickle")
    log_str = fstr("ensemble-vanilla-POD-sparse-trunk-deeponet on {problem} with {opt_choice} M={M} p={p} seed={seed}")

    trunk_input = 2
    branch_input = m

    for seed in seeds:
        for p in ps:
            trunk_output = p
            branch_output = p

            branch_config = dict(
                activation=activation_choice,
                num_hidden_layers=branch_depth,
                nodes=branch_width,
                output_dim=branch_output*2 + pod_p,
                name="branch"
            )

            bias_config = dict(name="bias")

            if not os.path.isdir(str(save_folder)):
                os.makedirs(str(save_folder))

            print("\n\n")
            print("+"*100)
            print(log_str)
            print("+"*100)
            print("\n\n")

            train_input, Y, test_input, Y_test = get_data()

            trunk_pod_basis, x_mean = form_pod_basis(Y, pod_p)
            x_mean = x_mean.T

            # partitioning ============================================================
            # TRAINING
            train_trunk_input = train_input[1]
            test_trunk_input = test_input[1]

            pu_config = dict(
                num_partitions=M,
                dim=train_trunk_input.shape[-1]
            )

            pu_obj = PU(pu_config)

            pu_obj.partition_domain(train_trunk_input, centers=centers_list, radius=radius_list)
            print("Radius = ", pu_obj.radius)
            print(f"M={M}")
            print("Checking train partitioning")
            K, M_changed, _ = pu_obj.check_partioning(train_trunk_input, change_M=True, min_partitions_per_point=min_partitions_per_point)
            pu_trunk_width = int(1/np.sqrt(K) * trunk_width)
            print(f"K={K}, PU trunk_width={pu_trunk_width}")
            print(f"new M={M_changed}")
            print("Forming points per groups")
            num_groups, participation_idx, points_per_group, indices_per_group, _, radius_arrs = pu_obj.form_points_per_group(train_trunk_input)
            print("Group sizes: ", [len(g) for g in points_per_group])
            print("Forming weights per groups")
            weights_per_group = pu_obj.form_weights_per_group(points_per_group=points_per_group, participation_idx=participation_idx, radius_arrs=radius_arrs)
            weights_per_group = [jnp.broadcast_to(w.reshape(*w.shape, 1), (w.shape[0], w.shape[1], p)) for w in weights_per_group]

            # TEST
            print("\nChecking test partitioning")
            _, M_test, _ = pu_obj.check_partioning(test_trunk_input, False, min_partitions_per_point=min_partitions_per_point)
            print(f"M_test={M}")
            print("Forming points per groups")
            test_num_groups, test_participation_idx, test_points_per_group, test_indices_per_group, _, test_radius_arrs = pu_obj.form_points_per_group(test_trunk_input)
            print("Group sizes: ", [len(g) for g in test_points_per_group])
            print("Forming weights per groups")
            test_weights_per_group = pu_obj.form_weights_per_group(points_per_group=test_points_per_group, participation_idx=test_participation_idx, radius_arrs=test_radius_arrs)
            test_weights_per_group = [jnp.broadcast_to(w.reshape(*w.shape, 1), (w.shape[0], w.shape[1], p)) for w in test_weights_per_group]

            # reassembling training and test and dummy tuples
            train_input = (train_input[0], train_input[1], points_per_group, weights_per_group, indices_per_group, participation_idx, num_groups)
            test_input = (test_input[0], test_input[1], test_points_per_group, test_weights_per_group, test_indices_per_group, test_participation_idx, test_num_groups)

            dummy_input = (train_input[0], train_input[1])

            pu_trunk_config = dict(
                last_layer_activate=False,
                input_dim=trunk_input,
                output_dim=trunk_output,
                name="pu_trunk"
            )

            trunk_config_list = []
            for trunk_idx in range(M_changed):
                temp_config = copy.deepcopy(pu_trunk_config)
                temp_config["activation"] = activation_choice
                temp_config["name"] = f"pu_{trunk_idx}_trunk"
                if trunk_idx < trunk_special_index:
                    temp_config["layer_sizes"] = special_pu_trunk_layer_sizes
                else:
                    temp_config["layer_sizes"] = [trunk_width for _ in range(trunk_depth)]
                trunk_config_list.append(temp_config)

            vanilla_trunk_config = dict(
                activation=activation_choice,
                last_layer_activate=True,  # should be activated when stacking
                num_hidden_layers=trunk_depth,
                nodes=trunk_width,
                output_dim=trunk_output,
                name="trunk"
            )

            model_config = dict(
                branch_config=branch_config,
                trunk_config_list=trunk_config_list,
                vanilla_trunk_config=vanilla_trunk_config,
                bias_config=bias_config,
                num_partitions=M_changed,
                trunk_pod_basis=trunk_pod_basis,
                pod_mean=x_mean,
                p=p,
                pod_p=pod_p
            )

            hyperparameter_dict = dict(
                print_bool=print_bool,
                print_interval=print_interval,
                epochs=epochs,
                model_config=model_config,
                opt_choice=opt_choice,
                schedule_choice=schedule_choice,
                lr_dict=lr_dict,
                problem=problem,
                N=N,
                N_test=N_test,
                p=p,
                M=M_changed,
                K=K,
                train_num_partitions=M_changed,
                test_num_partitions=M_test,
                train_num_groups=num_groups,
                test_num_groups=test_num_groups,
                m=m,
                dtype=dtype,
                seed=seed
            )

            train_config = dict(
                dummy_input=dummy_input,
                train_input=train_input,
                test_input=test_input,
                Y=Y,
                Y_test=Y_test,
                model_name=model_name,
            ) | hyperparameter_dict

            logged_results, trained_params = train_func(train_config)

            # removing tensors from dict so results can be saved
            del hyperparameter_dict["model_config"]["pod_mean"]
            del hyperparameter_dict["model_config"]["trunk_pod_basis"]

            logged_results = logged_results | hyperparameter_dict

            if save_results_bool:
                torch.save(trained_params, str(param_save_file_name))
                save_results(logged_results, str(save_file_name))


def ensemble_pod_sparse_trunk_deeponet_run():
    project = "ensemble_pod_sparse_trunk_deeponet"
    model_name = "ensemble_pod_sparse_trunk_deeponet_cartesian_prod"

    save_file_name = fstr(save_folder.payload + "{project}_M={M}_p={p}_seed={seed}.json")
    param_save_file_name = fstr(save_folder.payload + "{project}_M={M}_p={p}_seed={seed}.pickle")
    log_str = fstr("ensemble-POD-sparse-trunk-deeponet on {problem} with {opt_choice} M={M} p={p} seed={seed}")

    trunk_input = 2
    branch_input = m

    for seed in seeds:
        for p in ps:
            trunk_output = p
            branch_output = p

            branch_config = dict(
                activation=activation_choice,
                num_hidden_layers=branch_depth,
                nodes=branch_width,
                output_dim=branch_output + pod_p,
                name="branch"
            )

            bias_config = dict(name="bias")

            if not os.path.isdir(str(save_folder)):
                os.makedirs(str(save_folder))

            print("\n\n")
            print("+"*100)
            print(log_str)
            print("+"*100)
            print("\n\n")

            train_input, Y, test_input, Y_test = get_data()

            trunk_pod_basis, x_mean = form_pod_basis(Y, pod_p)
            x_mean = x_mean.T

            # partitioning ============================================================
            # TRAINING
            train_trunk_input = train_input[1]
            test_trunk_input = test_input[1]

            pu_config = dict(
                num_partitions=M,
                dim=train_trunk_input.shape[-1]
            )

            pu_obj = PU(pu_config)

            pu_obj.partition_domain(train_trunk_input, centers=centers_list, radius=radius_list)
            print("Radius = ", pu_obj.radius)
            print(f"M={M}")
            print("Checking train partitioning")
            K, M_changed, _ = pu_obj.check_partioning(train_trunk_input, change_M=True, min_partitions_per_point=min_partitions_per_point)
            pu_trunk_width = int(1/np.sqrt(K) * trunk_width)
            print(f"K={K}, PU trunk_width={pu_trunk_width}")
            print(f"new M={M_changed}")
            print("Forming points per groups")
            num_groups, participation_idx, points_per_group, indices_per_group, _, radius_arrs = pu_obj.form_points_per_group(train_trunk_input)
            print("Group sizes: ", [len(g) for g in points_per_group])
            print("Forming weights per groups")
            weights_per_group = pu_obj.form_weights_per_group(points_per_group=points_per_group, participation_idx=participation_idx, radius_arrs=radius_arrs)
            weights_per_group = [jnp.broadcast_to(w.reshape(*w.shape, 1), (w.shape[0], w.shape[1], p)) for w in weights_per_group]

            # TEST
            print("\nChecking test partitioning")
            _, M_test, _ = pu_obj.check_partioning(test_trunk_input, False, min_partitions_per_point=min_partitions_per_point)
            print(f"M_test={M}")
            print("Forming points per groups")
            test_num_groups, test_participation_idx, test_points_per_group, test_indices_per_group, _, test_radius_arrs = pu_obj.form_points_per_group(test_trunk_input)
            print("Group sizes: ", [len(g) for g in test_points_per_group])
            print("Forming weights per groups")
            test_weights_per_group = pu_obj.form_weights_per_group(points_per_group=test_points_per_group, participation_idx=test_participation_idx, radius_arrs=test_radius_arrs)
            test_weights_per_group = [jnp.broadcast_to(w.reshape(*w.shape, 1), (w.shape[0], w.shape[1], p)) for w in test_weights_per_group]

            # reassembling training and test and dummy tuples
            train_input = (train_input[0], train_input[1], points_per_group, weights_per_group, indices_per_group, participation_idx, num_groups)
            test_input = (test_input[0], test_input[1], test_points_per_group, test_weights_per_group, test_indices_per_group, test_participation_idx, test_num_groups)

            dummy_input = (train_input[0], train_input[1])

            pu_trunk_config = dict(
                last_layer_activate=False,
                input_dim=trunk_input,
                output_dim=trunk_output,
                name="pu_trunk"
            )

            trunk_config_list = []
            for trunk_idx in range(M_changed):
                temp_config = copy.deepcopy(pu_trunk_config)
                temp_config["activation"] = activation_choice
                temp_config["name"] = f"pu_{trunk_idx}_trunk"
                if trunk_idx < trunk_special_index:
                    temp_config["layer_sizes"] = special_pu_trunk_layer_sizes
                trunk_config_list.append(temp_config)

            model_config = dict(
                branch_config=branch_config,
                trunk_config_list=trunk_config_list,
                bias_config=bias_config,
                num_partitions=M_changed,
                trunk_pod_basis=trunk_pod_basis,
                pod_mean=x_mean,
                p=p,
                pod_p=pod_p
            )

            hyperparameter_dict = dict(
                print_bool=print_bool,
                print_interval=print_interval,
                epochs=epochs,
                model_config=model_config,
                opt_choice=opt_choice,
                schedule_choice=schedule_choice,
                lr_dict=lr_dict,
                problem=problem,
                N=N,
                N_test=N_test,
                p=p,
                M=M_changed,
                K=K,
                train_num_partitions=M_changed,
                test_num_partitions=M_test,
                train_num_groups=num_groups,
                test_num_groups=test_num_groups,
                m=m,
                dtype=dtype,
                seed=seed
            )

            train_config = dict(
                dummy_input=dummy_input,
                train_input=train_input,
                test_input=test_input,
                Y=Y,
                Y_test=Y_test,
                model_name=model_name,
            ) | hyperparameter_dict

            logged_results, trained_params = train_func(train_config)

            # removing tensors from dict so results can be saved
            del hyperparameter_dict["model_config"]["pod_mean"]
            del hyperparameter_dict["model_config"]["trunk_pod_basis"]

            logged_results = logged_results | hyperparameter_dict

            if save_results_bool:
                torch.save(trained_params, str(param_save_file_name))
                save_results(logged_results, str(save_file_name))


if __name__ == "__main__":
    dataset_folder = fstr("../dataset/{problem}/")

    save_results_bool = True
    print_bool = True

    problem = "Darcy_triangular"
    opt_choice = "adam"
    schedule_choice = "inverse_time_decay"

    print_interval = 10000
    epochs = 150000
    dtype = "float32"
    # activations
    activation_choice = "leaky_relu"

    m = 101
    N = 1900
    N_test = 100

    lr_dict = dict(
        peak_lr=1e-3,
        decay_steps=epochs // 5,
        decay_rate=0.5
    )

    if dtype == "float64":
        jax.config.update("jax_enable_x64", True)
    else:
        jax.config.update("jax_enable_x64", False)

    centers_list = [
        [0.24, 0.13],
        [0.76, 0.13],
        [0.50, 0.57],
    ]

    M = len(centers_list)
    trunk_special_index = M

    radius_list = []
    cnt = 1
    for c in centers_list:
        if cnt <= 2:
            radius_list.append(0.3)
        else:
            radius_list.append(0.3)
        cnt += 1

    assert len(radius_list) == len(centers_list)

    to_plot = "mse"
    plot_train_test = "test"
    clim_lower, clim_upper = 8e-6, 4e-5  # mse
    # clim_lower, clim_upper = 2, 18 # gradient
    title_font = 90
    axis_font = 88
    tick_fontsize = 60
    figsize = (22, 20)

    # sub project folder
    all_run_folder = "all_results"

    plot_save_folder = f"../figures/{problem}/"  # for heatmaps

    # ============================================================================================================
    seeds = [0, 1, 2, 3, 4]
    ps = [100]
    pod_p = 20
    pod_ps = [20]
    trunk_width = 64
    branch_width = 128

    branch_depth = 3
    trunk_depth = 3
    shallow_trunk_depth = trunk_depth - 1
    # ============================================================================================================
    # vanilla deeponet

    folder_string = f"vanilla_deeponet_results"
    save_folder = fstr("../{all_run_folder}/{folder_string}/{problem}/")
    vanilla_deeponet_run()
    # plot_vanilla()

    # ============================================================================================================
    # POD deeponet

    # True corresponds to adding mean to prediction and not using it as trainable basis, False corresponds to our proposed way
    mean_bools = [True, False]
    folder_string = fstr("pod_deeponet_results_mean_bool={mean_bool}")
    save_folder = fstr("../{all_run_folder}/" + folder_string.payload + "/{problem}/")
    # pod_deeponet_run()
    # if not to_plot == "gradient":
        # plot_pod()

    # ============================================================================================================
    # ensemble vanilla trunk deeponet

    num_trunks = M + 1
    folder_string = f"ensemble_vanilla_trunk_deeponet_results"
    save_folder = fstr("../{all_run_folder}/{folder_string}/{problem}/")
    # ensemble_vanilla_trunk_deeponet_run()
    # if not to_plot == "gradient":
        # ensemble_vanilla_trunk_plot()

    # ============================================================================================================
    # ensemble sparse trunk deeponet

    min_partitions_per_point = 1
    special_pu_trunk_layer_sizes = [int(trunk_width/1) for _ in range(trunk_depth)]
    folder_string = f"ensemble_vanilla_sparse_trunk_deeponet_results"
    save_folder = fstr("../{all_run_folder}/{folder_string}/{problem}/")
    # ensemble_vanilla_sparse_trunk_deeponet_run()
    # if not to_plot == "gradient":
        # plot_ensemble_vanilla_sparse()

    # ============================================================================================================
    # ensemble POD sparse trunk deeponet

    min_partitions_per_point = 1
    special_pu_trunk_layer_sizes = [int(trunk_width/1) for _ in range(trunk_depth)]
    folder_string = f"ensemble_pod_sparse_trunk_deeponet_results"
    save_folder = fstr("../{all_run_folder}/{folder_string}/{problem}/")
    # ensemble_pod_sparse_trunk_deeponet_run()
    # if not to_plot == "gradient":
    #     plot_ensemble_pod_sparse()

    # ============================================================================================================
    # ensemble POD vanilla trunk deeponet

    folder_string = f"ensemble_pod_vanilla_trunk_deeponet_results"
    save_folder = fstr("../{all_run_folder}/{folder_string}/{problem}/")
    # ensemble_pod_vanilla_trunk_deeponet_run()
    # if not to_plot == "gradient":
        # ensemble_pod_vanilla_trunk_plot()

    # ============================================================================================================
    # ensemble vanilla POD sparse trunk deeponet

    folder_string = f"ensemble_vanilla_pod_sparse_trunk_deeponet_results"
    save_folder = fstr("../{all_run_folder}/{folder_string}/{problem}/")
    # ensemble_vanilla_pod_sparse_trunk_deeponet_run()
    # if not to_plot == "gradient":
    #     ensemble_vanilla_pod_sparse_trunk_plot()
