import jax.numpy as jnp
import matplotlib.pyplot as plt
import flax.linen as nn
import jax
import optax
import os
# import orthax
import itertools
import warnings
import numpy as np

warnings.filterwarnings("ignore")

from jax import random, jit, grad, vmap, lax
from functools import partial
from tqdm import trange
from typing import Sequence, Callable
import time

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_debug_nans", True)

GPU = True
if GPU:
    devices = jax.devices('gpu')
    if len(devices) == 0:
        raise RuntimeError("No GPU devices available")
    device = devices[0]
    print("devices=", devices)
else:
    device = jax.devices('cpu')[0]
print("device=", device)

job_id = os.environ.get("SLURM_JOB_ID", "nojobid")  # fallback for local test
output_dir = f"outputs/run_{job_id}_1Dpinns"
os.makedirs(output_dir, exist_ok=True)

class BoxDense(nn.Module):
    name: str = "BoxDense"
    features: int
    activation: Callable
    depth: int
    layer: int
    arch_type: str

    @staticmethod
    def box_init(rng, shape, arch_type, depth, layer, dtype=jnp.float32):
        rng_points, rng_norms = random.split(rng, 2)
        norms = random.normal(rng_norms, shape=shape, dtype=dtype)
        norms = jnp.divide(norms, jnp.linalg.norm(norms, axis=1, keepdims=True))
        if arch_type == "plain":
            p_max = jnp.maximum(0, jnp.sign(norms))
            points = random.uniform(rng_points, shape=shape, dtype=dtype)
            k = 1. / jnp.sum((p_max - points) * norms, axis=0, keepdims=True)
        elif arch_type == "resnet":
            m = (1. + 1. / depth)**layer
            p_max = m * jnp.maximum(0, jnp.sign(norms))
            points = random.uniform(rng_points, shape=shape, dtype=dtype, minval=0., maxval=m)
            k = 1. / depth / jnp.sum((p_max - points) * norms, axis=0, keepdims=True)

        kernel = k * norms
        bias = k * jnp.sum(points * norms, axis=0)
        return kernel, bias.ravel()

    @nn.compact
    def __call__(self, x):
        if self.arch_type == "plain":
            init_fn = partial(self.box_init, arch_type="plain", depth=self.depth, layer=self.layer)
        elif self.arch_type == "resnet":
            init_fn = partial(self.box_init, arch_type="resnet", depth=self.depth, layer=self.layer)
        else:
            raise ValueError(f"Unsupported architecture type: {self.arch_type}")

        layer_weights = self.param("layer_weights", init_fn, (x.shape[-1], self.features))
        kernel, bias = layer_weights
        return self.activation(jnp.tensordot(x, kernel,(-1,0)) - bias)

def embedding(x, num_frequencies=5):
    """
    Generate sinusoidal embeddings for a single input of shape (2,).

    Args:
        x (ndarray): Input array of shape (2,).
        num_frequencies (int): Number of frequencies to use for the embedding.

    Returns:
        ndarray: Embedded features of shape (2 * 2 * num_frequencies,).
    """
    frequencies = jnp.arange(1, num_frequencies + 1)  # k values
    pi_x = jnp.pi * x[:, None]  # Shape: (2, 1)

    # Compute sin(k pi x) and cos(k pi x) for both elements
    sin_features = jnp.sin(frequencies * pi_x)  # Shape: (2, num_frequencies)
    cos_features = jnp.cos(frequencies * pi_x)  # Shape: (2, num_frequencies)

    # Concatenate sin and cos features along the last axis
    embedded_features = jnp.concatenate([sin_features, cos_features], axis=-1)  # Shape: (2, 2 * num_frequencies)

    # Flatten into a 1D array
    embedded_features = embedded_features.flatten()  # Shape: (2 * 2 * num_frequencies,)

    return embedded_features


class GatingNetwork(nn.Module):
    """
    A feedforward neural network that acts as a gating mechanism for partitioning input space.

    This network takes an input and produces a probability distribution over partitions.
    The output is softmax-normalized to ensure the probabilities sum to 1.

    Attributes:
        n_hidden (int): Number of hidden units in each layer.
        n_layers (int): Number of hidden layers in the network.
        num_partitions (int): Number of partitions (output dimension).
        activation (Callable): Activation function to use in hidden layers.
    """

    n_hidden: int
    n_layers: int
    num_partitions: int
    activation: Callable

    def setup(self):
        """
        Set up the layers of the network.

        This method initializes the hidden layers and the output layer of the network.
        """
        self.layer_0 = BoxDense(features=self.n_hidden, activation=self.activation, depth=self.n_layers, layer=0, arch_type="plain")
        self.layers = [
            BoxDense(self.n_hidden, self.activation, self.n_layers, layer+1, "resnet") for layer in range(self.n_layers)
        ]
        self.output = BoxDense(self.num_partitions, lambda x: x, self.n_layers, self.n_layers, "plain")


    @nn.compact
    def __call__(self, x):
        """
        Forward pass through the network.

        Args:
            x (jnp.ndarray): Input tensor.

        Returns:
            jnp.ndarray: Softmax-normalized output representing partition probabilities.
        """
        x = embedding(x)
        x = self.layer_0(x)
        for layer in self.layers:
            x = x + layer(x) # Residual connection
        return nn.softmax(self.output(x))


class BasisNetwork(nn.Module):
    """
    A network that returns a group of MLP basis functions, one for each partition.

    Attributes:
        n_hidden (int): Number of hidden units in each layer.
        n_layers (int): Number of hidden layers in the network.
        num_partitions (int): Number of partitions (output dimension).
        activation (Callable): Activation function to use in hidden layers.
    """

    n_hidden: int
    n_layers: int
    basis_size: int
    activation: Callable

    def setup(self):
        """
        Set up the layers of the network.

        This method initializes the hidden layers and the output layer of the network.
        """
        self.layer_0 = BoxDense(features=self.n_hidden, activation=self.activation, depth=self.n_layers, layer=0, arch_type="plain")
        self.layers = [
            BoxDense(self.n_hidden, self.activation, self.n_layers, layer+1, "resnet") for layer in range(self.n_layers)
        ]
        self.output = BoxDense(1, lambda x: x, self.n_layers, self.n_layers, "plain")


    @nn.compact
    def __call__(self, x):
        """
        Forward pass through the network.

        Args:
            x (jnp.ndarray): Input tensor.

        Returns:
            jnp.ndarray: Softmax-normalized output representing partition probabilities.
        """
        x = embedding(x)
        x = self.layer_0(x)
        for layer in self.layers:
            x = x + layer(x) # Residual connection
        return self.output(x)

class BlendedMLP_Regression:
    def __init__(self, key, net_setup_params, x, u_test, eps, lr):
        # Master random key
        self.key = key
        # Dictionary containing network setup parameters
        self.net_setup_params = net_setup_params
        # Set up the input data and the target function
        self.x = x
        self.u_test = u_test
        self.eps = eps

        self.basis_setup_params = net_setup_params["basis"]
        self.sigma_schedule = net_setup_params["sigma_schedule"]

        # Initialize the gate network
        basis_model = BasisNetwork(**self.basis_setup_params)
        self.basis_params = basis_model.init(key, jnp.ones(1,))

        # BC mask
        beta = jnp.zeros_like(self.u_test.ravel())
        beta = beta.at[0].set(1.)
        self.beta = beta.at[-1].set(1.)

        # Initialize sigma parameters for component and cooperative terms
        self.sigma_comp = self.sigma_schedule["comp"](0)
        self.sigma_coop = self.sigma_schedule["coop"](0)
        self.sigma_bc_comp = self.sigma_schedule["bc_comp"](0)
        self.sigma_bc_coop = self.sigma_schedule["bc_coop"](0)

        self.params = (
            self.basis_params,
            self.sigma_comp,
            self.sigma_coop,
            self.sigma_bc_comp,
            self.sigma_bc_coop,
        )
        self.basis_apply = basis_model.apply

        lr = optax.exponential_decay(lr, 1000, 0.9)
        self.optimizer = optax.adam(learning_rate=lr)
        self.opt_state = self.optimizer.init(self.params)

        self.itercount = itertools.count()
        self.l2_error_log = []
        self.loss_log = []
        self.loss_bcs_log = []

    def u_net(self, params, x):
            basis_params, *rest = params
            basis_out = self.basis_apply(basis_params, x)
            return basis_out


    @partial(jit, static_argnums=(0,))
    def L_net(self, params):
        """
        This function computes L[y_POU](x)
        """
        u_out_fn = lambda x: self.u_net(params, x)
        return jnp.squeeze(vmap(jax.jacfwd(jax.jacfwd(u_out_fn)))(self.x),axis=(-1,-2))

    @partial(jit, static_argnums=(0,))
    def D_net(self, params):
        """
        This function computes D[y_POU](x)
        """
        u_out_fn = lambda x: self.u_net(params, x)
        return jnp.squeeze(vmap((jax.jacfwd(u_out_fn)))(self.x),axis=(-1))


    @partial(jit, static_argnums=(0,))
    def loss_bcs(self, params):
        bc0 = vmap(self.u_net, (None, 0))(params, self.x).ravel()[0]
        bc1 = vmap(self.u_net, (None, 0))(params, self.x).ravel()[-1]
        loss_bcs = ((self.u_test.ravel()[0] - bc0)**2 + (self.u_test.ravel()[-1] - bc1)**2)
        return loss_bcs

    @partial(jit, static_argnums=(0,))
    def loss(self, params):
            basis_params, sigma_comp, sigma_coop, sigma_bc_comp, sigma_bc_coop = params
            u_net = vmap(self.u_net, (None, 0))(params, self.x)
            # L[y_{POU}](x)
            blended_pred = self.D_net(params) - self.eps * self.L_net(params)
            cooperative_loss = (
                0.5
                * jnp.sum((blended_pred) ** 2)
                / sigma_coop**2
            )

            # BC term
            BC_loss_coop = 0.5 * jnp.sum((jnp.expand_dims(self.beta, 1) * (self.u_test - u_net.reshape(-1,1))) ** 2) / sigma_bc_coop**2
            # Combine losses
            total_loss = cooperative_loss + BC_loss_coop
            assert total_loss.size == 1
            return total_loss

    @partial(jit, static_argnums=(0,))
    def compute_l2_error(self, params):
        out = vmap(self.u_net, (None, 0))(params, self.x)
        error = jnp.linalg.norm(out.ravel() - self.u_test.ravel()) / jnp.linalg.norm(
            self.u_test.ravel()
        )
        return error

    @partial(jit, static_argnums=(0,))
    def EM_step(self, params, opt_state, int_count):
        # Update sigma parameters
        params = (
            params[0],
            self.sigma_schedule["comp"](int_count),
            self.sigma_schedule["coop"](int_count),
            self.sigma_schedule["bc_comp"](int_count),
            self.sigma_schedule["bc_coop"](int_count),
        )
        grads = grad(self.loss)(params)
        updates, opt_state = self.optimizer.update(grads, opt_state)
        updated_params = optax.apply_updates(params, updates)
        return updated_params, opt_state

    def train(self, nIter=10000):
        pbar = trange(nIter)
        for it in pbar:
            self.int_count = next(self.itercount)
            self.params, self.opt_state = self.EM_step(self.params, self.opt_state, self.int_count)
            if it % 50 == 0:
                loss = self.loss(self.params)
                error = self.compute_l2_error(self.params)
                loss_bcs = self.loss_bcs(self.params)
                self.l2_error_log.append(error)
                self.loss_log.append(loss)
                self.loss_bcs_log.append(loss_bcs)
                pbar.set_postfix({"error": error, "loss": loss})

# Usage
base_rng = random.PRNGKey(0)
net_setup_params = {
    "basis": {
            "n_hidden": 30,
            "n_layers": 4,
            "basis_size": 40,
            "activation": jax.nn.tanh,
    },
    "sigma_schedule": {
        "coop": optax.constant_schedule(1.),
        "comp": optax.constant_schedule(1e5),
        "bc_comp": optax.constant_schedule(1e5),
        "bc_coop": optax.constant_schedule(1e-3),
    },
}
x = jnp.linspace(0,1,500).reshape(-1,1)
# Target function
# eps = 0.05
# u = (jnp.exp(1/eps) - jnp.exp(x/eps)) / (jnp.exp(1/eps) - 1)

Pe_list = [50, 40, 30, 25, 20, 15, 10, 5, 1]
error_list = [[] for _ in range(len(Pe_list))]
for i in range(5):
    new_rng = random.fold_in(base_rng, i)
    for j in range(len(Pe_list)):
        rng = random.fold_in(new_rng, j)
        Pe = Pe_list[j]
        print("Peclet number:", Pe, flush=True)
        eps = 1 / Pe
        u = (jnp.exp(1/eps) - jnp.exp(x/eps)) / (jnp.exp(1/eps) - 1)
        model = BlendedMLP_Regression(rng, net_setup_params, x, u, eps, 1e-4)
        model.train(100000)

        # plt.figure(figsize=(10, 4))
        # plt.subplot(1, 2, 1)
        # plt.semilogy(model.loss_bcs_log,)
        # # plt.semilogy(model.loss_bcs_log_f, label="BC Loss Fine")
        # plt.title("Boundary Condition Loss")
        # plt.xlabel("Iterations")
        # plt.xscale('log')
        # plt.legend()
        # plt.subplot(1, 2, 2)
        # plt.semilogy(model.l2_error_log)
        # # plt.semilogy(model.l2_error_log_f)
        # plt.title("Relative L2 Error")
        # plt.xlabel("Iterations")
        # plt.xscale('log')
        # plt.show()
        # path = os.path.join(output_dir, f"losses_{Pe}.png")
        # plt.savefig(path)
        # #print("Saved losses to outputs", flush=True)

        # u_pred = vmap(model.u_net, (None, 0))(model.params, x)
        # plt.figure(figsize=(6, 4))
        # plt.plot(x, u.ravel(), "b", label="Exact")
        # plt.plot(x, u_pred.ravel(), "r--", label="Prediction")
        # plt.legend()
        # plt.show()

        # path = os.path.join(output_dir, f"u_pred_c_{Pe}.png")
        # plt.savefig(path)
        # #print("Saved u prediction to outputs", flush=True)

        # path = os.path.join(output_dir, f"final_sol_{Pe}.npy")
        # np.save(path, u_pred)
        # #print("Saved solution to outputs", flush=True)

        # path = os.path.join(output_dir, f"l2_error_{Pe}.npy")
        # np.save(path, model.l2_error_log)
        # #print("Saved l2 error to outputs", flush=True)
        # print("Final L2 error:", model.l2_error_log[-1], flush=True)
        error_list[j].append(model.l2_error_log[-1])
        
error_list = jnp.asarray(error_list)
np.save(os.path.join(output_dir, "error_list.npy"), error_list)
mean = jnp.mean(error_list, axis=1)
std = jnp.std(error_list, axis=1)
plt.errorbar(Pe_list, mean, yerr=std, fmt='o-', capsize=5)
plt.xlabel('Pe number')
plt.ylabel('Relative L2 error')
plt.yscale('log')
plt.legend()
plt.grid(True)
plt.show()
path = os.path.join(output_dir, "error_list.png")
plt.savefig(path)
