import jax.numpy as jnp
import flax.linen as nn
from flax import struct
from flax.training import checkpoints
import os
import orthax
import optax
import jax
from jax import lax, vmap, jit, grad, random
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from functools import partial
import itertools
from tqdm import trange
import copy
import time

from .blended_solver import BlendedLeastSquaresSolver
from .levels import LevelCoarse, LevelFine
from .custom_types import Array, Any



def compute_R_sigma(gate_out_c: jnp.ndarray,
                    gate_out_f: jnp.ndarray | None,
                    posterior: jnp.ndarray,
                    level: int) -> jnp.ndarray:
    """
    Computes the R_sigma estimate for a given level.

    Args:
        gate_out_c: Coarse gating network output.
                    Shape for level 0: (n_data, n_experts_coarse)
                    Shape for level 1: (n_data, n_experts_coarse)
        gate_out_f: Fine gating network output.
                    Shape for level 0: None
                    Shape for level 1: (n_data, n_experts_coarse, n_experts_fine)
        posterior: E-step responsibilities (w_id or w_ijd).
                   Shape for level 0: (n_data, n_experts_coarse)
                   Shape for level 1: (n_data, n_experts_coarse, n_experts_fine)
        level: 0 for coarse, 1 for fine.

    Returns:
        R_sigma_est: The estimated R_sigma value.
    """
    if gate_out_c is not None:
        dtype = gate_out_c.dtype
    elif posterior is not None:
        dtype = posterior.dtype
    else:
        dtype = jnp.float64
    eps = jnp.finfo(dtype).eps

    if level == 0:

        # pi_id = gate_out_c (n_data, n_experts_coarse)
        sum_pi_id = jnp.sum(gate_out_c, axis=0)  # Shape: (n_experts_coarse,)
        # w_id = posterior (n_data, n_experts_coarse)
        sum_w_id = jnp.sum(posterior, axis=0)    # Shape: (n_experts_coarse,)

        max_sum_pi = jnp.max(sum_pi_id)
        min_sum_w = jnp.min(sum_w_id)

        R_sigma_est = max_sum_pi / (min_sum_w + eps)

    elif level == 1:
        # pi_id (coarse gate out) = gate_out_c (n_data, n_experts_coarse)
        # pi_j|id (fine gate out, conditional) = gate_out_f (n_data, n_experts_coarse, n_experts_fine)
        # pi_ijd = pi_id * pi_j|id
        # w_ijd = posterior (n_data, n_experts_coarse, n_experts_fine)

        gate_out_c_expanded = jnp.expand_dims(gate_out_c, axis=-1) # (n_data, n_experts_coarse, 1)
        pi_ijd = gate_out_c_expanded * gate_out_f # (n_data, n_experts_coarse, n_experts_fine)

        sum_pi_ijd = jnp.sum(pi_ijd, axis=0)      # Shape: (n_experts_coarse, n_experts_fine)
        sum_w_ijd = jnp.sum(posterior, axis=0)    # Shape: (n_experts_coarse, n_experts_fine)

        max_sum_pi = jnp.max(sum_pi_ijd)
        min_sum_w = jnp.min(sum_w_ijd)

        R_sigma_est = max_sum_pi / (min_sum_w + eps)
    else:
        raise ValueError(f"Invalid level: {level}. Must be 0 or 1.")

    return R_sigma_est


@struct.dataclass
class TrainingState:
    params_c: Any
    params_f: Any
    coeffs_c: Array
    coeffs_f: Array
    opt_state_c: Any
    opt_state_f: Any


class HierarchicalPOUBase:
    def __init__(self,
                 key : Array,
                 net_setup_params : dict,
                 training_params : dict,
                 problem_params : dict,
                 mesh : Mesh,
                 replicate_sharding : NamedSharding,
                 logger=None,
                 load_state_dir : str | None = None,
                 debug=False):
        self.key = key
        net_setup_params["key"] = key
        self.net_setup_params = net_setup_params
        self.training_params = training_params
        self.problem_params = problem_params
        self.mesh = mesh

        self.replicate_sharding = replicate_sharding
        self.logger = logger
        self.load_state_dir = load_state_dir

        self.basis_type = self.net_setup_params["poly"]["coarse"]["basis_choice"]

        self.setup_data()
        self.setup_models() # init params / coeffs

        if self.load_state_dir:
            self.load_model_state(self.load_state_dir)

        self.setup_training()

    def get_training_state(self) -> TrainingState:
        """Returns the current training state packaged in a TrainingState object."""
        return TrainingState(
            params_c=self.params_c,
            params_f=self.params_f,
            coeffs_c=self.coeffs_c,
            coeffs_f=self.coeffs_f,
            opt_state_c=self.opt_state_c,
            opt_state_f=self.opt_state_f
        )


    def get_state_to_save(self) -> dict:
        """Collects parameters and coefficients into a dictionary for saving."""
        return {
            'params_c': self.params_c,
            'params_f': self.params_f,
            'coeffs_c': self.coeffs_c,
            'coeffs_f': self.coeffs_f,
            # 'opt_state_c': self.opt_state_c,
            # 'opt_state_f': self.opt_state_f,
        }

    def save_model_state(self, save_dir: str):
        """Saves the current model parameters and coefficients."""
        if jax.process_index() == 0:
            self.logger.info(f"Saving model state to {save_dir}...")
            try:
                state_to_save = self.get_state_to_save()
                state_to_save_cpu = jax.device_get(state_to_save) # mv to cpu
                checkpoints.save_checkpoint(
                    ckpt_dir=save_dir,
                    target=state_to_save_cpu,
                    step=0,
                    overwrite=True,
                )
                self.logger.info("Model state saved successfully.")
            except Exception as e:
                self.logger.error(f"Failed to save model state: {e}", exc_info=True)

        if self.mesh and len(self.mesh.devices) > 1:
            jax.block_until_ready(self.get_training_state())

    def load_model_state(self, load_dir: str):
        """Loads model parameters and coefficients from a file."""
        if jax.process_index() == 0:
            self.logger.info(f"Attempting to load model state from {load_dir}...")
        try:
            target_struct = self.get_state_to_save()

            restored_state_cpu = checkpoints.restore_checkpoint(
                 ckpt_dir=load_dir,
                 target=target_struct,
                 step=0
            )

            if restored_state_cpu:
                if jax.process_index() == 0:
                    self.logger.info("Model state loaded successfully from checkpoint.")
                restored_state_device = jax.tree_util.tree_map(
                    lambda x: jax.device_put(x, self.replicate_sharding),
                    restored_state_cpu
                )

                self.params_c = restored_state_device['params_c']
                self.params_f = restored_state_device['params_f']
                self.coeffs_c = restored_state_device['coeffs_c']
                self.coeffs_f = restored_state_device['coeffs_f']
                # if 'opt_state_c' in restored_state_device:
                #     self.opt_state_c = restored_state_device['opt_state_c']
                #     self.opt_state_f = restored_state_device['opt_state_f']

                if jax.process_index() == 0:
                    self.logger.info("Loaded state assigned to model attributes and placed on device.")
            else:
                if jax.process_index() == 0:
                     self.logger.warning(f"Checkpoint step 0 not found or failed to restore in {load_dir}. Using initial parameters.")
        except Exception as e:

            if jax.process_index() == 0:
                self.logger.error(f"Failed to load model state from {load_dir}: {e}. Using initial parameters.", exc_info=True)

    def setup_data(self):
        raise NotImplementedError("Must be implemented in subclass for specific dimensionality.")

    def setup_models(self):
        """Will likely call setup_gating_networks and setup_basis_functions"""
        raise NotImplementedError("Must be implemented in subclass for specific dimensionality.")

    def setup_gating_networks(self):
        raise NotImplementedError("Must be implemented in subclass for specific dimensionality.")

    def setup_basis_functions(self):
        raise NotImplementedError("Must be implemented in subclass for specific dimensionality.")

    def setup_training(self):


        # used for sigma_{coop} iterative solver scan
        self.test_sigma_coop_values_static = jnp.array([1e5,
                                                        5e4, 1e4,
                                                        5e3, 1e3,
                                                        7.5e2, 5e2, 2.5e2, 1e2,
                                                        7.5e1, 5e1, 2.5e1, 1e1,
                                                        5e0], dtype=self._dtype)

        if jax.process_index() == 0:
            self.logger.info(f"training_params: {self.training_params}")

        lr_c = optax.exponential_decay(self.training_params["lr"]["coarse"], 1000, 0.9)
        lr_f = optax.exponential_decay(self.training_params["lr"]["fine"], 1000, 0.9)
        self.optimizer_c = optax.adam(learning_rate=lr_c)
        self.optimizer_f = optax.adam(learning_rate=lr_f)

        self.opt_state_c = self.optimizer_c.init(self.params_c)
        self.opt_state_f = self.optimizer_f.init(self.params_f)

        self.opt_state_c = jax.device_put(self.opt_state_c, self.replicate_sharding)
        self.opt_state_f = jax.device_put(self.opt_state_f, self.replicate_sharding)

        self.coef_slv_params = self.net_setup_params.get("coef_slv_params",
                                                         {'slv_type' : 'direct',
                                                          'max_iter' : 1000,
                                                          'tol' : 1e-12,
                                                          'reg' : 1e-6,
                                                          'omega' : 1.})
        if jax.process_index() == 0:
            self.logger.info(f"Solver configuration: {self.coef_slv_params}")

        self.lstsq_solver = BlendedLeastSquaresSolver(
            net_setup_params=self.net_setup_params,
            level_c=self.level_c,
            level_f=self.level_f,
            x=self.x,
            u_test=self.u_test,
            mesh=self.mesh,
            replicate_sharding=self.replicate_sharding
        )
        self.coef_slv_params = self.lstsq_solver.coef_slv_params



class BlendedMLPRegression(HierarchicalPOUBase):

    def __init__(self,
              key : Array,
              net_setup_params : dict,
              training_params : dict,
              problem_params : dict,
              mesh : Mesh,
              replicate_sharding : NamedSharding,
              logger=None,
              load_state_dir : str | None = None, # Pass load file arg
              debug=False):

        super().__init__(key, net_setup_params, training_params, problem_params, mesh, replicate_sharding, logger, load_state_dir, debug)

    def setup_data(self):
        self._dtype = self.net_setup_params['dtype']

        self.problem_params["x"] = jax.device_put(self.problem_params["x"], self.replicate_sharding)
        self.x = self.problem_params["x"]

        self.problem_params["u_exact"] = jax.device_put(self.problem_params["u_exact"], self.replicate_sharding)
        self.u_test = self.problem_params["u_exact"]

    def setup_models(self):


        key_coarse, key_fine = random.split(self.key)
        self.level_c = LevelCoarse(key_coarse,
                                   self.net_setup_params,
                                   self.problem_params,
                                   self.replicate_sharding,
                                   logger=self.logger)
        self.level_f = LevelFine(key_fine,
                                 self.net_setup_params,
                                 self.problem_params,
                                 self.replicate_sharding,
                                 logger=self.logger)


        self.sigma_schedule = self.net_setup_params["sigma_schedule"]
        for sigma_type in ["comp", "coop"]:
            self.sigma_schedule[sigma_type] = jax.jit(self.sigma_schedule[sigma_type])

        self.params_c = (
            self.level_c.gating_params,
            self.level_c.poly_params,
        )

        self.params_f = (
            self.level_f.gating_params,
            self.level_f.poly_params,
        )

        self.coeffs_c = self.level_c.coeffs
        self.coeffs_f = self.level_f.coeffs

        self.sigma_comp = self.sigma_schedule["comp"](0)
        self.sigma_coop = self.sigma_schedule["coop"](0)

        self.avg_params_c = self.params_c
        self.avg_params_f = self.params_f


        self.params_c = jax.device_put(self.params_c, self.replicate_sharding)
        self.params_f = jax.device_put(self.params_f, self.replicate_sharding)
        self.avg_params_c = jax.device_put(self.params_c, self.replicate_sharding)
        self.avg_params_f = jax.device_put(self.params_f, self.replicate_sharding)
        self.coeffs_c = jax.device_put(self.coeffs_c, self.replicate_sharding)
        self.coeffs_f = jax.device_put(self.coeffs_f, self.replicate_sharding)



    def setup_training(self):
        super().setup_training()
        self.itercount = itertools.count()

    def u_net(self, level, params_c, params_f, coeffs, x):
        gate_net_c = self.level_c.gate_net
        gate_net_f = self.level_f.gate_net
        poly_net_c = self.level_c.basis_net_c
        poly_net_f = self.level_f.basis_net_f

        if level == 0:
            gate_params, basis_params = params_c
            gate_out = gate_net_c(gate_params, x).ravel()
            basis_out = poly_net_c(basis_params, x)
            return jnp.sum(gate_out * jnp.einsum("ij,ij->i", coeffs, basis_out))
        elif level == 1:
            gate_params, basis_params = params_f
            gate_out_c = gate_net_c(params_c[0], x).ravel()
            gate_out_f = gate_net_f(gate_params, x)
            basis_out = poly_net_f(params_c[1], basis_params, poly_net_c, x)
            return jnp.sum(jnp.einsum("i,ij,ijk,ijk", gate_out_c, gate_out_f, coeffs, basis_out))


    @partial(jit, static_argnums=(0, 1))
    def E_step(self, level, params_c, params_f, coeffs):

        gate_net_c = self.level_c.gate_net
        gate_net_f = self.level_f.gate_net
        poly_net_c = self.level_c.basis_net_c
        poly_net_f = self.level_f.basis_net_f

        if level == 0:
            gate_params, basis_params = params_c
            gate_out = vmap(gate_net_c, (None, 0))(gate_params, self.x)
            basis_out = vmap(poly_net_c, (None, 0))(basis_params, self.x)

            # Computing likelihood
            likelihood = jnp.exp(-0.5 * (self.u_test - jnp.einsum("ij,dij->di", coeffs, basis_out)) ** 2 / self.sigma_comp**2)

            # Computing posterior
            posterior = gate_out * likelihood
            posterior = posterior + jnp.finfo(posterior.dtype).eps
            posterior = posterior / (jnp.sum(posterior, axis=1, keepdims=True))
            return posterior
        elif level == 1:
            gate_params, basis_params = params_f
            gate_out_c = vmap(gate_net_c, (None, 0))(params_c[0], self.x)
            gate_out_f = vmap(gate_net_f, (None, 0))(gate_params, self.x)
            basis_out = vmap(poly_net_f, (None, None, None, 0))(params_c[1], basis_params, poly_net_c, self.x)

            # Computing likelihood
            likelihood = jnp.exp(-0.5 * (jnp.expand_dims(self.u_test, axis=(-1)) - jnp.einsum("ijk,dijk->dij",coeffs, basis_out)) ** 2 / self.sigma_comp**2)

            # Computing posterior
            posterior = jnp.einsum("di,dij,dij->dij", gate_out_c, gate_out_f, likelihood)
            posterior = posterior + jnp.finfo(posterior.dtype).eps
            posterior = posterior / (jnp.sum(posterior, axis=(-1, -2), keepdims=True))
            return posterior

    @partial(jit, static_argnums=(0,))
    def loss_gate_c(self, params_c, params_f, coeffs, posterior):
        lvl=0
        gate_net_c = self.level_c.gate_net
        basis_net_c = self.level_c.basis_net_c
        gate_params, basis_params = params_c
        gate_out = vmap(gate_net_c, (None, 0))(gate_params, self.x)
        basis_out = vmap(basis_net_c, (None, 0))(basis_params, self.x)

        # y_{POU}(x)
        blended_pred = vmap(self.u_net, (None, None, None, None, 0))(lvl, params_c, params_f, coeffs, self.x)

        # Competetivee term
        comp_loss_pi = -jnp.sum(posterior * jnp.log(gate_out + jnp.finfo(gate_out.dtype).eps))
        comp_loss_basis = (
            0.5
            * jnp.sum(posterior * (self.u_test - jnp.einsum("ij,dij->di",coeffs, basis_out)) ** 2)
            / self.sigma_comp**2
        )

        # Cooperative term
        cooperative_loss = (
            0.5
            * jnp.sum((self.u_test.ravel() - blended_pred) ** 2)
            / self.sigma_coop**2
        )

        # Combine losses
        total_loss = comp_loss_pi + comp_loss_basis + cooperative_loss
        assert total_loss.size == 1
        #loss_terms = {"comp_pi_c": comp_loss_pi, "comp_basis_c": comp_loss_basis, "coop_c": cooperative_loss}
        return total_loss

    @partial(jit, static_argnums=(0,))
    def loss_gate_f(self, params_f, params_c, coeffs, posterior):
        lvl = 1
        basis_net_c = self.level_c.basis_net_c
        gate_net_f = self.level_f.gate_net
        basis_net_f = self.level_f.basis_net_f
        gate_params, basis_params = params_f
        gate_out_f = vmap(gate_net_f, (None, 0))(gate_params, self.x)
        basis_out = vmap(basis_net_f, (None, None, None, 0))(params_c[1], basis_params, basis_net_c, self.x)

        # y_{POU}(x)
        blended_pred = vmap(self.u_net, (None, None, None, None, 0))(lvl, params_c, params_f, coeffs, self.x)

        # Competetive term
        comp_loss_pi = -(jnp.sum(posterior * jnp.log(gate_out_f + jnp.finfo(gate_out_f.dtype).eps)))
        comp_loss_basis = (
            0.5
            * jnp.sum(posterior * (jnp.expand_dims(self.u_test,axis=-1) - jnp.einsum("ijk,dijk->dij",coeffs, basis_out)) ** 2)
            / self.sigma_comp**2
        )

        # Cooperative term
        cooperative_loss = (
            0.5
            * jnp.sum((self.u_test.ravel() - blended_pred) ** 2)
            / self.sigma_coop**2
        )

        # Combine losses
        total_loss = comp_loss_pi + comp_loss_basis + cooperative_loss
        assert total_loss.size == 1
        #loss_terms = {"comp_pi_f": comp_loss_pi, "comp_basis_f": comp_loss_basis, "coop_f": cooperative_loss}
        return total_loss


    @partial(jit, static_argnums=(0, 1))
    def compute_l2_error(self, level, params_c, params_f, coeffs):
        out = vmap(self.u_net, (None, None, None, None, 0))(level, params_c, params_f, coeffs, self.x)
        error = jnp.linalg.norm(out.ravel() - self.u_test.ravel()) #/ jnp.sqrt(self.u_test.shape[0])
        return error



    @partial(jit, static_argnums=(0, 1)) # JIT if called frequently, level is static
    def get_basis_matrix_P(self, level: int, current_params_c: Any, current_params_f: Any) -> jnp.ndarray:
        """
        Computes the basis matrix P for a given level.
        P has shape (n_data, n_total_basis_functions_for_level).
        """
        poly_net_c = self.level_c.basis_net_c
        poly_net_f = self.level_f.basis_net_f

        if level == 0:
            # coarse_poly_params are params_c[1]
            basis_params_c = current_params_c[1]
            # basis_out_c has shape (n_data, n_experts_coarse, n_poly_coarse)
            basis_out_c = vmap(poly_net_c, (None, 0))(basis_params_c, self.x)

            n_data = basis_out_c.shape[0]
            # Reshape to (n_data, n_experts_coarse * n_poly_coarse)
            P_matrix = basis_out_c.reshape((n_data, -1))
            return P_matrix

        elif level == 1:
            # fine_poly_params are params_f[1]
            # coarse_poly_params for fine level are params_c[1]
            coarse_poly_params_for_fine = current_params_c[1]
            basis_params_f = current_params_f[1]

            # basis_out_f has shape (n_data, n_experts_coarse, n_experts_fine, n_poly_fine)
            basis_out_f = vmap(poly_net_f, (None, None, None, 0))(
                coarse_poly_params_for_fine, basis_params_f, poly_net_c, self.x
            )

            n_data = basis_out_f.shape[0]
            # Reshape to (n_data, n_experts_coarse * n_experts_fine * n_poly_fine)
            P_matrix = basis_out_f.reshape((n_data, -1))
            return P_matrix
        else:
            raise ValueError(f"Invalid level: {level}")


    @partial(jit, static_argnums=(0,1)) # level is static
    def check_basis_rank(self, level: int, current_params_c: Any, current_params_f: Any, tolerance: float = 1e-6):
        """
        Computes the basis matrix P for the given level and checks if it has full column rank.
        Also returns the condition number of P^T P (or P if M < N).
        JIT-compatible version.

        Returns:
            A dictionary containing:
                - 'P_shape': Shape of the basis matrix P.
                - 'rank': Computed rank of P.
                - 'is_full_column_rank': Boolean, True if P has full column rank.
                - 'n_singular_values': Total number of singular values.
                - 'singular_values_above_tol': Number of singular values > tolerance.
                - 'condition_number_P_direct': Condition number of P directly from SVD.
                - 'condition_number_P_gramian': Condition number of P^T @ P.
        """
        P_matrix = self.get_basis_matrix_P(level, current_params_c, current_params_f)
        n_rows, n_cols = P_matrix.shape

        # SVD and rank computation
        try:
            s = jnp.linalg.svd(P_matrix, compute_uv=False)
        except jnp.linalg.LinAlgError:
            num_singular_values = min(n_rows, n_cols)
            s = jnp.full(num_singular_values, jnp.nan, dtype=P_matrix.dtype)
            svd_failed_indicator = True
            return {
                'P_shape': (n_rows, n_cols), 'rank': jnp.array(-1, dtype=jnp.int32),
                'is_full_column_rank': jnp.array(False), 'n_singular_values': jnp.array(num_singular_values, dtype=jnp.int32),
                'singular_values_above_tol': jnp.array(-1, dtype=jnp.int32),
                'condition_number_P_direct': jnp.array(jnp.inf, dtype=P_matrix.dtype),
                'condition_number_P_gramian': jnp.array(jnp.inf, dtype=P_matrix.dtype),
                'svd_error': jnp.array(True) # indicator
            }
        svd_failed_indicator = False

        # Rank
        singular_values_above_tol = jnp.sum(s > tolerance)
        computed_rank = singular_values_above_tol

        is_full_col_rank = jnp.array(False)
        # This conditional logic on shapes is fine for JIT as shapes are static.
        if n_rows >= n_cols:
            is_full_col_rank = (computed_rank == n_cols)

        # Condition number of P (direct)
        s_masked_for_min = jnp.where(s > tolerance, s, jnp.inf)
        s_masked_for_max = jnp.where(s > tolerance, s, -jnp.inf)

        min_s_positive = jnp.min(s_masked_for_min)
        max_s_positive = jnp.max(s_masked_for_max)

        cond_P_direct = jnp.where(
            (min_s_positive > tolerance) & (min_s_positive != jnp.inf) & (max_s_positive != -jnp.inf),
            max_s_positive / min_s_positive,
            jnp.inf
        )

        # Condition number of P^T P
        cond_P_gramian = jnp.inf # Default to inf
        # Only makes sense if P has full column rank and min_s_positive is valid
        # The condition for P^T P to be well-conditioned requires P to have full column rank.
        can_compute_cond_gramian = (n_rows >= n_cols) & \
                                   (computed_rank == n_cols) & \
                                   (min_s_positive > tolerance) & \
                                   (min_s_positive != jnp.inf) & \
                                   (max_s_positive != -jnp.inf)

        cond_P_gramian = jnp.where(
            can_compute_cond_gramian,
            (max_s_positive / min_s_positive)**2, # cond(P)^2
            jnp.inf
        )

        return {
            'P_shape': (n_rows, n_cols),
            'singular_values' : s,
            'rank': computed_rank,
            'is_full_column_rank': is_full_col_rank,
            'n_singular_values': jnp.array(s.shape[0], dtype=jnp.int32),
            'singular_values_above_tol': singular_values_above_tol,
            'condition_number_P_direct': cond_P_direct,
            'condition_number_P_gramian': cond_P_gramian,
            'svd_error': jnp.array(svd_failed_indicator)
        }


    @partial(jit, static_argnums=(0, 1))
    def compute_kappa_B_sq_P_from_solver_Mcomp(self,
                                             level: int,
                                             current_params_c: Any,
                                             current_params_f: Any,
                                             current_posterior: Array,
                                             current_sigma_comp: Array,
                                             tolerance: float = 1e-5):
        """
        Computes kappa_B^2(P) = lambda_max(P^T B P) / lambda_min_nz(P^T B P),
        where P^T B P is derived from the M_comp_full matrix assembled by the solver,
        scaled by sigma_comp^2.

        Returns:
            kappa_B_sq (float): The computed condition number.
            eig_error_occurred (bool): True if an eigenvalue computation error occurred.
        """
        eig_error_occurred = jnp.array(False)

        # Get M_comp_full (this is P^T B P / sigma_comp^2)
        c_poly_partitions = self.net_setup_params["poly"]["coarse"]["num_partitions"]
        gate_net_c = self.level_c.gate_net
        basis_net_c = self.level_c.basis_net_c
        basis_net_f = self.level_f.basis_net_f

        if level == 0:
            # Coarse level M_comp assembly (from _assemble_iterative_monolithic_system)
            basis_out = vmap(basis_net_c, (None, 0))(current_params_c[1], self.x) # P for coarse

            m_indices = jnp.arange(c_poly_partitions)
            k_indices = jnp.arange(c_poly_partitions) # Redundant if c_poly_partitions is scalar in einsum path

            # posterior has shape (n_data, n_experts_coarse)
            # basis_out has shape (n_data, n_experts_coarse, n_poly_coarse)
            mask_mk = (m_indices[:, None] == k_indices[None, :]).astype(basis_out.dtype) if c_poly_partitions > 1 else jnp.array([[1.0]], dtype=basis_out.dtype)

            M_comp_full_unscaled = jnp.einsum("dm,dmn,dkj,mk->mnkj",
                                       current_posterior,
                                       basis_out,
                                       basis_out,
                                       mask_mk)
            # M_shape for level 0: (c_poly_partitions * c_poly_basis_size, c_poly_partitions * c_poly_basis_size)
            # n_total_basis_coeffs = c_poly_partitions * self.net_setup_params["poly"]["coarse"]["basis_size"]
            # M_shape_tuple = (n_total_basis_coeffs, n_total_basis_coeffs)

        elif level == 1:
            f_poly_partitions = self.net_setup_params["poly"]["fine"]["num_partitions"]

            # current_params_c[1] is coarse poly params for fine basis
            # current_params_f[1] is fine poly params
            basis_out = vmap(basis_net_f, (None, None, None, 0))(
                current_params_c[1], current_params_f[1], basis_net_c, self.x
            ) # P for fine

            # posterior for fine has shape (n_data, n_experts_coarse, n_experts_fine)
            # basis_out has shape (n_data, n_experts_coarse, n_experts_fine, n_poly_fine)

            m_indices = jnp.arange(c_poly_partitions)
            k_indices = jnp.arange(c_poly_partitions)
            mask_mk = (m_indices[:, None] == k_indices[None, :]).astype(basis_out.dtype) if c_poly_partitions > 1 else jnp.array([[1.0]], dtype=basis_out.dtype)

            n_indices = jnp.arange(f_poly_partitions)
            l_indices = jnp.arange(f_poly_partitions)
            mask_nl = (n_indices[:, None] == l_indices[None, :]).astype(basis_out.dtype) if f_poly_partitions > 1 else jnp.array([[1.0]], dtype=basis_out.dtype)

            M_comp_full_unscaled = jnp.einsum("dmn,dmni,dklj,mk,nl->mniklj", # Check dklj vs dmnj
                                       current_posterior,
                                       basis_out, # dmni
                                       basis_out, # dklj (implies sum over d, but k,l also expert indices)
                                                  # The original was "dmnj" which means same coarse expert m=k
                                                  # and fine expert n=l for the product.
                                                  # Corrected einsum based on original M_comp:
                                                  # M_comp = jnp.einsum("dmn,dmni,dmnj,mk,nl->mniklj", posterior, basis_out, basis_out, mask_mk, mask_nl)
                                       mask_mk,
                                       mask_nl)
            # Corrected einsum based on BlendedLeastSquaresSolver:
            M_comp_full_unscaled = jnp.einsum("dmn,dmni,dmnj,mk,nl->mniklj",
                                              current_posterior, basis_out, basis_out, mask_mk, mask_nl)


            # n_total_basis_coeffs = c_poly_partitions * f_poly_partitions * self.level_f.basis_size
            # M_shape_tuple = (n_total_basis_coeffs, n_total_basis_coeffs)
        else:
            raise ValueError(f"Invalid level: {level}")

        # PTBP_matrix is M_comp_full_unscaled, reshaped.
        # The solver adds regularization *after* this, to (M_comp - N_coop) or to D_part.
        # P^T B P before regularization.

        if level == 0:
            n_basis_coeffs_total = self.level_c.coeffs.size
        else:
            n_basis_coeffs_total = self.level_f.coeffs.size

        PTBP_matrix_flat = M_comp_full_unscaled.reshape((n_basis_coeffs_total, n_basis_coeffs_total))

        # Now, compute eigenvalues of this PTBP_matrix_flat
        try:
            eigvals = jnp.linalg.eigvalsh(PTBP_matrix_flat)
        except jnp.linalg.LinAlgError:
            # If fails, output NaNs
            eigvals = jnp.full(PTBP_matrix_flat.shape[0], jnp.nan, dtype=PTBP_matrix_flat.dtype)
            eig_error_occurred = jnp.array(True)

        eigvals_nz_bool = (eigvals > tolerance)
        num_eigvals_nz = jnp.sum(eigvals_nz_bool)

        eigvals_masked_for_min = jnp.where(eigvals_nz_bool, eigvals, jnp.inf)
        eigvals_masked_for_max = jnp.where(eigvals_nz_bool, eigvals, -jnp.inf)

        min_eig_nz = jnp.min(eigvals_masked_for_min)
        max_eig_nz = jnp.max(eigvals_masked_for_max)

        kappa_B_sq = jnp.where(
            num_eigvals_nz > 0,
            max_eig_nz / min_eig_nz,
            jnp.inf
        )
        return kappa_B_sq, eig_error_occurred, eigvals

    def _run_solver_side_tests(self,
                               level: int,
                               current_params_c: Any,
                               current_params_f: Any,
                               current_coeffs: Array,
                               current_posterior: Array,
                               current_sigma_comp: Array,
                               lstsq_solver_solve_method: callable,
                               test_sigma_coop_values: Array
                               ):
        sg_params_c = jax.tree_util.tree_map(jax.lax.stop_gradient, current_params_c)
        sg_params_f = jax.tree_util.tree_map(jax.lax.stop_gradient, current_params_f)
        sg_coeffs = jax.lax.stop_gradient(current_coeffs)
        sg_posterior = jax.lax.stop_gradient(current_posterior)
        sg_sigma_comp = jax.lax.stop_gradient(current_sigma_comp)

        # This inner function now takes the solver method as an argument
        def solve_for_one_test_sigma(solver_method_for_map: callable, test_sigma_coop_val: Array):
            sg_test_sigma_coop = jax.lax.stop_gradient(test_sigma_coop_val)
            _ignored_coeffs, test_iters = solver_method_for_map(
                level,
                sg_params_c,
                sg_params_f,
                sg_coeffs,
                sg_posterior,
                sg_sigma_comp,
                sg_test_sigma_coop
            )
            return test_iters


        def mapped_fn(test_sigma_coop_val):
            sg_test_sigma_coop = jax.lax.stop_gradient(test_sigma_coop_val)
            _ignored_coeffs, test_iters = lstsq_solver_solve_method(
                level,
                sg_params_c,
                sg_params_f,
                sg_coeffs,
                sg_posterior,
                sg_sigma_comp,
                sg_test_sigma_coop
            )
            return test_iters

        side_test_slv_iters_array = lax.map(mapped_fn, test_sigma_coop_values)
        return side_test_slv_iters_array


    @partial(jit, static_argnames=("self", "log_spectral"))
    def _train_step(self, state: TrainingState, log_spectral):
        # Unpack state
        params_c, params_f, coeffs_c, coeffs_f, opt_state_c, opt_state_f = \
            state.params_c, state.params_f, state.coeffs_c, state.coeffs_f, state.opt_state_c, state.opt_state_f

        # EM Step 0
        lvl0 = 0
        with jax.named_scope("EM0_E_step"):
            posterior0 = self.E_step(lvl0, params_c, params_f, coeffs_c)


        if log_spectral:
            with jax.named_scope("EM0_R_sigma"):
                gate_out_c_for_R0 = vmap(self.level_c.gate_net, (None, 0))(params_c[0], self.x)
                R_sigma_0 = compute_R_sigma(gate_out_c_for_R0, None, posterior0, lvl0)

            with jax.named_scope("EM0_Rank_Check"):
                rank_info_c = self.check_basis_rank(lvl0, params_c, params_f)
                kappa_B_sq_c, eig_err_c, eigvals_c = self.compute_kappa_B_sq_P_from_solver_Mcomp(
                    lvl0, params_c, params_f, posterior0, self.sigma_comp
                )

            with jax.named_scope("EM0_Solver_Side_Tests"):
                side_test_slv_iters_c_array = self._run_solver_side_tests(
                    lvl0, params_c, params_f, coeffs_c, posterior0, self.sigma_comp,
                    self.lstsq_solver.solve,
                    self.test_sigma_coop_values_static
                )

        with jax.named_scope("EM0_Solver"):
            coeffs_c_new, slv_iters_c = self.lstsq_solver.solve(
                lvl0, params_c, params_f, coeffs_c, posterior0,
                self.sigma_comp, self.sigma_coop
            )
        with jax.named_scope("EM0_Loss"):
            loss_c = self.loss_gate_c(params_c, params_f, coeffs_c_new, posterior0)
        with jax.named_scope("EM0_Grad"):
            grads_c = grad(self.loss_gate_c)(params_c, params_f, coeffs_c_new, posterior0)

        with jax.named_scope("EM0_Optimizer"):
            updates_c, opt_state_c_new = self.optimizer_c.update(grads_c, opt_state_c)
            params_c_intermediate = optax.apply_updates(params_c, updates_c)

        # EM Step 1
        lvl1 = 1
        with jax.named_scope("EM1_E_step"):
            posterior1 = self.E_step(lvl1, params_c_intermediate, params_f, coeffs_f)


        if log_spectral:
            with jax.named_scope("EM1_R_sigma"):
                gate_out_c_for_R1 = vmap(self.level_c.gate_net, (None, 0))(params_c_intermediate[0], self.x)
                gate_out_f_cond_for_R1 = vmap(self.level_f.gate_net, (None, 0))(params_f[0], self.x)
                R_sigma_1 = compute_R_sigma(gate_out_c_for_R1, gate_out_f_cond_for_R1, posterior1, lvl1)

            with jax.named_scope("EM1_Rank_Check"):
                rank_info_f = self.check_basis_rank(lvl1, params_c_intermediate, params_f)
                kappa_B_sq_f, eig_err_f, eigvals_f  = self.compute_kappa_B_sq_P_from_solver_Mcomp(
                    lvl1, params_c_intermediate, params_f, posterior1, self.sigma_comp,
                )

            with jax.named_scope("EM1_Solver_Side_Tests"):
                side_test_slv_iters_f_array = self._run_solver_side_tests(
                    lvl1, params_c_intermediate, params_f, coeffs_f, posterior1, self.sigma_comp,
                    self.lstsq_solver.solve,
                    self.test_sigma_coop_values_static
                )

        with jax.named_scope("EM1_Solver"):
            coeffs_f_new, slv_iters_f = self.lstsq_solver.solve(
                lvl1, params_c_intermediate, params_f, coeffs_f, posterior1,
                self.sigma_comp, self.sigma_coop
            )
        with jax.named_scope("EM1_Loss"):
            loss_f = self.loss_gate_f(params_f, params_c_intermediate, coeffs_f_new, posterior1)
        with jax.named_scope("EM1_Grad"):
            grads_f = grad(self.loss_gate_f)(params_f, params_c_intermediate, coeffs_f_new, posterior1)

        with jax.named_scope("EM1_Optimizer"):
            updates_f, opt_state_f_new = self.optimizer_f.update(grads_f, opt_state_f)
            params_f_new = optax.apply_updates(params_f, updates_f)

        # Update f2c
        with jax.named_scope("Update_f2c_E_step_Sum"):
            posterior_f2c = jnp.sum(self.E_step(1, params_c_intermediate, params_f_new, coeffs_f_new), axis=-1)
        with jax.named_scope("Update_f2c_Grad"):
            grads_f2c = grad(self.loss_gate_c)(params_c_intermediate, params_f_new, coeffs_c_new, posterior_f2c)
        with jax.named_scope("Update_f2c_Optimizer"):
            updates_f2c, opt_state_c_final = self.optimizer_c.update(grads_f2c, opt_state_c_new)
            params_c_final = optax.apply_updates(params_c_intermediate, updates_f2c)

        new_state = TrainingState(
            params_c=params_c_final,
            params_f=params_f_new,
            coeffs_c=coeffs_c_new,
            coeffs_f=coeffs_f_new,
            opt_state_c=opt_state_c_final,
            opt_state_f=opt_state_f_new
        )
        metrics = {
            "loss_c": loss_c, "loss_f": loss_f,
            "slv_iters_c": slv_iters_c, "slv_iters_f": slv_iters_f,
            }

        if log_spectral:
            metrics.update({
                "R_sigma_g" : self.sigma_coop**2/self.sigma_comp**2,
                "R_sigma_0": R_sigma_0, "R_sigma_1": R_sigma_1,


                "rank_c_is_full": rank_info_c['is_full_column_rank'],
                "rank_f_is_full": rank_info_f['is_full_column_rank'],

                "cond_PTP_c": rank_info_c['condition_number_P_gramian'],
                "cond_PTP_f": rank_info_f['condition_number_P_gramian'],

                "rank_c_val": rank_info_c['rank'],
                "rank_f_val": rank_info_f['rank'],

                "P_shape_c": rank_info_c['P_shape'],
                "P_shape_f": rank_info_f['P_shape'],

                "svd_error_c": rank_info_c['svd_error'],
                "svd_error_f": rank_info_f['svd_error'],

                "PTP_S_c": rank_info_c['singular_values'],
                "PTP_S_f": rank_info_f['singular_values'],

                "side_test_slv_iters_c": side_test_slv_iters_c_array,
                "side_test_slv_iters_f": side_test_slv_iters_f_array,

                "kappa_B_sq_P_f": kappa_B_sq_f,
                "kappa_B_sq_P_c": kappa_B_sq_c,

                "kappa_B_sq_P_f_eigvals_f": eigvals_f,
                "kappa_B_sq_P_f_eigvals_c": eigvals_c,
                })
        return new_state, metrics



    def train(self, gpu=False, track_solver_iters=False):
        nIter = self.training_params["nIter"]
        log_interval = self.training_params.get("log_interval", 100)
        log_spectral = self.training_params.get("log_spectral", False)

        accuracy_bail = self.training_params.get("accuracy_bail", 1e-12)

        unjit_training_step = self.training_params.get("unjit_training_step", False)

        rel_norm = jnp.linalg.norm(self.u_test.ravel())

        # Initial state
        state = self.get_training_state()

        # Logging setup
        log_data = {lvl: {"l2_error": [], "loss": [], "timings" : [],
                        } for lvl in ["coarse", "fine"]}
        if track_solver_iters:
             log_data["coarse"]["slv_iters"] = []
             log_data["fine"]["slv_iters"] = []

        if log_interval:
            for lvl in ["coarse", "fine"]:
                log_data[lvl].update({
                          "R_sigma": [], "R_sigma_star": [],
                          "side_test_slv_iters" : [],
                          "rank_is_full": [],
                            "cond_PTP": [],
                            "rank_val": [],
                            "P_shape": [],
                            "kappa_B_sq_P_f_eigvals" : [ ], "PTP_S" : []
                })

        for it in range(nIter):
            state, metrics = self._train_step(state, log_spectral)

            if track_solver_iters:
                log_data["coarse"]["slv_iters"].append(metrics["slv_iters_c"])
                log_data["fine"]["slv_iters"].append(metrics["slv_iters_f"])

            if "timings" in metrics.keys():
                log_data["coarse"]["timings"].append(metrics["timings"])
                log_data["fine"]["timings"].append(metrics["timings"])

            if it % log_interval == 0:
                error_c = self.compute_l2_error(0, state.params_c, state.params_f, state.coeffs_c)/rel_norm
                error_f = self.compute_l2_error(1, state.params_c, state.params_f, state.coeffs_f)/rel_norm
                log_data["coarse"]["l2_error"].append(error_c)
                log_data["coarse"]["loss"].append(metrics["loss_c"])
                log_data["fine"]["l2_error"].append(error_f)
                log_data["fine"]["loss"].append(metrics["loss_f"])

                if log_spectral:
                    r0,r1,rg = metrics["R_sigma_0"], metrics["R_sigma_1"], metrics["R_sigma_g"]
                    log_data["coarse"]["R_sigma_star"].append(r0)
                    log_data["coarse"]["R_sigma"].append(rg)
                    blsq_slv_iter_0, blsq_slv_iter_1 = metrics["side_test_slv_iters_c"], metrics["side_test_slv_iters_f"]
                    log_data["coarse"]["side_test_slv_iters"].append(blsq_slv_iter_0)
                    log_data["coarse"]["R_sigma"].append(metrics["R_sigma_0"])
                    log_data["coarse"]["rank_is_full"].append(metrics["rank_c_is_full"])
                    log_data["coarse"]["cond_PTP"].append(metrics["cond_PTP_c"])
                    log_data["coarse"]["rank_val"].append(metrics["rank_c_val"])
                    log_data["coarse"]["P_shape"].append(metrics["P_shape_c"])
                    log_data["coarse"]["PTP_S"].append(metrics["PTP_S_c"])
                    log_data["coarse"]["kappa_B_sq_P_f_eigvals"].append(metrics["kappa_B_sq_P_f_eigvals_c"])

                    log_data["fine"]["R_sigma_star"].append(r1)
                    log_data["fine"]["R_sigma"].append(rg)
                    log_data["fine"]["side_test_slv_iters"].append(blsq_slv_iter_1)
                    log_data["fine"]["R_sigma"].append(metrics["R_sigma_1"])
                    log_data["fine"]["rank_is_full"].append(metrics["rank_f_is_full"])
                    log_data["fine"]["cond_PTP"].append(metrics["cond_PTP_f"])
                    log_data["fine"]["rank_val"].append(metrics["rank_f_val"])
                    log_data["fine"]["P_shape"].append(metrics["P_shape_f"])
                    log_data["fine"]["kappa_B_sq_P_f_eigvals"].append(metrics["kappa_B_sq_P_f_eigvals_f"])
                    log_data["fine"]["PTP_S"].append(metrics["PTP_S_f"])

                if jax.process_index() == 0:
                    iter_info = ''
                    if track_solver_iters:
                         iter_info = f"slv_iters(c,f): {metrics['slv_iters_c']}, {metrics['slv_iters_f']}"

                    self.logger.info(f"""iters={it+1:6d}"""
                                     f"""\trel_error_f : {error_f:.2e}"""
                                     f"""\trel_error_c : {error_c:.2e}"""
                                     f"""|   {iter_info}"""
                                     )

                if error_f < accuracy_bail:
                    # desired accuracy has been reached
                    break

        # update self state
        self.params_c, self.params_f = state.params_c, state.params_f
        self.coeffs_c, self.coeffs_f = state.coeffs_c, state.coeffs_f
        self.opt_state_c, self.opt_state_f = state.opt_state_c, state.opt_state_f


        def recursive_merge(a, b):
            if isinstance(a, dict) and isinstance(b, dict):
                return {k: recursive_merge(a[k], b[k]) for k in a}
            elif isinstance(a, list) and isinstance(b, list):
                return a + b
            elif isinstance(a, jnp.ndarray) and isinstance(b, jnp.ndarray):
                return jnp.concatenate([a, b])
            else:
                return b  # fallback: overwrite

        if hasattr(self, "logs") and self.logs:
            self.logs = recursive_merge(self.logs, log_data)
        else:
            self.logs = log_data


        # Final logging
        if jax.process_index() == 0 and nIter > 0:
            final_metrics = {
                "error_f": f"{error_f:.2e}",
                "error_c": f"{error_c:.2e}",
                "loss_c": f"{metrics['loss_c']:.2e}",
                "loss_f": f"{metrics['loss_f']:.2e}"
            }
            self.logger.info("Final Metrics: " + str(final_metrics))


