"""
This trains the DeepONet for the paper results. It relies on the deepxde library.

Executing it will output two pickle files: one with the results and one with the
models.
"""
# Force the library backend to be pytorch.
import os
# os.environ['DDE_BACKEND'] = 'jax'
os.environ['DDE_BACKEND'] = 'pytorch'

# Import this first due to a circular import issue in the deepxde library
import torch._dynamo.decorators

import datetime
import itertools
import pickle
import numpy as np

import torch
import deepxde as dde

# import spatial_interp.poisson_solutions as poisson_solutions
import poisson_solutions

# This was last run on a Macbook Pro. Alternatives are 'cpu' and 'cuda'.
device = 'mps'
dtype = "float32"
torch_dtype = torch.float32 if dtype == "float32" else torch.float64
np_dtype = np.float32 if dtype == "float32" else np.float64
torch.set_default_dtype(torch_dtype)
torch.set_default_device(device)
dde.config.set_default_float(dtype)


#
# Set up the data.
#
N_data = 10_000
N_grid = 22
train_data = poisson_solutions.create_dataset_dict()
# Push the data to torch and onto the accelerator.
convert_to_type = lambda x: {k: torch.Tensor(v.astype(np_dtype)).to(device) for k, v in x.items()}
train_data_dim = {k: convert_to_type(v) for k, v in train_data[N_grid].items()}


#
# Do the training loop.
#
test_results = []
models = []
# subset = ['fem', 1, 4, 7, 's1', 's4', 's7', 'c1', 'c4', 'c7']
# N_seeds = 4  # +1 from the other sweep.
# Defaults to running one seed on all function spaces.
all = train_data_dim.keys()
N_seeds = 1
function_spaces_to_run = all

for p_train, seed in itertools.product(function_spaces_to_run, range(N_seeds)):
    print("Training on", p_train)
    # Get the training data
    branch = train_data_dim[p_train]['f']
    trunk = train_data_dim[p_train]['x'][0].reshape((-1, 1))
    X_train = (branch, trunk)
    y_train = train_data_dim[p_train]['u']

    # DeepXDE requires a test no matter what; just use train as a dummy. We perform evals after training.
    X_test = X_train
    y_test = y_train

    # Simple data structure in module; testing data is REQUIRED
    data = dde.data.TripleCartesianProd(
        X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test
    )

    net = dde.nn.DeepONetCartesianProd(
        [N_grid, 256, 256],
        [1, 256, 256],
        activation="relu",
        kernel_initializer="Glorot normal",
    )

    # Compile and train
    model = dde.Model(data, net)
    models.append((p_train, model))
    # Adam can be done here. The decay is a step learning rate scheduler every quarter through training.
    N_iter = 20_000
    model.compile("adam", loss='MSE', lr=0.001, verbose=0, decay=(("step", N_iter // 4, 0.1)))
    model.train(iterations=N_iter, batch_size=2048, verbose=1)

    # Now we compare this against all the other data sets.
    evals = {}
    for p_test in train_data_dim.keys():
        branch = train_data_dim[p_test]['f']
        trunk = train_data_dim[p_test]['x'][0].reshape((-1, 1))
        X_test = (branch, trunk)
        
        vals = model.net(X_test).detach()
        test_mse = torch.mean((vals - train_data_dim[p_test]['u']) ** 2).item()
        evals[p_test] = test_mse
    test_results.append((p_train, evals))


# Naively save the results and models to a pickle file.
savepath = "."  # Save elsewhere if you choose
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
output_file = os.path.join(savepath, f"deeponet_results_{current_time}.pkl")
output_file_models = os.path.join(savepath, f"deeponet_models_{current_time}.pkl")
with open(output_file, 'wb') as f:
    pickle.dump(test_results, f)
print(f"Test Evals saved to {output_file}")
with open(output_file_models, 'wb') as f:
    pickle.dump(models, f)
print(f"Models saved to {output_file_models}")
