import jax.numpy as jnp
import matplotlib.pyplot as plt
import flax.linen as nn
import jax
import optax
import os
# import orthax
import itertools
import warnings
import numpy as np

warnings.filterwarnings("ignore")

from jax import random, jit, grad, vmap, lax
from functools import partial
from tqdm import trange
from typing import Sequence, Callable
import time

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_debug_nans", True)

GPU = True
if GPU:
    devices = jax.devices('gpu')
    if len(devices) == 0:
        raise RuntimeError("No GPU devices available")
    device = devices[0]
    print("devices=", devices)
else:
    device = jax.devices('cpu')[0]
print("device=", device)

job_id = os.environ.get("SLURM_JOB_ID", "nojobid")  # fallback for local test
output_dir = f"outputs/run_{job_id}_1Dmoes"
os.makedirs(output_dir, exist_ok=True)

class BoxDense(nn.Module):
    name: str = "BoxDense"
    features: int
    activation: Callable
    depth: int
    layer: int
    arch_type: str

    @staticmethod
    def box_init(rng, shape, arch_type, depth, layer, dtype=jnp.float32):
        rng_points, rng_norms = random.split(rng, 2)
        norms = random.normal(rng_norms, shape=shape, dtype=dtype)
        norms = jnp.divide(norms, jnp.linalg.norm(norms, axis=1, keepdims=True))
        if arch_type == "plain":
            p_max = jnp.maximum(0, jnp.sign(norms))
            points = random.uniform(rng_points, shape=shape, dtype=dtype)
            k = 1. / jnp.sum((p_max - points) * norms, axis=0, keepdims=True)
        elif arch_type == "resnet":
            m = (1. + 1. / depth)**layer
            p_max = m * jnp.maximum(0, jnp.sign(norms))
            points = random.uniform(rng_points, shape=shape, dtype=dtype, minval=0., maxval=m)
            k = 1. / depth / jnp.sum((p_max - points) * norms, axis=0, keepdims=True)

        kernel = k * norms
        bias = k * jnp.sum(points * norms, axis=0)
        return kernel, bias.ravel()

    @nn.compact
    def __call__(self, x):
        if self.arch_type == "plain":
            init_fn = partial(self.box_init, arch_type="plain", depth=self.depth, layer=self.layer)
        elif self.arch_type == "resnet":
            init_fn = partial(self.box_init, arch_type="resnet", depth=self.depth, layer=self.layer)
        else:
            raise ValueError(f"Unsupported architecture type: {self.arch_type}")

        layer_weights = self.param("layer_weights", init_fn, (x.shape[-1], self.features))
        kernel, bias = layer_weights
        return self.activation(jnp.tensordot(x, kernel,(-1,0)) - bias)

def embedding(x, num_frequencies=5):
    """
    Generate sinusoidal embeddings for a single input of shape (2,).

    Args:
        x (ndarray): Input array of shape (2,).
        num_frequencies (int): Number of frequencies to use for the embedding.

    Returns:
        ndarray: Embedded features of shape (2 * 2 * num_frequencies,).
    """
    frequencies = jnp.arange(1, num_frequencies + 1)  # k values
    pi_x = jnp.pi * x[:, None]  # Shape: (2, 1)

    # Compute sin(k pi x) and cos(k pi x) for both elements
    sin_features = jnp.sin(frequencies * pi_x)  # Shape: (2, num_frequencies)
    cos_features = jnp.cos(frequencies * pi_x)  # Shape: (2, num_frequencies)

    # Concatenate sin and cos features along the last axis
    embedded_features = jnp.concatenate([sin_features, cos_features], axis=-1)  # Shape: (2, 2 * num_frequencies)

    # Flatten into a 1D array
    embedded_features = embedded_features.flatten()  # Shape: (2 * 2 * num_frequencies,)

    return embedded_features


class GatingNetwork(nn.Module):
    """
    A feedforward neural network that acts as a gating mechanism for partitioning input space.

    This network takes an input and produces a probability distribution over partitions.
    The output is softmax-normalized to ensure the probabilities sum to 1.

    Attributes:
        n_hidden (int): Number of hidden units in each layer.
        n_layers (int): Number of hidden layers in the network.
        num_partitions (int): Number of partitions (output dimension).
        activation (Callable): Activation function to use in hidden layers.
    """

    n_hidden: int
    n_layers: int
    num_partitions: int
    activation: Callable

    def setup(self):
        """
        Set up the layers of the network.

        This method initializes the hidden layers and the output layer of the network.
        """
        self.layer_0 = BoxDense(features=self.n_hidden, activation=self.activation, depth=self.n_layers, layer=0, arch_type="plain")
        self.layers = [
            BoxDense(self.n_hidden, self.activation, self.n_layers, layer+1, "resnet") for layer in range(self.n_layers)
        ]
        self.output = BoxDense(self.num_partitions, lambda x: x, self.n_layers, self.n_layers, "plain")


    @nn.compact
    def __call__(self, x):
        """
        Forward pass through the network.

        Args:
            x (jnp.ndarray): Input tensor.

        Returns:
            jnp.ndarray: Softmax-normalized output representing partition probabilities.
        """
        x = embedding(x)
        x = self.layer_0(x)
        for layer in self.layers:
            x = x + layer(x) # Residual connection
        return nn.softmax(self.output(x))


class BasisNetwork(nn.Module):
    """
    A network that returns a group of MLP basis functions, one for each partition.

    Attributes:
        n_hidden (int): Number of hidden units in each layer.
        n_layers (int): Number of hidden layers in the network.
        num_partitions (int): Number of partitions (output dimension).
        activation (Callable): Activation function to use in hidden layers.
    """

    n_hidden: int
    n_layers: int
    basis_size: int
    activation: Callable

    def setup(self):
        """
        Set up the layers of the network.

        This method initializes the hidden layers and the output layer of the network.
        """
        self.layer_0 = BoxDense(features=self.n_hidden, activation=self.activation, depth=self.n_layers, layer=0, arch_type="plain")
        self.layers = [
            BoxDense(self.n_hidden, self.activation, self.n_layers, layer+1, "resnet") for layer in range(self.n_layers)
        ]
        self.output = BoxDense(self.basis_size, lambda x: x, self.n_layers, self.n_layers, "plain")


    @nn.compact
    def __call__(self, x):
        """
        Forward pass through the network.

        Args:
            x (jnp.ndarray): Input tensor.

        Returns:
            jnp.ndarray: Softmax-normalized output representing partition probabilities.
        """
        x = embedding(x)
        x = self.layer_0(x)
        for layer in self.layers:
            x = x + layer(x) # Residual connection
        return self.output(x)

class BlendedMLP_Regression:
    def __init__(self, key, net_setup_params, x, u_test, eps, lr_c, lr_f):
        # Master random key
        self.key = key
        # Dictionary containing network setup parameters
        self.net_setup_params = net_setup_params
        # Set up the input data and the target function
        self.x = x
        self.u_test = u_test
        self.eps = eps

        self.gating_setup_params_c = net_setup_params["gating"]["coarse"]
        self.basis_setup_params_c = net_setup_params["basis"]["coarse"]
        self.gating_setup_params_f = net_setup_params["gating"]["fine"]
        self.basis_setup_params_f = net_setup_params["basis"]["fine"]
        self.sigma_schedule = net_setup_params["sigma_schedule"]
        self.N_coarse = self.gating_setup_params_c["num_partitions"]
        self.N_fine = self.gating_setup_params_f["num_partitions"]
        self.N_basis_c = self.basis_setup_params_c["basis_size"]
        self.N_basis_f = self.basis_setup_params_f["basis_size"] + self.basis_setup_params_c["basis_size"]

        gating_model_c = GatingNetwork(**self.gating_setup_params_c)
        basis_model_c = [BasisNetwork(**self.basis_setup_params_c) for _ in range(self.N_coarse)]
        self.gating_params_c = gating_model_c.init(key, jnp.ones(1,))
        self.basis_params_c = [i.init(key, jnp.ones(1,)) for i in basis_model_c]
        self.coeffs_c = [jnp.ones((self.N_basis_c)) for _ in range(self.N_coarse)]
        self.coeffs_c = jnp.asarray(self.coeffs_c)

        gating_model_f = [GatingNetwork(**self.gating_setup_params_f) for _ in range(self.N_coarse)]
        basis_model_f = [[BasisNetwork(**self.basis_setup_params_f) for _ in range(self.N_fine)]\
                         for _ in range(self.N_coarse)]
        self.gating_params_f = [i.init(key, jnp.ones(1,)) for i in gating_model_f]
        self.basis_params_f = [[j.init(key, jnp.ones(1,)) for j in i] for i in basis_model_f]
        self.coeffs_f = [[jnp.ones((self.N_basis_f)) for _ in range(self.N_fine)] for _ in range(self.N_coarse)]
        self.coeffs_f = jnp.asarray(self.coeffs_f)

        # Initialize sigma parameters for component and cooperative terms
        self.sigma_comp = self.sigma_schedule["comp"](0)
        self.sigma_coop = self.sigma_schedule["coop"](0)
        self.sigma_bc_comp = self.sigma_schedule["bc_comp"](0)
        self.sigma_bc_coop = self.sigma_schedule["bc_coop"](0)

        self.params_c = (
            self.gating_params_c,
            self.basis_params_c,
            self.coeffs_c,
            self.sigma_comp,
            self.sigma_coop,
            self.sigma_bc_comp,
            self.sigma_bc_coop,
        )
        self.avg_params_c = self.params_c

        self.params_f = (
            self.gating_params_f,
            self.basis_params_f,
            self.coeffs_f,
            self.sigma_comp,
            self.sigma_coop,
            self.sigma_bc_comp,
            self.sigma_bc_coop,
        )
        self.avg_params_f = self.params_f

        self.gate_apply_c = gating_model_c.apply
        self.basis_apply_c = [i.apply for i in basis_model_c]
        self.gate_apply_f = [i.apply for i in gating_model_f]
        self.basis_apply_f = [[j.apply for j in i] for i in basis_model_f]

        lr_c = optax.exponential_decay(lr_c, 1000, 0.9)
        lr_f = optax.exponential_decay(lr_f, 1000, 0.9)
        self.optimizer_c = optax.adam(learning_rate=lr_c)
        self.optimizer_f = optax.adam(learning_rate=lr_f)
        self.opt_state_c = self.optimizer_c.init(self.params_c)
        self.opt_state_f = self.optimizer_f.init(self.params_f)

        self.itercount = itertools.count()
        self.l2_error_log_c = []
        self.loss_log_c = []
        self.loss_bcs_log_c = []
        self.l2_error_log_f = []
        self.loss_log_f = []
        self.loss_bcs_log_f = []
        self.bc_term_log_c = []
        self.bc_term_log_f = []

    def gate_net(self, params, x, level):
        if level == 0:
            return self.gate_apply_c(params, x)
        elif level == 1:
            gate_f = [self.gate_apply_f[i](params[i], x) for i in range(len(self.gate_apply_f))]
            return jnp.asarray(gate_f)

    def basis_net(self, params_c, params_f, x, level):
        if level == 0:
            basis_out = [self.basis_apply_c[i](params_c[i], x) for i in range(len(self.basis_apply_c))]
            return jnp.asarray(basis_out)
        elif level == 1:
            basis_f = [[self.basis_apply_f[i][j](params_f[i][j], x) for j in range(len(self.basis_apply_f[i]))]\
                       for i in range(len(self.basis_apply_f))]
            basis_f = jnp.asarray(basis_f)
            basis_c = self.basis_net(params_c, params_f, x, 0)[:,jnp.newaxis,:]
            basis_c = jnp.repeat(basis_c, repeats=self.N_fine, axis=1)
            return jnp.concatenate([basis_f, basis_c], axis=-1)

    def u_net(self, params_c, params_f, x, level):
        if level == 0:
            gate_params, basis_params, coeffs, _, _, _, _ = params_c
            gate_out = self.gate_net(gate_params, x, 0).ravel()
            basis_out = self.basis_net(basis_params, params_f[1], x, 0)
            return jnp.sum(gate_out * jnp.einsum("ij,ij->i", coeffs, basis_out))
        elif level == 1:
            gate_params, basis_params, coeffs, _, _, _, _ = params_f
            gate_out_c = self.gate_net(params_c[0], x, 0).ravel()
            gate_out_f = self.gate_net(gate_params, x, 1)
            basis_out = self.basis_net(params_c[1], basis_params, x, 1)
            return jnp.sum(jnp.einsum("i,ij,ijk,ijk", gate_out_c, gate_out_f, coeffs, basis_out))

    def u_pred(self, params_c, params_f, x, level):
        """
        This function computes u[y_POU](x)
        """
        return self.u_net(params_c, params_f, x, level) * x * (1-x) + (1-x)



    @partial(jit, static_argnums=(0, 3))
    def L_net(self, params_c, params_f, level):
        """
        This function computes L[y_POU](x)
        """
        if level == 0:
            u_out_fn = lambda x: self.u_net(params_c, params_f, x, 0)
            return jnp.squeeze(vmap(jax.jacfwd(jax.jacfwd(u_out_fn)))(self.x),axis=(-1,-2))
        elif level == 1:
            u_out_fn = lambda x: self.u_net(params_c, params_f, x, 1)
            return jnp.squeeze(vmap(jax.jacfwd(jax.jacfwd(u_out_fn)))(self.x),axis=(-1,-2))

    @partial(jit, static_argnums=(0, 3))
    def D_net(self, params_c, params_f, level):
        """
        This function computes D[y_POU](x)
        """
        if level == 0:
            u_out_fn = lambda x: self.u_net(params_c, params_f, x, 0)
            return jnp.squeeze(vmap((jax.jacfwd(u_out_fn)))(self.x),axis=(-1))
        elif level == 1:
            u_out_fn = lambda x: self.u_net(params_c, params_f, x, 1)
            return jnp.squeeze(vmap((jax.jacfwd(u_out_fn)))(self.x),axis=(-1))

    @partial(jit, static_argnums=(0, 3))
    def N_net(self, params_c, params_f, level):
        """
        This function computes N[y_POU](x)
        """
        if level == 0:
            u_out = vmap(self.u_net, (None, None, 0, None))(params_c, params_f, self.x, 0)
            L_out = self.L_net(params_c, params_f, 0)
            D_out = self.D_net(params_c, params_f, 0)
            return (1-2*self.x)*u_out + (self.x-self.x**2)*D_out - self.eps*(-2*u_out + (2.-4*self.x)*D_out + (self.x-self.x**2)*L_out)

    @partial(jit, static_argnums=(0, 4))
    def L_basis(self, params_c, params_f, x, level):
        """
        This function calculates L[\phi_m](x)
        """
        if level == 0:
            _, basis_params, coeffs, _, _, _, _ = params_c
            basis_fn = lambda x: self.basis_net(basis_params, params_f[1], x, level)
            return jax.jacfwd(jax.jacfwd(basis_fn))(x)
        elif level == 1:
            _, basis_params, coeffs, _, _, _, _ = params_f
            basis_fn = lambda x: self.basis_net(params_c[1], basis_params, x, level)
            return jax.jacfwd(jax.jacfwd(basis_fn))(x)

    @partial(jit, static_argnums=(0, 4))
    def D_basis(self, params_c, params_f, x, level):
        """
        This function calculates D[\phi_m](x)
        """
        if level == 0:
            _, basis_params, coeffs, _, _, _, _ = params_c
            basis_fn = lambda x: self.basis_net(basis_params, params_f[1], x, level)
            return jax.jacfwd(basis_fn)(x)
        elif level == 1:
            _, basis_params, coeffs, _, _, _, _ = params_f
            basis_fn = lambda x: self.basis_net(params_c[1], basis_params, x, level)
            return jax.jacfwd(basis_fn)(x)

    @partial(jit, static_argnums=(0, 4))
    def N_basis(self, params_c, params_f, x, level):
        """
        This function calculates N[\phi_m](x)
        """
        if level == 0:
            _, basis_params, coeffs, _, _, _, _ = params_c
            basis_out = self.basis_net(basis_params, params_f[1], x, level)
            L_basis = jnp.squeeze(self.L_basis(params_c, params_f, x, level),axis=(-1,-2))
            D_basis = jnp.squeeze(self.D_basis(params_c, params_f, x, level),axis=-1)
            return (1-2*x)*basis_out + (x-x**2)*D_basis - self.eps*(-2*basis_out + (2.-4*x)*D_basis + (x-x**2)*L_basis)

    @partial(jit, static_argnums=(0, 4))
    def L_gate_basis(self, params_c, params_f, x, level):
        if level == 0:
            """
            Computes L[\pi_m \phi_m](x)
            """
            gate_params, basis_params, coeffs, _, _, _, _ = params_c
            gate_poly_fn = lambda x: jnp.einsum("i,ij->ij",self.gate_net(gate_params, x, 0), self.basis_net(basis_params, params_f[1], x, 0))
            return jax.jacfwd(jax.jacfwd(gate_poly_fn))(x)

        elif level == 1:
            """
            Computes L[\pi_m\pi_mn\phi_mn](x)
            """
            gate_params, basis_params, coeffs, _, _, _, _ = params_f
            def gate_poly_fn(x):
                return jnp.einsum("m,mn,mnk->mnk",
                                  self.gate_net(params_c[0], x, 0),
                                  self.gate_net(gate_params, x, 1),
                                  self.basis_net(params_c[1], basis_params, x, 1),
                                  )

            return jax.jacfwd(jax.jacfwd(gate_poly_fn))(x)

    @partial(jit, static_argnums=(0, 4))
    def D_gate_basis(self, params_c, params_f, x, level):
        if level == 0:
            """
            Computes D[\pi_m \phi_m](x)
            """
            gate_params, basis_params, coeffs, _, _, _, _ = params_c
            gate_poly_fn = lambda x: jnp.einsum("i,ij->ij",self.gate_net(gate_params, x, 0), self.basis_net(basis_params, params_f[1], x, 0))
            return jax.jacfwd(gate_poly_fn)(x)

        elif level == 1:
            """
            Computes D[\pi_m\pi_mn\phi_mn](x)
            """
            gate_params, basis_params, coeffs, _, _, _, _ = params_f
            def gate_poly_fn(x):
                return jnp.einsum("m,mn,mnk->mnk",
                                  self.gate_net(params_c[0], x, 0),
                                  self.gate_net(gate_params, x, 1),
                                  self.basis_net(params_c[1], basis_params, x, 1),
                                  )

            return jax.jacfwd(gate_poly_fn)(x)

    @partial(jit, static_argnums=(0, 4))
    def N_gate_basis(self, params_c, params_f, x, level):
        if level == 0:
            """
            Computes N[\pi_m \phi_m](x)
            """
            gate_params, basis_params, coeffs, _, _, _, _ = params_c
            gate_basis = jnp.einsum("i,ij->ij",self.gate_net(gate_params, x, 0), self.basis_net(basis_params, params_f[1], x, 0))
            L_gate_basis = jnp.squeeze(self.L_gate_basis(params_c, params_f, x, 0),axis=(-1,-2))
            D_gate_basis = jnp.squeeze(self.D_gate_basis(params_c, params_f, x, 0),axis=(-1))
            return (1-2*x)*gate_basis + (x-x**2)*D_gate_basis - self.eps*(-2*gate_basis + (2.-4*x)*D_gate_basis + (x-x**2)*L_gate_basis)

    @partial(jit, static_argnums=(0, 3))
    def E_step(self, params_c, params_f, level):
        if level == 0:
            gate_params, basis_params, coeffs, sigma_comp, sigma_coop, sigma_bc_comp, sigma_bc_coop = params_c
            gate_out = vmap(self.gate_net, (None, 0, None))(gate_params, self.x, 0)
            # basis_out = vmap(self.basis_net, (None, None, 0, None))(basis_params, params_f[1], self.x, 0)
            # u_expert = jnp.einsum("ij,dij->di",coeffs, basis_out)
            # L_basis = jnp.squeeze(vmap(self.L_basis, (None, None, 0, None))(params_c, params_f, self.x, 0), axis=(-1,-2))
            # D_basis = jnp.squeeze(vmap(self.D_basis, (None, None, 0, None))(params_c, params_f, self.x, 0), axis=(-1))
            N_basis = vmap(self.N_basis, (None, None, 0, None))(params_c, params_f, self.x, 0)

            # Component likelihood
            # lhs = (1.-2*self.x)*u_expert + (self.x - self.x**2)*jnp.einsum("ij,dij->di",coeffs,D_basis) - self.eps * (-2*u_expert + (2. - 4*self.x)*jnp.einsum("ij,dij->di",coeffs,D_basis) + \
            #                                                                                                       (self.x - self.x**2)*jnp.einsum("ij,dij->di",coeffs,D_basis))
            likelihood_comp = jnp.exp(-0.5 * ((1. - jnp.einsum("ij,dij->di",coeffs,N_basis))/ sigma_comp) ** 2)

            # BC likelihood
            # likelihood_bc = jnp.exp(-0.5 * (jnp.expand_dims(self.beta, 1) * (self.u_test - jnp.einsum("ij,dij->di",coeffs, basis_out)) / sigma_bc_comp) ** 2)

            # Computing posterior
            posterior = gate_out * likelihood_comp #* likelihood_bc
            posterior = posterior + jnp.finfo(posterior.dtype).eps
            posterior = posterior / (jnp.sum(posterior, axis=1, keepdims=True))
            return posterior

    @partial(jit, static_argnums=(0, 3))
    def loss_bcs(self, params_c, params_f, level):
        bc0 = vmap(self.u_pred, (None, None, 0, None))(params_c, params_f, self.x, level).ravel()[0]
        bc1 = vmap(self.u_pred, (None, None, 0, None))(params_c, params_f, self.x, level).ravel()[-1]
        loss_bcs = ((self.u_test.ravel()[0] - bc0)**2 + (self.u_test.ravel()[-1] - bc1)**2)
        return loss_bcs

    @partial(jit, static_argnums=(0,))
    def loss_c(self, params_c, params_f, posterior):
            gate_params, basis_params, coeffs, sigma_comp, sigma_coop, sigma_bc_comp, sigma_bc_coop = params_c
            gate_out = vmap(self.gate_net, (None, 0, None))(gate_params, self.x, 0)
            blended_pred = self.N_net(params_c, params_f, 0)
            N_basis = vmap(self.N_basis, (None, None, 0, None))(params_c, params_f, self.x, 0)

            # Component term
            component_loss = -jnp.sum(posterior * jnp.log(gate_out + jnp.finfo(gate_out.dtype).eps))

            # Cooperative term
            cooperative_loss = (
                0.5
                * jnp.sum((1. - blended_pred) ** 2)
                / sigma_coop**2
            )

            # Component term
            comp_loss = (
                0.5
                * jnp.sum(posterior * (1. - jnp.einsum("ij,dij->di",coeffs,N_basis)) ** 2)
                / sigma_comp**2
            )

            # Combine losses
            total_loss = component_loss + cooperative_loss + comp_loss
            assert total_loss.size == 1
            return total_loss

    @partial(jit, static_argnums=(0, 3))
    def compute_l2_error(self, params_c, params_f, level):
        out = vmap(self.u_pred, (None, None, 0, None))(params_c, params_f, self.x, level)
        error = jnp.linalg.norm(out.ravel() - self.u_test.ravel()) / jnp.linalg.norm(
            self.u_test.ravel()
        )
        return error

    @partial(jit, static_argnums=(0, 4))
    def EM_step(self, params_c, params_f, opt_state, level, int_count):
        if level == 0:
            # Update sigma parameters
            params = (
                params_c[0],
                params_c[1],
                params_c[2],
                self.sigma_schedule["comp"](int_count),
                self.sigma_schedule["coop"](int_count),
                self.sigma_schedule["bc_comp"](int_count),
                self.sigma_schedule["bc_coop"](int_count),
            )
            # E-step
            posterior = self.E_step(params, params_f, level)

            # M-step

            # Gate and basis parameters are fit using gradient descent
            grads = grad(self.loss_c)(params, params_f, posterior)
            updates, opt_state = self.optimizer_c.update(grads, opt_state)
            updated_params = optax.apply_updates(params, updates)

            # Polynomial parameters are fit using the blended least squares
            params = self.blended_lstsq_fit_iterative(updated_params, params_f, posterior, level, omega=1.)

            return params, opt_state
        elif level == 1:
            pass

    @partial(jit, static_argnums=(0, 4))
    def blended_lstsq_fit_iterative(self, params_c, params_f, posterior, level, max_iter=1000, tol=1e-12, reg=1e-12, omega=1.):
        if level == 0:
            gate_params, basis_params, coeffs, sigma_comp, sigma_coop, sigma_bc_comp, sigma_bc_coop = params_c
            gate_out = vmap(self.gate_net, (None, 0, None))(gate_params, self.x, 0)
            basis_out = vmap(self.basis_net, (None, None, 0, None))(basis_params, params_f[1], self.x, 0)
            N_basis = vmap(self.N_basis, (None, None, 0, None))(params_c, params_f, self.x, 0)
            N_gate_basis = vmap(self.N_gate_basis, (None, None, 0, None))(params_c, params_f, self.x, 0)

            m_indices, k_indices = jnp.arange(self.N_coarse), jnp.arange(self.N_coarse)
            mask_mk = (m_indices[:, None] == k_indices).astype(float)
            x_expanded = jnp.expand_dims(self.x, axis=-1)

            b_m = (jnp.einsum("dm,dmn->mn", posterior, N_basis) / sigma_comp**2 + jnp.einsum("dm,dmn->mn", gate_out, N_gate_basis) / sigma_coop**2)
            b = b_m.ravel()

            M = jnp.einsum("dm,dmn,dmj,mk->mnkj ",posterior, N_basis, N_basis, mask_mk) / sigma_comp**2
            M = M - jnp.einsum("mk,dkj,dmn->mnkj", mask_mk, N_gate_basis, N_gate_basis) / sigma_coop**2
            N = jnp.einsum("mk, dkj,dmn->mnkj", 1. - mask_mk, N_gate_basis, N_gate_basis) / sigma_coop**2
            N = - N

            M = M.reshape(
                self.N_coarse * self.N_basis_c,
                self.N_coarse * self.N_basis_c,
            ) + reg * jnp.eye(
                self.N_coarse * self.N_basis_c,
                self.N_coarse * self.N_basis_c,
            )
            N = N.reshape(
                self.N_coarse * self.N_basis_c,
                self.N_coarse * self.N_basis_c,
            )
            coeffs = coeffs.ravel()
            # M = jax.device_put(M, device=self.device)
            # N = jax.device_put(N, device=self.device)
            # b = jax.device_put(b, device=self.device)
            def body_fun(state):
                i, coeffs, converged, error, history = state
                coeffs_new = (
                    omega * jnp.linalg.solve(M, b + N @ coeffs) + (1 - omega) * coeffs
                )
                # coeffs_new = omega * jnp.linalg.lstsq(M, b + N @ coeffs, rcond=reg)[0] + (1 - omega) * coeffs
                error = jnp.linalg.norm(coeffs_new - coeffs)
                converged = error < tol
                history = history.at[i].set(error)
                return i + 1, coeffs_new, converged, error, history

            def cond_fun(state):
                i, _, converged, _, _ = state
                return (i < max_iter) & (~converged)

            history_init = jnp.zeros(max_iter)
            final_state = lax.while_loop(
                cond_fun, body_fun, (0, coeffs, False, jnp.inf, history_init)
            )
            i, coeffs, converged, error, history = final_state
            # coeffs = coeffs.reshape(
            #     self.N_coarse,
            #     self.N_basis_c,
            # )
            # jax.debug.print("converged: {}", converged)
            # converged_bool = jnp.bool_(converged)
            # coeffs = jnp.where(converged_bool, coeffs, jnp.linalg.lstsq(M-N, b)[0])
            coeffs = coeffs.reshape((
                self.N_coarse,
                self.N_basis_c,
            ))

            return (gate_params, basis_params, coeffs, sigma_comp, sigma_coop, sigma_bc_comp, sigma_bc_coop)
        elif level == 1:
            pass

    @partial(jit, static_argnums=(0,))
    def ema_update(self, params, avg_params):
        return optax.incremental_update(params, avg_params, step_size=0.001)

    def train(self, nIter=10000):
        pbar = trange(nIter)
        for it in pbar:
            self.int_count = next(self.itercount)
            self.params_c, self.opt_state_c = self.EM_step(self.params_c,self.params_f,self.opt_state_c,0,self.int_count)
            # self.params_f, self.opt_state_f = self.EM_step(self.params_c,self.params_f,self.opt_state_f,1,self.int_count)
            # self.params_c, self.opt_state_c = self.update_f2c(self.params_c,self.params_f,self.opt_state_c,self.int_count)
            # self.params_c = self.blended_lstsq_fit_iterative(self.params_c, self.params_f, self.E_step(self.params_c, self.params_f, 0), 0)
            # self.params_f = self.blended_lstsq_fit_iterative(self.params_c, self.params_f, self.E_step(self.params_c, self.params_f, 1), 1)
            self.avg_params_c = self.ema_update(self.params_c, self.avg_params_c)
            # self.avg_params_f = self.ema_update(self.params_f, self.avg_params_f)
            if it % 1000 == 0:
                loss_c = self.loss_c(self.params_c,self.params_f, self.E_step(self.params_c, self.params_f, 0))
                # loss_f = self.loss_f(self.params_f,self.params_c, self.E_step(self.params_c, self.params_f, 1))
                error_c = self.compute_l2_error(self.params_c, self.params_f, 0)
                # error_f = self.compute_l2_error(self.params_c, self.params_f, 1)
                loss_bcs_c = self.loss_bcs(self.params_c, self.params_f, 0)
                # loss_bcs_f = self.loss_bcs(self.params_c, self.params_f, 1)
                self.l2_error_log_c.append(error_c)
                self.loss_log_c.append(loss_c)
                self.loss_bcs_log_c.append(loss_bcs_c)
                # self.l2_error_log_f.append(error_f)
                # self.loss_log_f.append(loss_f)
                # self.loss_bcs_log_f.append(loss_bcs_f)
                # pbar.set_postfix({"loss_c": loss_c, "loss_f": loss_f, "error_c": error_c, "error_f":error_f})
                # pbar.set_postfix({"error_c": error_c, "loss_c": loss_c})

# Usage
# Usage
base_rng = random.PRNGKey(0)
net_setup_params = {
    "gating": {
        "coarse": {
            "n_hidden": 40,
            "n_layers": 1,
            "num_partitions": 4,
            "activation": jax.nn.tanh,
        },
        "fine": {
            "n_hidden": 20,
            "n_layers": 1,
            "num_partitions": 2,
            "activation": jax.nn.tanh,
        },
    },
    "basis": {
        "coarse": {
            "n_hidden": 10,
            "n_layers": 0,
            "basis_size": 10,
            "activation": jax.nn.tanh,
        },
        "fine": {
            "n_hidden": 6,
            "n_layers": 1,
            "basis_size": 6,
            "activation": jax.nn.tanh,
        },
    },
    "sigma_schedule": {
        "coop": optax.constant_schedule(1e3),
        "comp": optax.constant_schedule(1e-5),
        "bc_comp": optax.constant_schedule(1e-5),
        "bc_coop": optax.constant_schedule(1e-5),
    },
}
x = jnp.linspace(0,1,500).reshape(-1,1)
# Target function
# eps = 0.05
# u = (1. - jnp.exp(-x/eps)) / (1 - jnp.exp(-1/eps))

Pe_list = [50, 40, 30, 25, 20, 15, 10, 5, 1]
error_list = [[] for _ in range(len(Pe_list))]
for i in range(5):
    new_rng = random.fold_in(base_rng, i)
    for j in range(len(Pe_list)):
        rng = random.fold_in(new_rng, j)
        Pe = Pe_list[j]
        print("Peclet number:", Pe, flush=True)
        eps = 1 / Pe
        u = (jnp.exp(1/eps) - jnp.exp(x/eps)) / (jnp.exp(1/eps) - 1)
        model = BlendedMLP_Regression(rng, net_setup_params, x, u, eps, 5e-5, 1e-3)
        model.train(30000)

        # plt.figure(figsize=(10, 4))
        # plt.subplot(1, 2, 1)
        # plt.semilogy(model.loss_bcs_log_c)
        # # plt.semilogy(model.loss_bcs_log_f, label="BC Loss Fine")
        # plt.title("Boundary Condition Loss")
        # plt.xlabel("Iterations")
        # plt.xscale('log')
        # # plt.legend()
        # plt.subplot(1, 2, 2)
        # plt.semilogy(model.l2_error_log_c)
        # # plt.semilogy(model.l2_error_log_f)
        # plt.title("Relative L2 Error")
        # plt.xlabel("Iterations")
        # plt.xscale('log')
        # plt.show()
        # path = os.path.join(output_dir, f"losses_{Pe}.png")
        # plt.savefig(path)
        # print("Saved losses to outputs", flush=True)

        # u_pred_c = vmap(model.u_net, (None, None, 0, None))(model.params_c, model.params_f, x, 0)

        # plt.figure(figsize=(6, 4))
        # plt.plot(x, u.ravel(), "b", label="Exact")
        # plt.plot(x, u_pred_c.ravel(), "r--", label="Prediction")
        # # plt.plot(x, u_pred_f.ravel(), "g--", label="Fine")
        # # plt.title("u Prediction")
        # plt.legend()
        # plt.show()
        # path = os.path.join(output_dir, f"u_pred_c_{Pe}.png")
        # plt.savefig(path)
        # #   print("Saved u prediction to outputs", flush=True)

        # plt.figure(figsize=(6, 4))
        # plt.semilogy(x, u.ravel() - u_pred_c.ravel())
        # # plt.plot(x, u_pred_f.ravel(), "g--", label="Fine")
        # # plt.title("u Prediction")
        # plt.show()
        # path = os.path.join(output_dir, f"Residue_{Pe}.png")
        # plt.savefig(path)

        # path = os.path.join(output_dir, f"final_sol_{Pe}.npy")
        # np.save(path, u_pred_c)
        # #   print("Saved solution to outputs", flush=True)

        # path = os.path.join(output_dir, f"l2_error_{Pe}.npy")
        # np.save(path, model.l2_error_log_c)
        # #   print("Saved l2 error to outputs", flush=True)
        # print("Final L2 error:", model.l2_error_log_c[-1], flush=True)
        error_list[j].append(model.l2_error_log_c[-1])
error_list = jnp.asarray(error_list)
np.save(os.path.join(output_dir, "error_list.npy"), error_list)
mean = jnp.mean(error_list, axis=1)
std = jnp.std(error_list, axis=1)
plt.errorbar(Pe_list, mean, yerr=std, fmt='o-', capsize=5)
plt.xlabel('Pe number')
plt.ylabel('Relative L2 error')
plt.yscale('log')
plt.legend()
plt.grid(True)
plt.show()
path = os.path.join(output_dir, "error_list.png")
plt.savefig(path)