import jax.numpy as jnp
import matplotlib.pyplot as plt
import flax.linen as nn
import jax
import optax
import os
# import orthax
import itertools
import warnings
import numpy as np

warnings.filterwarnings("ignore")

from jax import random, jit, grad, vmap, lax
from functools import partial
from tqdm import trange
from typing import Sequence, Callable
import time

GPU = True
if GPU:
    devices = jax.devices('gpu')
    if len(devices) == 0:
        raise RuntimeError("No GPU devices available")
    device = devices[0]
    print("devices=", devices)
else:
    device = jax.devices('cpu')[0]
print("device=", device)

Pe = 20
job_id = os.environ.get("SLURM_JOB_ID", "nojobid")
output_dir = f"outputs/run_{job_id}_{Pe}_moes"
os.makedirs(output_dir, exist_ok=True)

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_debug_nans", True)

class BoxDense(nn.Module):
    name: str = "BoxDense"
    features: int
    activation: Callable
    depth: int
    layer: int
    arch_type: str

    @staticmethod
    def box_init(rng, shape, arch_type, depth, layer, dtype=jnp.float32):
        rng_points, rng_norms = random.split(rng, 2)
        norms = random.normal(rng_norms, shape=shape, dtype=dtype)
        norms = jnp.divide(norms, jnp.linalg.norm(norms, axis=1, keepdims=True))
        if arch_type == "plain":
            p_max = jnp.maximum(0, jnp.sign(norms))
            points = random.uniform(rng_points, shape=shape, dtype=dtype)
            k = 1. / jnp.sum((p_max - points) * norms, axis=0, keepdims=True)
        elif arch_type == "resnet":
            m = (1. + 1. / depth)**layer
            p_max = m * jnp.maximum(0, jnp.sign(norms))
            points = random.uniform(rng_points, shape=shape, dtype=dtype, minval=0., maxval=m)
            k = 1. / depth / jnp.sum((p_max - points) * norms, axis=0, keepdims=True)

        kernel = k * norms
        bias = k * jnp.sum(points * norms, axis=0)
        return kernel, bias.ravel()

    @nn.compact
    def __call__(self, x):
        if self.arch_type == "plain":
            init_fn = partial(self.box_init, arch_type="plain", depth=self.depth, layer=self.layer)
        elif self.arch_type == "resnet":
            init_fn = partial(self.box_init, arch_type="resnet", depth=self.depth, layer=self.layer)
        else:
            raise ValueError(f"Unsupported architecture type: {self.arch_type}")

        layer_weights = self.param("layer_weights", init_fn, (x.shape[-1], self.features))
        kernel, bias = layer_weights
        return self.activation(jnp.tensordot(x, kernel,(-1,0)) - bias)

def embedding(x, num_frequencies=5):
    """
    Generate sinusoidal embeddings for a single input of shape (2,).

    Args:
        x (ndarray): Input array of shape (2,).
        num_frequencies (int): Number of frequencies to use for the embedding.

    Returns:
        ndarray: Embedded features of shape (2 * 2 * num_frequencies,).
    """
    frequencies = jnp.arange(1, num_frequencies + 1)  # k values
    pi_x = jnp.pi * x[:, None]  # Shape: (2, 1)

    # Compute sin(k pi x) and cos(k pi x) for both elements
    sin_features = jnp.sin(frequencies * pi_x)  # Shape: (2, num_frequencies)
    cos_features = jnp.cos(frequencies * pi_x)  # Shape: (2, num_frequencies)

    # Concatenate sin and cos features along the last axis
    embedded_features = jnp.concatenate([sin_features, cos_features], axis=-1)  # Shape: (2, 2 * num_frequencies)

    # Flatten into a 1D array
    embedded_features = embedded_features.flatten()  # Shape: (2 * 2 * num_frequencies,)

    return embedded_features


class GatingNetwork(nn.Module):
    """
    A feedforward neural network that acts as a gating mechanism for partitioning input space.

    This network takes an input and produces a probability distribution over partitions.
    The output is softmax-normalized to ensure the probabilities sum to 1.

    Attributes:
        n_hidden (int): Number of hidden units in each layer.
        n_layers (int): Number of hidden layers in the network.
        num_partitions (int): Number of partitions (output dimension).
        activation (Callable): Activation function to use in hidden layers.
    """

    n_hidden: int
    n_layers: int
    num_partitions: int
    activation: Callable

    def setup(self):
        """
        Set up the layers of the network.

        This method initializes the hidden layers and the output layer of the network.
        """
        self.layer_0 = BoxDense(features=self.n_hidden, activation=self.activation, depth=self.n_layers, layer=0, arch_type="plain")
        self.layers = [
            BoxDense(self.n_hidden, self.activation, self.n_layers, layer+1, "resnet") for layer in range(self.n_layers)
        ]
        self.output = BoxDense(self.num_partitions, lambda x: x, self.n_layers, self.n_layers, "plain")


    @nn.compact
    def __call__(self, x):
        """
        Forward pass through the network.

        Args:
            x (jnp.ndarray): Input tensor.

        Returns:
            jnp.ndarray: Softmax-normalized output representing partition probabilities.
        """
        x = embedding(x)
        x = self.layer_0(x)
        for layer in self.layers:
            x = x + layer(x) # Residual connection
        return nn.softmax(self.output(x))


class BasisNetwork(nn.Module):
    """
    A network that returns a group of MLP basis functions, one for each partition.

    Attributes:
        n_hidden (int): Number of hidden units in each layer.
        n_layers (int): Number of hidden layers in the network.
        num_partitions (int): Number of partitions (output dimension).
        activation (Callable): Activation function to use in hidden layers.
    """

    n_hidden: int
    n_layers: int
    basis_size: int
    activation: Callable

    def setup(self):
        """
        Set up the layers of the network.

        This method initializes the hidden layers and the output layer of the network.
        """
        self.layer_0 = BoxDense(features=self.n_hidden, activation=self.activation, depth=self.n_layers, layer=0, arch_type="plain")
        self.layers = [
            BoxDense(self.n_hidden, self.activation, self.n_layers, layer+1, "resnet") for layer in range(self.n_layers)
        ]
        self.output = BoxDense(self.basis_size, lambda x: x, self.n_layers, self.n_layers, "plain")


    @nn.compact
    def __call__(self, x):
        """
        Forward pass through the network.

        Args:
            x (jnp.ndarray): Input tensor.

        Returns:
            jnp.ndarray: Softmax-normalized output representing partition probabilities.
        """
        x = embedding(x)
        x = self.layer_0(x)
        for layer in self.layers:
            x = x + layer(x) # Residual connection
        return self.output(x)

class BlendedMLP_Regression:
    def __init__(self, key, net_setup_params, x, t, h, D, center, width, lr_c, lr_f, device):
        self.device = device
        # Master random key
        self.key = key
        # Dictionary containing network setup parameters
        self.net_setup_params = net_setup_params
        # Set up the input data
        self.h = h
        self.D = D
        self.center = center
        self.width = width
        self.N_x = x.shape[0]
        self.N_t = t.shape[0]
        X, T = jnp.meshgrid(x.ravel(),t.ravel())
        self.x_res, self.t_res = X.ravel(), T.ravel()
        x0_indices = jnp.where(self.x_res == 0)
        self.x_bc0, self.t_bc0 = self.x_res[x0_indices], self.t_res[x0_indices]
        x1_indices = jnp.where(self.x_res == x[-1])
        self.x_bc1, self.t_bc1 = self.x_res[x1_indices], self.t_res[x1_indices]
        t0_indices = jnp.where(self.t_res == 0)
        self.x_ic, self.t_ic = self.x_res[t0_indices], self.t_res[t0_indices]
        self.u_test = 1. / jnp.sqrt(1. + 2*self.D*self.t_res/self.width**2) * jnp.exp(- 0.5 * (self.x_res - self.center - self.h*self.t_res)**2 / (self.width**2 + 2*self.D*self.t_res))
        
        self.x_res = jax.device_put(self.x_res, device)
        self.t_res = jax.device_put(self.t_res, device)
        self.x_bc0 = jax.device_put(self.x_bc0, device)
        self.t_bc0 = jax.device_put(self.t_bc0, device)
        self.x_bc1 = jax.device_put(self.x_bc1, device)
        self.t_bc1 = jax.device_put(self.t_bc1, device)
        self.u_test = jax.device_put(self.u_test, device)

        self.gating_setup_params_c = net_setup_params["gating"]["coarse"]
        self.basis_setup_params_c = net_setup_params["basis"]["coarse"]
        self.gating_setup_params_f = net_setup_params["gating"]["fine"]
        self.basis_setup_params_f = net_setup_params["basis"]["fine"]
        self.sigma_schedule = net_setup_params["sigma_schedule"]
        self.N_coarse = self.gating_setup_params_c["num_partitions"]
        self.N_fine = self.gating_setup_params_f["num_partitions"]
        self.N_basis_c = self.basis_setup_params_c["basis_size"]
        self.N_basis_f = self.basis_setup_params_f["basis_size"] + self.basis_setup_params_c["basis_size"]

        gating_model_c = GatingNetwork(**self.gating_setup_params_c)
        basis_model_c = [BasisNetwork(**self.basis_setup_params_c) for _ in range(self.N_coarse)]
        self.gating_params_c = gating_model_c.init(key, jnp.ones(2,))
        self.basis_params_c = [i.init(key, jnp.ones(2,)) for i in basis_model_c]
        self.coeffs_c = [jnp.ones((self.N_basis_c)) for _ in range(self.N_coarse)]
        self.coeffs_c = jnp.asarray(self.coeffs_c)

        gating_model_f = [GatingNetwork(**self.gating_setup_params_f) for _ in range(self.N_coarse)]
        basis_model_f = [[BasisNetwork(**self.basis_setup_params_f) for _ in range(self.N_fine)]\
                         for _ in range(self.N_coarse)]
        self.gating_params_f = [i.init(key, jnp.ones(2,)) for i in gating_model_f]
        self.basis_params_f = [[j.init(key, jnp.ones(2,)) for j in i] for i in basis_model_f]
        self.coeffs_f = [[jnp.ones((self.N_basis_f)) for _ in range(self.N_fine)] for _ in range(self.N_coarse)]
        self.coeffs_f = jnp.asarray(self.coeffs_f)

        # BC mask
        beta = jnp.zeros_like(self.x_res)
        self.bc0_mask = beta.at[x0_indices].set(jnp.ones((t.shape[0],)))
        self.bc1_mask = beta.at[x1_indices].set(jnp.ones((t.shape[0],)))
        self.ic_mask = beta.at[t0_indices].set(jnp.ones((x.shape[0],)))

        # IC correction
        self.ic_correction = vmap(self.ic_correction_fn)(self.x_res, self.t_res)

        # Initialize sigma parameters for component and cooperative terms
        self.sigma_comp = self.sigma_schedule["comp"](0)
        self.sigma_coop = self.sigma_schedule["coop"](0)
        self.sigma_bc = self.sigma_schedule["bc"](0)
        self.sigma_ic = self.sigma_schedule["ic"](0)

        self.params_c = (
            self.gating_params_c,
            self.basis_params_c,
        )
        self.avg_params_c = self.params_c
        self.params_f2c = self.params_c

        self.params_f = (
            self.gating_params_f,
            self.basis_params_f,
        )
        self.avg_params_f = self.params_f

        self.gate_apply_c = gating_model_c.apply
        self.basis_apply_c = [i.apply for i in basis_model_c]
        self.gate_apply_f = [i.apply for i in gating_model_f]
        self.basis_apply_f = [[j.apply for j in i] for i in basis_model_f]

        lr_c = optax.exponential_decay(lr_c, 10000, 0.9)
        lr_f = optax.exponential_decay(lr_f, 10000, 0.9)
        self.optimizer_c = optax.chain(
                                        # optax.add_decayed_weights(1e-4),  # Apply weight decay
                                        optax.adam(learning_rate=lr_c),     # After talking to Nat this morning, I changed back to the default settings of Adam
                                      )
        self.optimizer_f = optax.chain(
                                        # optax.add_decayed_weights(1e-4),  # Apply weight decay
                                        optax.adam(learning_rate=lr_f),
                                      )
        self.opt_state_c = self.optimizer_c.init(self.params_c)
        self.opt_state_f = self.optimizer_f.init(self.params_f)

        self.itercount = itertools.count()
        self.l2_error_log_c = []
        self.loss_log_c = []
        self.loss_bcs_log_c = []
        self.l2_error_log_f = []
        self.loss_log_f = []
        self.loss_bcs_log_f = []

    def gate_net(self, params, x, t, level):
        if level == 0:
            return self.gate_apply_c(params, jnp.asarray([x,t]))
        elif level == 1:
            gate_f = [self.gate_apply_f[i](params[i], jnp.asarray([x,t])) for i in range(len(self.gate_apply_f))]
            return jnp.asarray(gate_f)

    def basis_net(self, params_c, params_f, x, t, level):
        if level == 0:
            basis_out = [self.basis_apply_c[i](params_c[i], jnp.asarray([x,t])) for i in range(len(self.basis_apply_c))]
            return jnp.asarray(basis_out)
        elif level == 1:
            basis_f = [[self.basis_apply_f[i][j](params_f[i][j], jnp.asarray([x,t])) for j in range(len(self.basis_apply_f[i]))]\
                       for i in range(len(self.basis_apply_f))]
            basis_f = jnp.asarray(basis_f)
            basis_c = self.basis_net(params_c, params_f, x, t, 0)[:,jnp.newaxis,:]
            basis_c = jnp.repeat(basis_c, repeats=self.N_fine, axis=1)
            return jnp.concatenate([basis_f, basis_c], axis=-1)

    def u_pou(self, params_c, params_f, coeffs, x, t, level):
        if level == 0:
            gate_params, basis_params = params_c
            gate_out = self.gate_net(gate_params, x, t, 0).ravel()
            basis_out = self.basis_net(basis_params, params_f[1], x, t, 0)
            return jnp.sum(gate_out * jnp.einsum("ij,ij->i", coeffs, basis_out))
        elif level == 1:
            gate_params, basis_params = params_f
            gate_out_c = self.gate_net(params_c[0], x, t, 0).ravel()
            gate_out_f = self.gate_net(gate_params, x, t, 1)
            basis_out = self.basis_net(params_c[1], basis_params, x, t, 1)
            return jnp.sum(jnp.einsum("i,ij,ijk,ijk", gate_out_c, gate_out_f, coeffs, basis_out))

    def u_pred(self, params_c, params_f, coeffs, x, t, level):
        return t * self.u_pou(params_c, params_f, coeffs, x, t, level) + self.ic_func(x, t)

    def L_net(self, params_c, params_f, coeffs, x, t, level):
        """
        This function computes L[y_POU](x)
        """
        u_out_fn = lambda x, t: self.u_pou(params_c, params_f, coeffs, x, t, level)
        return u_out_fn(x,t) + t * (jax.jacfwd(u_out_fn,1)(x,t) + self.h * jax.jacfwd(u_out_fn,0)(x,t) - self.D * jax.jacfwd(jax.jacfwd(u_out_fn,0),0)(x,t))

    def L_basis(self, params_c, params_f, x, t, level):
        basis_fn = lambda x, t: self.basis_net(params_c[1], params_f[1], x, t, level)
        return basis_fn(x,t) + t * (jax.jacfwd(basis_fn,1)(x,t) + self.h * jax.jacfwd(basis_fn,0)(x,t) - self.D * jax.jacfwd(jax.jacfwd(basis_fn,0),0)(x,t))

    def L_gate_basis(self, params_c, params_f, x, t, level):
        """
        This function calculates L[\pi_m \phi_n](x)
        """
        if level == 0:
            gate_params, basis_params = params_c
            gate_basis_fn = lambda x, t: jnp.einsum("i,ij->ij",
                                                    self.gate_net(gate_params, x, t, 0),
                                                    self.basis_net(basis_params, params_f[1], x, t, 0))
        elif level == 1:
            gate_params, basis_params = params_f
            gate_basis_fn = lambda x, t: jnp.einsum("i,ij,ijk->ijk",
                                                    self.gate_net(params_c[0], x, t, 0),
                                                    self.gate_net(gate_params, x, t, 1),
                                                    self.basis_net(params_c[1], basis_params, x, t, 1))
        return gate_basis_fn(x,t) + t * (jax.jacfwd(gate_basis_fn,1)(x,t) + self.h * jax.jacfwd(gate_basis_fn,0)(x,t) - self.D * jax.jacfwd(jax.jacfwd(gate_basis_fn,0),0)(x,t))

    def ic_func(self, x, t):
        return jnp.exp(- 0.5 * (x - self.center)**2 / self.width**2)

    def ic_correction_fn(self, x, t):
        u_true0_fn = lambda x: self.ic_func(x, t)
        d_u_true0 = grad(u_true0_fn)(x)
        dd_u_true0 = grad(grad(u_true0_fn))(x)
        return self.D * dd_u_true0 - self.h * d_u_true0


    @partial(jit, static_argnums=(0,4))
    def E_step(self, params_c, params_f, coeffs, level):
        if level == 0:
            gate_params, basis_params = params_c
            gate_out = vmap(self.gate_net, (None, 0, 0, None))(gate_params, self.x_res, self.t_res, 0)
            L_basis = vmap(self.L_basis, (None, None, 0, 0, None))(params_c, params_f, self.x_res, self.t_res, 0)
            basis_out = vmap(self.basis_net, (None, None, 0, 0, None))(basis_params, params_f[1], self.x_res, self.t_res, 0)
            # ubc_true = 1e-12

            # Component likelihood
            likelihood_comp = jnp.exp(-0.5 * (self.ic_correction.reshape(-1,1) - jnp.einsum("ij,dij->di",coeffs, L_basis)) ** 2 / self.sigma_comp**2)

            # BC likelihood
            likelihood_bc0 = jnp.exp(-0.5 * self.bc0_mask.reshape(-1,1) * (self.u_test[:,None] - jnp.einsum("ij,dij->di",coeffs, basis_out))**2 / self.sigma_bc**2)
            likelihood_bc1 = jnp.exp(-0.5 * self.bc1_mask.reshape(-1,1) * (self.u_test[:,None] - jnp.einsum("ij,dij->di",coeffs, basis_out))**2 / self.sigma_bc**2)

            # Combined likelihood
            likelihood = likelihood_comp * likelihood_bc0 * likelihood_bc1


            # Computing posterior
            posterior = gate_out * likelihood
            posterior = posterior + jnp.finfo(posterior.dtype).eps
            posterior = posterior / (
                jnp.sum(posterior, axis=1, keepdims=True)
            )
            return posterior
        elif level == 1:
            gate_params, basis_params = params_f
            gate_out_c = vmap(self.gate_net, (None, 0, 0, None))(params_c[0], self.x_res, self.t_res, 0)
            gate_out_f = vmap(self.gate_net, (None, 0, 0, None))(gate_params, self.x_res, self.t_res, 1)
            L_basis = vmap(self.L_basis, (None, None, 0, 0, None))(params_c, params_f, self.x_res, self.t_res, 1)
            basis_out = vmap(self.basis_net, (None, None, 0, 0, None))(params_c[1], basis_params, self.x_res, self.t_res, 1)
            # ubc_true = 1e-12

            # Component likelihood
            likelihood_comp = jnp.exp(-0.5 * (self.ic_correction.reshape(-1,1,1) - jnp.einsum("ijk,dijk->dij",coeffs, L_basis)) ** 2 / self.sigma_comp**2)

            # BC likelihood
            likelihood_bc0 = jnp.exp(-0.5 * jnp.expand_dims(self.bc0_mask.reshape(-1,1), -1) * (self.u_test[:,None,None] - jnp.einsum("ijk,dijk->dij",coeffs, basis_out))**2 / self.sigma_bc**2)
            likelihood_bc1 = jnp.exp(-0.5 * jnp.expand_dims(self.bc1_mask.reshape(-1,1), -1) * (self.u_test[:,None,None] - jnp.einsum("ijk,dijk->dij",coeffs, basis_out))**2 / self.sigma_bc**2)

            # Combined likelihood
            likelihood = likelihood_comp * likelihood_bc0 * likelihood_bc1

            # Computing posterior
            posterior = jnp.einsum("di,dij,dij->dij", gate_out_c, gate_out_f, likelihood)
            posterior = posterior + jnp.finfo(posterior.dtype).eps
            posterior = posterior / (
                jnp.sum(posterior, axis=(1,2), keepdims=True)
            )
            return posterior

    @partial(jit, static_argnums=(0,4))
    def loss_bcs(self, params_c, params_f, coeffs, level):
        u_out0 = vmap(self.u_pred, (None, None, None, 0, 0, None))(params_c, params_f, coeffs, self.x_bc0, self.t_bc0, level)
        u0_true = 1e-12
        u_out1 = vmap(self.u_pred, (None, None, None, 0, 0, None))(params_c, params_f, coeffs, self.x_bc1, self.t_bc1, level)
        u1_true = 1e-12
        loss_bcs = jnp.linalg.norm((u_out0 - u0_true))**2 + jnp.linalg.norm((u_out1 - u1_true))**2
        return loss_bcs

    @partial(jit, static_argnums=(0,4))
    def loss_ics(self, params_c, params_f, coeffs, level):
        u_out = vmap(self.u_pred, (None, None, None, 0, 0, None))(params_c, params_f, coeffs, self.x_ic, self.t_ic, level)
        u_ic = vmap(self.ic_func, (0,0))(self.x_ic, self.t_ic)
        loss_bcs = jnp.linalg.norm((u_out - u_ic))**2
        return loss_bcs

    @partial(jit, static_argnums=(0,))
    def loss_c(self, params_c, params_f, coeffs, posterior):
        gate_params, basis_params = params_c
        gate_out = vmap(self.gate_net, (None, 0, 0, None))(gate_params, self.x_res, self.t_res, 0)
        basis_out = vmap(self.basis_net, (None, None, 0, 0, None))(basis_params, params_f[1], self.x_res, self.t_res, 0)
        # ubc_true = 1e-12
        L_basis = vmap(self.L_basis, (None, None, 0, 0, None))(params_c, params_f, self.x_res, self.t_res, 0)
        # L[y_{POU}](x)
        blended_pred = vmap(self.L_net, (None, None, None, 0, 0, None))(params_c, params_f, coeffs, self.x_res, self.t_res, 0)

        # Component term
        component_loss = -jnp.sum((posterior * jnp.log(gate_out + jnp.finfo(gate_out.dtype).eps)))

        # Cooperative term
        cooperative_loss = (
            0.5
            * jnp.sum(((self.ic_correction - blended_pred) ** 2))
            / self.sigma_coop**2
        )

        # Component term
        comp_loss = (
            0.5
            * jnp.sum((posterior * (self.ic_correction.reshape(-1,1) - jnp.einsum("ij,dij->di",coeffs, L_basis)) ** 2))
            / self.sigma_comp**2
        )

        # BC terms
        bc0_loss = jnp.sum((self.bc0_mask.reshape(-1,1) * (self.u_test[:,None] - jnp.einsum("ij,dij->di",coeffs, basis_out))**2 / self.sigma_bc**2)) / 2
        bc1_loss = jnp.sum((self.bc1_mask.reshape(-1,1) * (self.u_test[:,None] - jnp.einsum("ij,dij->di",coeffs, basis_out))**2 / self.sigma_bc**2)) / 2

        # BC terms
        bc0_loss_ = jnp.sum((posterior * self.bc0_mask.reshape(-1,1) * (self.u_test[:,None] - jnp.einsum("ij,dij->di",coeffs, basis_out))**2 / self.sigma_bc**2)) / 2
        bc1_loss_ = jnp.sum((posterior * self.bc1_mask.reshape(-1,1) * (self.u_test[:,None] - jnp.einsum("ij,dij->di",coeffs, basis_out))**2 / self.sigma_bc**2)) / 2
        # Combine losses
        bc_loss = bc0_loss + bc1_loss + bc0_loss_ + bc1_loss_
        total_loss = component_loss + cooperative_loss + comp_loss + bc_loss
        assert total_loss.size == 1
        return total_loss

    @partial(jit, static_argnums=(0,))
    def loss_f(self, params_f, params_c, coeffs, posterior):
        gate_params, basis_params = params_f
        gate_out = vmap(self.gate_net, (None, 0, 0, None))(gate_params, self.x_res, self.t_res, 1)
        basis_out = vmap(self.basis_net, (None, None, 0, 0, None))(params_c[1], basis_params, self.x_res, self.t_res, 1)
        # ubc_true = 1e-12
        L_basis = vmap(self.L_basis, (None, None, 0, 0, None))(params_c, params_f, self.x_res, self.t_res, 1)
        # L[y_{POU}](x)
        blended_pred = vmap(self.L_net, (None, None, None, 0, 0, None))(params_c, params_f, coeffs, self.x_res, self.t_res, 1)

        # Component term
        component_loss = -jnp.sum((posterior * jnp.log(gate_out + jnp.finfo(gate_out.dtype).eps)))

        # Cooperative term
        cooperative_loss = (
            0.5
            * jnp.sum(((self.ic_correction - blended_pred) ** 2))
            / self.sigma_coop**2
        )

        # Component term
        comp_loss = (
            0.5
            * jnp.sum((posterior * (self.ic_correction.reshape(-1,1,1) - jnp.einsum("ijk,dijk->dij",coeffs, L_basis)) ** 2))
            / self.sigma_comp**2
        )

        # BC terms
        bc0_loss = jnp.sum((jnp.expand_dims(self.bc0_mask.reshape(-1,1), -1) * (self.u_test[:,None,None] - jnp.einsum("ijk,dijk->dij",coeffs, basis_out))**2 / self.sigma_bc**2)) / 2
        bc1_loss = jnp.sum((jnp.expand_dims(self.bc1_mask.reshape(-1,1), -1) * (self.u_test[:,None,None] - jnp.einsum("ijk,dijk->dij",coeffs, basis_out))**2 / self.sigma_bc**2)) / 2

        # BC terms
        bc0_loss_ = jnp.sum((posterior * jnp.expand_dims(self.bc0_mask.reshape(-1,1), -1) * (self.u_test[:,None,None] - jnp.einsum("ijk,dijk->dij",coeffs, basis_out))**2 / self.sigma_bc**2)) /2
        bc1_loss_ = jnp.sum((posterior * jnp.expand_dims(self.bc1_mask.reshape(-1,1), -1) * (self.u_test[:,None,None] - jnp.einsum("ijk,dijk->dij",coeffs, basis_out))**2 / self.sigma_bc**2)) /2

        # Combine losses
        bc_loss = bc0_loss + bc1_loss + bc0_loss_ + bc1_loss_
        total_loss = component_loss + cooperative_loss + comp_loss + bc_loss
        assert total_loss.size == 1
        return total_loss

    @partial(jit, static_argnums=(0,4))
    def compute_l2_error(self, params_c, params_f, coeffs, level):
        out = vmap(self.u_pred, (None, None, None, 0, 0, None))(params_c, params_f, coeffs, self.x_res, self.t_res, level)
        error = jnp.linalg.norm(out.ravel() - self.u_test.ravel()) / jnp.linalg.norm(self.u_test)
        return error

    @partial(jit, static_argnums=(0,))
    def update_f2c(self, params_c, params_f, coeffs_c, coeffs_f, opt_state):
        params = params_c
        posterior = jnp.sum(self.E_step(params, params_f, coeffs_f, 1), axis=-1)
        # Gate and basis parameters are fit using gradient descent
        grads = grad(self.loss_c)(params, params_f, coeffs_c, posterior)
        updates, opt_state = self.optimizer_c.update(grads, opt_state, params)
        updated_params = optax.apply_updates(params, updates)
        return updated_params, posterior, opt_state

    @partial(jit, static_argnums=(0,5))
    def blended_lstsq_fit(self, params_c, params_f, coeffs, posterior, level):
        if level == 0:
            gate_params, basis_params = params_c
            L_gate_basis = vmap(self.L_gate_basis, in_axes=(None, None, 0, 0, None))(params_c, params_f, self.x_res, self.t_res, 0)
            # u_ic_out = vmap(self.ic_func, (0,0))(self.x_res, self.t_res) * self.ic_mask
            L_basis = vmap(self.L_basis, (None, None, 0, 0, None))(params_c, params_f, self.x_res, self.t_res, 0)
            basis_out = vmap(self.basis_net, (None, None, 0, 0, None))(basis_params, params_f[1], self.x_res, self.t_res, 0)
            g_out = vmap(self.gate_net, (None, 0, 0, None))(gate_params, self.x_res, self.t_res, 0)
            # ubc_true = jnp.zeros_like(self.bc0_mask) + 1e-12

            b_m = (
                    + jnp.einsum("dm,d,d,dmn->mn", posterior, self.bc0_mask, self.u_test, basis_out) / self.sigma_bc**2
                    + jnp.einsum("dm,d,d,dmn->mn", posterior, self.bc1_mask, self.u_test, basis_out) / self.sigma_bc**2
                    + jnp.einsum("dm,d,d,dmn->mn", g_out, self.bc0_mask, self.u_test, basis_out) / self.sigma_bc**2
                    + jnp.einsum("dm,d,d,dmn->mn", g_out, self.bc1_mask, self.u_test, basis_out) / self.sigma_bc**2
                    + jnp.einsum("dm,d,dmn->mn", posterior, self.ic_correction, L_basis) / self.sigma_comp**2
                    + jnp.einsum("dmn,d->mn", L_gate_basis, self.ic_correction) / self.sigma_coop**2
                )
            b = b_m.flatten()

            m_indices, k_indices = jnp.arange(self.N_coarse), jnp.arange(self.N_coarse)
            mask_mk = (m_indices[:, None] == k_indices).astype(float)

            comp_A = jnp.einsum("dm,dmj,dnk,mn->mjnk", posterior, L_basis, L_basis, mask_mk) / self.sigma_comp**2
            coop_A = jnp.einsum("dmj,dnk->mjnk", L_gate_basis, L_gate_basis) / self.sigma_coop**2
            bc_A = (
                jnp.einsum("d,dm,dmj,mn,dnk->mjnk", self.bc0_mask, posterior, basis_out, mask_mk, basis_out) / self.sigma_bc**2
                + jnp.einsum("d,dm,dmj,mn,dnk->mjnk", self.bc1_mask, posterior, basis_out, mask_mk, basis_out) / self.sigma_bc**2
                + jnp.einsum("d,dm,dmj,dn,dnk->mjnk", self.bc1_mask, g_out, basis_out, g_out, basis_out) / self.sigma_bc**2
                + jnp.einsum("d,dm,dmj,dn,dnk->mjnk", self.bc0_mask, g_out, basis_out, g_out, basis_out) / self.sigma_bc**2
            )

            A = comp_A + coop_A + bc_A
            A = A.reshape(
                self.N_coarse * self.N_basis_c,
                self.N_coarse * self.N_basis_c,
            )
            coeffs = jnp.linalg.lstsq(A, b, rcond=1e-12)[0]
            coeffs = coeffs.reshape(
                self.N_coarse,
                self.N_basis_c,
            )
            return coeffs

        if level == 1:
            gate_params, basis_params = params_f
            L_gate_basis = vmap(self.L_gate_basis, in_axes=(None, None, 0, 0, None))(params_c, params_f, self.x_res, self.t_res, 1)
            L_basis = vmap(self.L_basis, (None, None, 0, 0, None))(params_c, params_f, self.x_res, self.t_res, 1)
            basis_out = vmap(self.basis_net, (None, None, 0, 0, None))(params_c[1], basis_params, self.x_res, self.t_res, 1)
            g_out_c = vmap(self.gate_net, (None, 0, 0, None))(params_c[0], self.x_res, self.t_res, 0)
            g_out_f = vmap(self.gate_net, (None, 0, 0, None))(gate_params, self.x_res, self.t_res, 1)
            g_out = jnp.einsum("dm,dmn->dmn", g_out_c, g_out_f)
            # ubc_true = jnp.zeros_like(self.bc0_mask) + 1e-12

            b_m = (
                    + jnp.einsum("dmn,d,d,dmni->mni", posterior, self.bc0_mask, self.u_test, basis_out) / self.sigma_bc**2
                    + jnp.einsum("dmn,d,d,dmni->mni", posterior, self.bc1_mask, self.u_test, basis_out) / self.sigma_bc**2
                    + jnp.einsum("dmn,d,d,dmni->mni", g_out, self.bc0_mask, self.u_test, basis_out) / self.sigma_bc**2
                    + jnp.einsum("dmn,d,d,dmni->mni", g_out, self.bc1_mask, self.u_test, basis_out) / self.sigma_bc**2
                    + jnp.einsum("dmn,d,dmni->mni", posterior, self.ic_correction, L_basis) / self.sigma_comp**2
                    + jnp.einsum("dmni,d->mni", L_gate_basis, self.ic_correction) / self.sigma_coop**2
                )
            b = b_m.flatten()
            m_indices, k_indices = jnp.arange(self.N_coarse), jnp.arange(self.N_coarse)
            mask_mk = (m_indices[:, None] == k_indices).astype(float)

            n_indices, l_indices = jnp.arange(self.N_fine), jnp.arange(self.N_fine)
            mask_nl = (n_indices[:, None] == l_indices).astype(float)

            comp_A = jnp.einsum("dmn,dmni,dmnj,mk,nl->mniklj", posterior, L_basis, L_basis, mask_mk, mask_nl) / self.sigma_comp**2
            coop_A = jnp.einsum("dmni,dklj->mniklj", L_gate_basis, L_gate_basis) / self.sigma_coop**2
            bc_A = (
                jnp.einsum("d,dmn,dmni,mk,nl,dmnj->mniklj", self.bc0_mask, posterior, basis_out, mask_mk, mask_nl, basis_out) / self.sigma_bc**2
                + jnp.einsum("d,dmn,dmni,mk,nl,dmnj->mniklj", self.bc1_mask, posterior, basis_out, mask_mk, mask_nl, basis_out) / self.sigma_bc**2
                + jnp.einsum("d,dmn,dmni,dkl,dklj->mniklj", self.bc1_mask, g_out, basis_out, g_out, basis_out) / self.sigma_bc**2
                + jnp.einsum("d,dmn,dmni,dkl,dklj->mniklj", self.bc0_mask, g_out, basis_out, g_out, basis_out) / self.sigma_bc**2
            )

            A = comp_A + coop_A + bc_A
            A = A.reshape((
                self.N_coarse * self.N_fine * self.N_basis_f,
                self.N_coarse * self.N_fine * self.N_basis_f,
            ))
            coeffs = jnp.linalg.lstsq(A, b, rcond=1e-12)[0]
            coeffs = coeffs.reshape((
                self.N_coarse,
                self.N_fine,
                self.N_basis_f,
            ))
            return coeffs

    @partial(jit, static_argnums=(0,5))
    def blended_lstsq_fit_iterative(self, params_c, params_f, coeffs, posterior, level, max_iter=1000, tol=1e-12, reg=1e-12, omega=1.):
        if level == 0:
            gate_params, basis_params = params_c
            L_gate_basis = vmap(self.L_gate_basis, in_axes=(None, None, 0, 0, None))(params_c, params_f, self.x_res, self.t_res, 0)
            # u_ic_out = vmap(self.ic_func, (0,0))(self.x_res, self.t_res) * self.ic_mask
            L_basis = vmap(self.L_basis, (None, None, 0, 0, None))(params_c, params_f, self.x_res, self.t_res, 0)
            basis_out = vmap(self.basis_net, (None, None, 0, 0, None))(basis_params, params_f[1], self.x_res, self.t_res, 0)
            g_out = vmap(self.gate_net, (None, 0, 0, None))(gate_params, self.x_res, self.t_res, 0)
            # ubc_true = jnp.zeros_like(self.bc0_mask) + 1e-12

            b_m = (
                    + jnp.einsum("dm,d,d,dmn->mn", posterior, self.bc0_mask, self.u_test, basis_out) / self.sigma_bc**2
                    + jnp.einsum("dm,d,d,dmn->mn", posterior, self.bc1_mask, self.u_test, basis_out) / self.sigma_bc**2
                    + jnp.einsum("dm,d,d,dmn->mn", g_out, self.bc0_mask, self.u_test, basis_out) / self.sigma_bc**2
                    + jnp.einsum("dm,d,d,dmn->mn", g_out, self.bc1_mask, self.u_test, basis_out) / self.sigma_bc**2
                    + jnp.einsum("dm,d,dmn->mn", posterior, self.ic_correction, L_basis) / self.sigma_comp**2
                    + jnp.einsum("dmn,d->mn", L_gate_basis, self.ic_correction) / self.sigma_coop**2
                )
            b = b_m.flatten()

            m_indices, k_indices = jnp.arange(self.N_coarse), jnp.arange(self.N_coarse)
            mask_mk = (m_indices[:, None] == k_indices).astype(float)

            comp_A = jnp.einsum("dm,dmj,dnk,mn->mjnk", posterior, L_basis, L_basis, mask_mk) / self.sigma_comp**2
            coop_A = jnp.einsum("dmj,dnk->mjnk", L_gate_basis, L_gate_basis) / self.sigma_coop**2
            bc_A_diag = (
                + jnp.einsum("d,dm,dmj,mn,dnk->mjnk", self.bc0_mask, posterior, basis_out, mask_mk, basis_out) / self.sigma_bc**2
                + jnp.einsum("d,dm,dmj,mn,dnk->mjnk", self.bc1_mask, posterior, basis_out, mask_mk, basis_out) / self.sigma_bc**2
            )
            bc_A_off = (
                + jnp.einsum("d,dm,dmj,dn,dnk->mjnk", self.bc1_mask, g_out, basis_out, g_out, basis_out) / self.sigma_bc**2
                + jnp.einsum("d,dm,dmj,dn,dnk->mjnk", self.bc0_mask, g_out, basis_out, g_out, basis_out) / self.sigma_bc**2
            )

            M = comp_A + bc_A_diag
            M = M.reshape(
                self.N_coarse * self.N_basis_c,
                self.N_coarse * self.N_basis_c,
            ) + reg * jnp.eye(
                self.N_coarse * self.N_basis_c,
                self.N_coarse * self.N_basis_c,
            )
            N = -(coop_A + bc_A_off)
            N = N.reshape(
                self.N_coarse * self.N_basis_c,
                self.N_coarse * self.N_basis_c,
            )
            coeffs = coeffs.ravel()
            M = jax.device_put(M, device=self.device)
            N = jax.device_put(N, device=self.device)
            b = jax.device_put(b, device=self.device)
            def body_fun(state):
                i, coeffs, converged, error, history = state
                coeffs_new = (
                    omega * jnp.linalg.solve(M, b + N @ coeffs) + (1 - omega) * coeffs
                )
                # coeffs_new = omega * jnp.linalg.lstsq(M, b + N @ coeffs, rcond=reg)[0] + (1 - omega) * coeffs
                error = jnp.linalg.norm(coeffs_new - coeffs)
                converged = error < tol
                history = history.at[i].set(error)
                return i + 1, coeffs_new, converged, error, history

            def cond_fun(state):
                i, _, converged, _, _ = state
                return (i < max_iter) & (~converged)

            history_init = jnp.zeros(max_iter)
            final_state = lax.while_loop(
                cond_fun, body_fun, (0, coeffs, False, jnp.inf, history_init)
            )
            _, coeffs, _, error, _ = final_state
            coeffs = coeffs.reshape(
                self.N_coarse,
                self.N_basis_c,
            )
            return coeffs

        if level == 1:
            gate_params, basis_params = params_f
            L_gate_basis = vmap(self.L_gate_basis, in_axes=(None, None, 0, 0, None))(params_c, params_f, self.x_res, self.t_res, 1)
            L_basis = vmap(self.L_basis, (None, None, 0, 0, None))(params_c, params_f, self.x_res, self.t_res, 1)
            basis_out = vmap(self.basis_net, (None, None, 0, 0, None))(params_c[1], basis_params, self.x_res, self.t_res, 1)
            g_out_c = vmap(self.gate_net, (None, 0, 0, None))(params_c[0], self.x_res, self.t_res, 0)
            g_out_f = vmap(self.gate_net, (None, 0, 0, None))(gate_params, self.x_res, self.t_res, 1)
            g_out = jnp.einsum("dm,dmn->dmn", g_out_c, g_out_f)
            # ubc_true = jnp.zeros_like(self.bc0_mask) + 1e-12

            b_m = (
                    + jnp.einsum("dmn,d,d,dmni->mni", posterior, self.bc0_mask, self.u_test, basis_out) / self.sigma_bc**2
                    + jnp.einsum("dmn,d,d,dmni->mni", posterior, self.bc1_mask, self.u_test, basis_out) / self.sigma_bc**2
                    + jnp.einsum("dmn,d,d,dmni->mni", g_out, self.bc0_mask, self.u_test, basis_out) / self.sigma_bc**2
                    + jnp.einsum("dmn,d,d,dmni->mni", g_out, self.bc1_mask, self.u_test, basis_out) / self.sigma_bc**2
                    + jnp.einsum("dmn,d,dmni->mni", posterior, self.ic_correction, L_basis) / self.sigma_comp**2
                    + jnp.einsum("dmni,d->mni", L_gate_basis, self.ic_correction) / self.sigma_coop**2
                )
            b = b_m.flatten()
            m_indices, k_indices = jnp.arange(self.N_coarse), jnp.arange(self.N_coarse)
            mask_mk = (m_indices[:, None] == k_indices).astype(float)

            n_indices, l_indices = jnp.arange(self.N_fine), jnp.arange(self.N_fine)
            mask_nl = (n_indices[:, None] == l_indices).astype(float)

            comp_A = jnp.einsum("dmn,dmni,dmnj,mk,nl->mniklj", posterior, L_basis, L_basis, mask_mk, mask_nl) / self.sigma_comp**2
            coop_A = jnp.einsum("dmni,dklj->mniklj", L_gate_basis, L_gate_basis) / self.sigma_coop**2
            bc_A_diag = (
                + jnp.einsum("d,dmn,dmni,mk,nl,dmnj->mniklj", self.bc0_mask, posterior, basis_out, mask_mk, mask_nl, basis_out) / self.sigma_bc**2
                + jnp.einsum("d,dmn,dmni,mk,nl,dmnj->mniklj", self.bc1_mask, posterior, basis_out, mask_mk, mask_nl, basis_out) / self.sigma_bc**2
            )
            bc_A_off = (
                + jnp.einsum("d,dmn,dmni,dkl,dklj->mniklj", self.bc1_mask, g_out, basis_out, g_out, basis_out) / self.sigma_bc**2
                + jnp.einsum("d,dmn,dmni,dkl,dklj->mniklj", self.bc0_mask, g_out, basis_out, g_out, basis_out) / self.sigma_bc**2
            )

            M = comp_A + bc_A_diag
            M = M.reshape((
                self.N_coarse * self.N_fine * self.N_basis_f,
                self.N_coarse * self.N_fine * self.N_basis_f,
            )) + jnp.eye(
                self.N_coarse * self.N_fine * self.N_basis_f,
                self.N_coarse * self.N_fine * self.N_basis_f,
            )
            N = - (coop_A + bc_A_off)
            N = N.reshape((
                self.N_coarse * self.N_fine * self.N_basis_f,
                self.N_coarse * self.N_fine * self.N_basis_f,
            ))
            coeffs = coeffs.ravel()
            M = jax.device_put(M, device=self.device)
            N = jax.device_put(N, device=self.device)
            b = jax.device_put(b, device=self.device)

            def body_fun(state):
                i, coeffs, converged, error, history = state
                coeffs_new = (
                    omega * jnp.linalg.solve(M, b + N @ coeffs) + (1 - omega) * coeffs
                )
                # coeffs_new = omega * self.lstsq(M, b + N @ coeffs)[0] + (1 - omega) * coeffs
                error = jnp.linalg.norm(coeffs_new - coeffs)
                converged = error < tol
                history = history.at[i].set(error)
                return i + 1, coeffs_new, converged, error, history

            def cond_fun(state):
                i, _, converged, _, _ = state
                return (i < max_iter) & (~converged)

            history_init = jnp.zeros(max_iter)
            final_state = lax.while_loop(
                cond_fun, body_fun, (0, coeffs, False, jnp.inf, history_init)
            )
            _, coeffs, _, error, _ = final_state
            coeffs = coeffs.reshape((
                self.N_coarse,
                self.N_fine,
                self.N_basis_f,
            ))
            return coeffs

    @partial(jit, static_argnums=(0,))
    def ema_update(self, params, avg_params):
        return optax.incremental_update(params, avg_params, step_size=0.001)

    @partial(jit, static_argnums=(0, ))
    def EM_0(self, params_c, params_f, coeffs_c, opt_state_c):
        posterior_c = self.E_step(params_c, params_f, coeffs_c, 0)
        coeffs_c = self.blended_lstsq_fit(params_c, params_f, coeffs_c, posterior_c, 0)

        grads = grad(self.loss_c)(params_c, params_f, coeffs_c, posterior_c)
        updates, opt_state_c = self.optimizer_c.update(grads, opt_state_c, params_c)
        params_c = optax.apply_updates(params_c, updates)
        return params_c, coeffs_c, opt_state_c

    @partial(jit, static_argnums=(0, ))
    def EM_1(self, params_c, params_f, coeffs_f, opt_state_f):
        posterior_f = self.E_step(params_c, params_f, coeffs_f, 1)
        coeffs_f = self.blended_lstsq_fit(params_c, params_f, coeffs_f, posterior_f, 1)

        grads = grad(self.loss_f)(params_f, params_c, coeffs_f, posterior_f)
        updates, opt_state_f = self.optimizer_f.update(grads, opt_state_f, params_f)
        params_f = optax.apply_updates(params_f, updates)
        return params_f, coeffs_f, opt_state_f

    @partial(jit, static_argnums=(0,))
    def f2c(self, params_c, params_f, coeffs_c, coeffs_f, opt_state):
        params = params_c
        posterior_f = self.E_step(params, params_f, coeffs_f, 1)
        posterior_c = jnp.sum(posterior_f, axis=-1)
        # Gate and basis parameters are fit using gradient descent
        grads = grad(self.loss_c)(params, params_f, coeffs_c, posterior_c)
        updates, opt_state = self.optimizer_c.update(grads, opt_state, params)
        updated_params = optax.apply_updates(params, updates)
        coeffs_c = self.blended_lstsq_fit(params_c, params_f, coeffs_c, posterior_c, 0)
        coeffs_f = self.blended_lstsq_fit(params_c, params_f, coeffs_f, posterior_f, 1)
        return updated_params, opt_state, coeffs_c, coeffs_f

    def train(self, nIter=10000):
        pbar = trange(nIter)
        for it in pbar:
            self.int_count = next(self.itercount)
            # Update sigma parameters
            self.sigma_comp = self.sigma_schedule["comp"](self.int_count)
            self.sigma_coop = self.sigma_schedule["coop"](self.int_count)
            self.sigma_bc = self.sigma_schedule["bc"](self.int_count)
            # Coarse Level Update
            posterior_c = self.E_step(self.params_c, self.params_f, self.coeffs_c, 0)
            self.coeffs_c = self.blended_lstsq_fit_iterative(self.params_c, self.params_f, self.coeffs_c, posterior_c, 0)

            grads = grad(self.loss_c)(self.params_c, self.params_f, self.coeffs_c, posterior_c)
            updates, self.opt_state_c = self.optimizer_c.update(grads, self.opt_state_c, self.params_c)
            self.params_c = optax.apply_updates(self.params_c, updates)
            # Fine Level Update
            posterior_f = self.E_step(self.params_c, self.params_f, self. coeffs_f, 1)
            self.coeffs_f = self.blended_lstsq_fit_iterative(self.params_c, self.params_f, self.coeffs_f, posterior_f, 1)

            grads = grad(self.loss_f)(self.params_f, self.params_c, self.coeffs_f, posterior_f)
            updates, self.opt_state_f = self.optimizer_f.update(grads, self.opt_state_f, self.params_f)
            self.params_f = optax.apply_updates(self.params_f, updates)
            # Fine to Coarse Update
            self.params_c, posterior_c, self.opt_state_c = self.update_f2c(self.params_c,self.params_f,self.coeffs_c,self.coeffs_f,self.opt_state_c)
            self.coeffs_c = self.blended_lstsq_fit_iterative(self.params_c, self.params_f, self.coeffs_c, posterior_c, 0)
            self.coeffs_f = self.blended_lstsq_fit_iterative(self.params_c, self.params_f, self.coeffs_f, posterior_f, 1)
            # Record loss and error
            error_c = self.compute_l2_error(self.params_c, self.params_f, self.coeffs_c, 0)
            error_f = self.compute_l2_error(self.params_c, self.params_f, self.coeffs_f, 1)
            loss_c = self.loss_c(self.params_c, self.params_f, self.coeffs_c, posterior_c)
            loss_f = self.loss_f(self.params_f, self.params_c, self.coeffs_f, posterior_f)

            if it % 1 == 0:
                self.l2_error_log_c.append(error_c)
                self.loss_log_c.append(loss_c)
                self.l2_error_log_f.append(error_f)
                self.loss_log_f.append(loss_f)
                pbar.set_postfix({"error_c": error_c, "error_f":error_f})
                # pbar.set_postfix({"loss_c": loss_c, "error_c": error_c})

    def train_(self, nIter=10000):
        self.params_c = jax.device_put(self.params_c, device=self.device)
        self.params_f = jax.device_put(self.params_f, device=self.device)
        self.opt_state_c = jax.device_put(self.opt_state_c, device=self.device)
        self.opt_state_f = jax.device_put(self.opt_state_f, device=self.device)
        EM_0 = self.EM_0
        EM_1 = self.EM_1
        f2c = self.f2c

        log_interval = 100
        num_logs = nIter // log_interval
        coarse_logs = {
            "l2_error_log": jnp.zeros(num_logs),
            # "loss_log": jnp.zeros(num_logs),
        }
        fine_logs = {
            "l2_error_log": jnp.zeros(num_logs),
            # "loss_log": jnp.zeros(num_logs),
        }

        def unrolled_loop_body(carry, _):
            (params_c, params_f, opt_state_c, opt_state_f, coeffs_c, coeffs_f, coarse_logs, fine_logs, int_count) = carry

            def single_iter_body(carry, _):
                (params_c, params_f, opt_state_c, opt_state_f, coeffs_c, coeffs_f, int_count) = carry
                out_0 = EM_0(params_c, params_f, coeffs_c, opt_state_c)
                params_c, coeffs_c, opt_state_c = out_0
                # Fine Level Update
                out_1 = EM_1(params_c, params_f, coeffs_f, opt_state_f)
                params_f, coeffs_f, opt_state_f = out_1
                # Fine to Coarse Update
                params_c, opt_state_c, coeffs_c, coeffs_f = f2c(params_c, params_f, coeffs_c, coeffs_f, opt_state_c)
                int_count += 1
                carry = (params_c, params_f, opt_state_c, opt_state_f, coeffs_c, coeffs_f, int_count)
                return carry, None

            reduced_carry = (params_c, params_f, opt_state_c, opt_state_f, coeffs_c, coeffs_f, int_count)
            reduced_carry, _ = jax.lax.scan(single_iter_body, reduced_carry, None, length=log_interval)
            params_c, params_f, opt_state_c, opt_state_f, coeffs_c, coeffs_f, int_count = reduced_carry

            error_c = self.compute_l2_error(params_c, params_f, coeffs_c, 0)
            error_f = self.compute_l2_error(params_c, params_f, coeffs_f, 1)
            log_idx = int_count // log_interval - 1
            coarse_logs["l2_error_log"] = coarse_logs["l2_error_log"].at[log_idx].set(error_c)
            fine_logs["l2_error_log"] = fine_logs["l2_error_log"].at[log_idx].set(error_f)
            carry = (params_c, params_f, opt_state_c, opt_state_f, coeffs_c, coeffs_f, coarse_logs, fine_logs, int_count)
            return carry, None
        carry = (
            self.params_c, self.params_f,
            self.opt_state_c, self.opt_state_f,
            self.coeffs_c, self.coeffs_f,
            coarse_logs, fine_logs, 0
        )
        num_batches = nIter // log_interval
        carry, _ = jax.lax.scan(unrolled_loop_body, carry, None, length=num_batches)
        self.params_c, self.params_f, self.opt_state_c, self.opt_state_f, self.coeffs_c, self.coeffs_f, coarse_logs, fine_logs = carry[:-1]
        coarse_logs = {k: jax.device_get(v) for k, v in coarse_logs.items()}
        fine_logs = {k: jax.device_get(v) for k, v in fine_logs.items()}

        error_f = fine_logs["l2_error_log"][-1]
        error_c = coarse_logs["l2_error_log"][-1]
        print(fine_logs["l2_error_log"])
        print({"error_f": "%1.2e" % error_f,
               "error_c": "%1.2e" % error_c})

        self.logs = {"coarse" : coarse_logs, "fine" : fine_logs}

# Usage
rng = random.PRNGKey(0)
net_setup_params = {
    "gating": {
        "coarse": {
            "n_hidden": 40,
            "n_layers": 0,
            "num_partitions": 4,
            "activation": jax.nn.tanh,
        },
        "fine": {
            "n_hidden": 20,
            "n_layers": 0,
            "num_partitions": 2,
            "activation": jax.nn.tanh,
        },
    },
    "basis": {
        "coarse": {
            "n_hidden": 10,
            "n_layers": 2,
            "basis_size": 10,
            "activation": jax.nn.tanh,
        },
        "fine": {
            "n_hidden": 10,
            "n_layers": 2,
            "basis_size": 10,
            "activation": jax.nn.tanh,
        },
    },
    "sigma_schedule": {
        "coop": optax.constant_schedule(1e-3),
        "comp": optax.constant_schedule(1e-3),
        "bc": optax.constant_schedule(1e-5),
        "ic": optax.constant_schedule(1.),
    },
}
x = jnp.linspace(0,1,100, dtype=jnp.float64)
t = jnp.linspace(0, 0.1, 100, dtype=jnp.float64)
D = 0.054
h = D * Pe
center = 0.4
width = 0.05

nIter = 50000

model = BlendedMLP_Regression(rng, net_setup_params, x, t, h, D, center, width, 1e-4, 1e-4, device)
tic = time.time()
model.train_(nIter)
toc = time.time()
elapsed = toc - tic
minutes = int(elapsed // 60)
seconds = int(elapsed % 60)
print(f"Time = {minutes:.2f}m {seconds:02d}s", flush=True)
print(f"Iterations = {nIter}", flush=True)
print(f"sigmas = {model.sigma_comp:.2e}, {model.sigma_coop:.2e}, {model.sigma_bc:.2e}", flush=True)
print(f"width = {width}", flush=True)
print(f"Ndata = {len(x) * len(t)}", flush=True)
print(f"Pe number = {Pe}", flush=True)

model.logs['coarse']['l2_error_log'][-1], model.logs['fine']['l2_error_log'][-1]

plt.semilogy(model.logs['coarse']['l2_error_log'], label="Coarse")
plt.semilogy(model.logs['fine']['l2_error_log'], label="Fine")
plt.title("L2 relative error")
plt.xlabel("Iterations")
plt.legend()
path = os.path.join(output_dir, "L2error.png")
plt.savefig(path)
print("Saved plot to outputs", flush=True)

X, T = jnp.meshgrid(x, t)
X, T = X.ravel(), T.ravel()
g_pred_c = vmap(model.gate_net, (None, 0, 0, None))(model.params_c[0], X, T, 0)
p_pred_c = vmap(model.basis_net, (None, None, 0, 0, None))(model.params_c[1], model.params_f[1], X, T, 0)
u_pred_c = vmap(model.u_pred, (None, None, None, 0, 0, None))(model.params_c, model.params_f, model.coeffs_c, X, T, 0)
L_pred_c = vmap(model.L_net, (None, None, None, 0, 0, None))(model.params_c, model.params_f, model.coeffs_c, X, T, 0)
posterior_c = model.E_step(model.params_c, model.params_f, model.coeffs_c, 0)

g_pred_f = vmap(model.gate_net, (None, 0, 0, None))(model.params_f[0], X, T, 1)
p_pred_f = vmap(model.basis_net, (None, None, 0, 0, None))(model.params_c[1], model.params_f[1], X, T, 1)
u_pred_f = vmap(model.u_pred, (None, None, None, 0, 0, None))(model.params_c, model.params_f, model.coeffs_f, X, T, 1)
L_pred_f = vmap(model.L_net, (None, None, None, 0, 0, None))(model.params_c, model.params_f, model.coeffs_f, X, T, 1)
posterior_f = model.E_step(model.params_c, model.params_f, model.coeffs_f, 1)
u = 1. / jnp.sqrt(1. + 2*D*model.t_res/width**2) * jnp.exp(- 0.5 * (X - center - h*T)**2 / (width**2 + 2*D*T))

u_ = jnp.einsum("di,dij,ijk,dijk->d", g_pred_c, g_pred_f, model.coeffs_f, p_pred_f)

plt.figure(figsize=(6, 18))
plt.subplot(3,1,1)
plt.contourf(x, t, (u_pred_f).reshape(len(x), len(t)))
plt.title("u_f Prediction")
plt.colorbar()
plt.subplot(3,1,2)
plt.contourf(x, t, (model.ic_correction - L_pred_f).reshape(len(x), len(t)))
plt.title("Residual")
plt.colorbar()
plt.subplot(3,1,3)
plt.contourf(x, t, (u.reshape(len(x), len(t)) - u_pred_f.reshape(len(x), len(t))))
plt.title("u_true - u_pred_f")
plt.colorbar()
plt.show()
path = os.path.join(output_dir, "u_pred_f.png")
plt.savefig(path)
print("Saved plot to outputs", flush=True)

plt.figure(dpi = 300, figsize=(6, 4))
plt.title("t = 0.1")
plt.plot(x, u[-len(t):], "b", label="Exact")
plt.plot(x, u_pred_c[-len(t):], "r--", label="Prediction")
plt.xlabel(r"$x$")
plt.legend()
path = os.path.join(output_dir, "final_time.png")
plt.savefig(path)
print("Saved plot to outputs", flush=True)
path = os.path.join(output_dir, "final_sol.npy")
np.save(path, u_pred_f)
print("Saved solution to outputs", flush=True)