import jax.numpy as jnp
import flax.linen as nn
import jax
from jax import lax, vmap, jit, grad
from jax import debug
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.experimental.shard_map import shard_map
from functools import partial
import numpy as np
import time
import collections.abc

from .levels import LevelCoarse, LevelFine
from .custom_types import Array, Any


class BlendedLeastSquaresSolver:
    def __init__(self,
                 net_setup_params: dict,
                 level_c: LevelCoarse,
                 level_f: LevelFine,
                 x: Array,
                 u_test: Array,
                 mesh: Mesh,
                 replicate_sharding: NamedSharding):
        self.net_setup_params = net_setup_params
        self.level_c = level_c
        self.level_f = level_f
        self.x = x
        self.u_test = u_test
        self.mesh = mesh
        self.replicate_sharding = replicate_sharding
        self.coef_slv_params = self.net_setup_params.get("coef_slv_params",
                                                         {'slv_type' : 'monolithic',
                                                          'max_iter' : 1000,
                                                          'tol' : 1e-12,
                                                          'reg' : 1e-6,
                                                          'omega' : 1.})
        self.max_iter = self.coef_slv_params['max_iter']
        self.tol = self.coef_slv_params['tol']
        self.reg = self.coef_slv_params['reg']
        self.omega = self.coef_slv_params['omega']

        self.mesh_axis_name = 'nodes'
        self.block_partition_spec = P(self.mesh_axis_name)
        self.replicate_partition_spec = P()

        self._define_monolithic_iterative_fns()

        slv_type = self.coef_slv_params["slv_type"]
        if slv_type == "direct":
            self.solve = self._blended_lstsq_fit_direct
            if jax.process_index() == 0: print("Using blended_lstsq_fit_direct solver.")
        elif slv_type == "iterative_monolithic":
            self.solve = self._blended_lstsq_fit_iterative
            if jax.process_index() == 0: print("Using blended_lstsq_fit_iterative (monolithic) solver.")
        elif slv_type == "iterative_block":
            self._define_iterative_solver()
            self.solve = self._blended_lstsq_fit_iterative_block
            if jax.process_index() == 0: print("Using blended_lstsq_fit_iterative_block (single-device map) solver.")
        elif slv_type == "iterative_block_distributed":
            if not isinstance(self.mesh, Mesh) or len(self.mesh.devices) <= 1 :
                if jax.process_index() == 0: print(f"WARNING: Requested 'iterative_block_distributed' ... falling back to 'iterative_block'.")
                self._define_iterative_solver()
                self.solve = self._blended_lstsq_fit_iterative_block
                self.coef_slv_params["slv_type"] = "iterative_block"
            else:
                if jax.process_index() == 0: print(f"Defining distributed solver utilities for mesh: {self.mesh.devices}...")
                self.solve = self._blended_lstsq_fit_iterative_block_distributed
                if jax.process_index() == 0: print(f"Using blended_lstsq_fit_iterative_block_distributed solver.")
        else:
            raise ValueError(f"Unknown coef_slv_params['slv_type']: {slv_type}")

    # Assembly Functions
    def _assemble_direct_system(self, level, params_c, params_f, posterior, sigma_comp, sigma_coop):
        # (Logic from _blended_lstsq_fit_direct to assemble A and b)
        c_poly_partitions = self.net_setup_params["poly"]["coarse"]["num_partitions"]
        c_poly_basis_size = self.net_setup_params["poly"]["coarse"]["basis_size"]
        gate_net_c, gate_net_f = self.level_c.gate_net, self.level_f.gate_net
        basis_net_c, basis_net_f = self.level_c.basis_net_c, self.level_f.basis_net_f

        if level == 0:
            gate_params, basis_params = params_c
            g_out = vmap(gate_net_c, (None, 0))(gate_params, self.x)
            basis_out = vmap(basis_net_c, (None, 0))(basis_params, self.x)
            b_m = (jnp.einsum("dm,dmn,d->mn", posterior, basis_out, self.u_test[:, 0]) / sigma_comp**2
                 + jnp.einsum("dm,dmn,d->mn", g_out, basis_out, self.u_test[:, 0]) / sigma_coop**2)
            b_flat = b_m.ravel()
            m_indices, k_indices = jnp.arange(c_poly_partitions), jnp.arange(c_poly_partitions)
            mask_mk = (m_indices[:, None] == k_indices).astype(float)
            gate_basis = jnp.einsum("dm,dmn->dmn", g_out, basis_out)
            M_comp = jnp.einsum("dm,dmn,dkj,mk->mnkj",posterior, basis_out, basis_out, mask_mk) / sigma_comp**2
            N_coop = - jnp.einsum("dkj,dmn->mnkj", gate_basis, gate_basis) / sigma_coop**2
            A_shape = (c_poly_partitions * c_poly_basis_size, c_poly_partitions * c_poly_basis_size,)
            c_shape = (c_poly_partitions, c_poly_basis_size,)
        elif level == 1:
            f_poly_partitions = self.net_setup_params["poly"]["fine"]["num_partitions"]
            f_poly_basis_size = self.level_f.basis_size
            gate_params, basis_params = params_f
            g_out_c = vmap(gate_net_c, (None, 0))(params_c[0], self.x)
            g_out_f = vmap(gate_net_f, (None, 0))(gate_params, self.x)
            g_out = jnp.einsum("dm,dmn->dmn", g_out_c, g_out_f)
            basis_out = vmap(basis_net_f, (None, None, None, 0))(params_c[1], basis_params, basis_net_c, self.x)
            b_mnp = (jnp.einsum("dmn,d,dmnp->mnp", posterior, self.u_test[:,0], basis_out) / sigma_comp**2
                     + jnp.einsum("dmn,d,dmnp->mnp", g_out, self.u_test[:,0], basis_out) / sigma_coop**2)
            b_flat = b_mnp.ravel()
            m_indices, k_indices = jnp.arange(c_poly_partitions), jnp.arange(c_poly_partitions)
            mask_mk = (m_indices[:, None] == k_indices).astype(float)
            n_indices, l_indices = jnp.arange(f_poly_partitions), jnp.arange(f_poly_partitions)
            mask_nl = (n_indices[:, None] == l_indices).astype(float)
            gate_basis = jnp.einsum("dmn,dmni->dmni", g_out, basis_out)
            M_comp = jnp.einsum("dmn,dmni,dmnj,mk,nl->mniklj", posterior, basis_out, basis_out, mask_mk, mask_nl) / sigma_comp**2
            N_coop = - jnp.einsum("dmni,dklj->mniklj", gate_basis, gate_basis) / sigma_coop**2
            A_shape = (c_poly_partitions * f_poly_partitions * f_poly_basis_size,
                       c_poly_partitions * f_poly_partitions * f_poly_basis_size,)
            c_shape = (c_poly_partitions, f_poly_partitions, f_poly_basis_size,)
        else: raise ValueError(f"Invalid level: {level}")


        reg = self.coef_slv_params.get('reg', 1e-6)
        A = M_comp - N_coop
        A_flat = A.reshape(A_shape) + reg * jnp.eye(A_shape[0], dtype=A.dtype)

        A_flat = jax.device_put(A_flat, self.replicate_sharding)
        b_flat = jax.device_put(b_flat, self.replicate_sharding)
        return A_flat, b_flat, c_shape

    def _assemble_iterative_monolithic_system(self, level, params_c, params_f, posterior, sigma_comp, sigma_coop):
        # (Logic from _blended_lstsq_fit_iterative to assemble D_part, R_neg_part, b_system_flat, c_shape)
        c_poly_partitions = self.net_setup_params["poly"]["coarse"]["num_partitions"]
        c_poly_basis_size = self.net_setup_params["poly"]["coarse"]["basis_size"]
        gate_net_c, gate_net_f = self.level_c.gate_net, self.level_f.gate_net
        basis_net_c, basis_net_f = self.level_c.basis_net_c, self.level_f.basis_net_f

        if level == 0:
            gate_params, basis_params= params_c
            g_out = vmap(gate_net_c, (None, 0))(gate_params, self.x)
            basis_out = vmap(basis_net_c, (None, 0))(basis_params, self.x)
            b_m = (jnp.einsum("dm,dmn,d->mn", posterior, basis_out, self.u_test[:, 0]) / sigma_comp**2
                 + jnp.einsum("dm,dmn,d->mn", g_out, basis_out, self.u_test[:, 0]) / sigma_coop**2)
            b_system_flat = b_m.ravel()
            m_indices, k_indices = jnp.arange(c_poly_partitions), jnp.arange(c_poly_partitions)
            mask_mk = (m_indices[:, None] == k_indices).astype(float)
            gate_basis = jnp.einsum("dm,dmn->dmn", g_out, basis_out)
            M_comp_full = jnp.einsum("dm,dmn,dkj,mk->mnkj",posterior, basis_out, basis_out, mask_mk) / sigma_comp**2
            N_coop_full = - jnp.einsum("dkj,dmn->mnkj", gate_basis, gate_basis) / sigma_coop**2
            M_shape = (c_poly_partitions * c_poly_basis_size, c_poly_partitions * c_poly_basis_size,)
            c_shape = (c_poly_partitions, c_poly_basis_size,)
        elif level == 1: # Fine
            f_poly_partitions = self.net_setup_params["poly"]["fine"]["num_partitions"]
            f_poly_basis_size = self.level_f.basis_size
            gate_params, basis_params = params_f
            g_out_c = vmap(gate_net_c, (None, 0))(params_c[0], self.x)
            g_out_f = vmap(gate_net_f, (None, 0))(gate_params, self.x)
            g_out = jnp.einsum("dm,dmn->dmn", g_out_c, g_out_f)
            basis_out = vmap(basis_net_f, (None, None, None, 0))(params_c[1], basis_params, basis_net_c, self.x)
            b_mnp = (jnp.einsum("dmn,d,dmnp->mnp", posterior, self.u_test[:,0], basis_out) / sigma_comp**2
                     + jnp.einsum("dmn,d,dmnp->mnp", g_out, self.u_test[:,0], basis_out) / sigma_coop**2)
            b_system_flat = b_mnp.ravel()
            m_indices, k_indices = jnp.arange(c_poly_partitions), jnp.arange(c_poly_partitions)
            mask_mk = (m_indices[:, None] == k_indices).astype(float)
            n_indices, l_indices = jnp.arange(f_poly_partitions), jnp.arange(f_poly_partitions)
            mask_nl = (n_indices[:, None] == l_indices).astype(float)
            gate_basis = jnp.einsum("dmn,dmni->dmni", g_out, basis_out)
            M_comp_full = jnp.einsum("dmn,dmni,dmnj,mk,nl->mniklj", posterior, basis_out, basis_out, mask_mk, mask_nl) / sigma_comp**2
            N_coop_full = - jnp.einsum("dmni,dklj->mniklj", gate_basis, gate_basis) / sigma_coop**2
            M_shape = (c_poly_partitions * f_poly_partitions * f_poly_basis_size, c_poly_partitions * f_poly_partitions * f_poly_basis_size,)
            c_shape = (c_poly_partitions, f_poly_partitions, f_poly_basis_size,)
        else: raise ValueError(f"Invalid level: {level}")

        D_part = M_comp_full.reshape(M_shape) + self.reg * jnp.eye(*M_shape)
        R_neg_part = N_coop_full.reshape(M_shape)

        D_part = jax.device_put(D_part, self.replicate_sharding)
        R_neg_part = jax.device_put(R_neg_part, self.replicate_sharding)
        b_system_flat = jax.device_put(b_system_flat, self.replicate_sharding)
        return (D_part, R_neg_part, b_system_flat), c_shape

    def _assemble_iterative_block_system(self, level, params_c, params_f, posterior, sigma_comp, sigma_coop):
        # (Logic from _blended_lstsq_fit_iterative_block to assemble M_diag_blocks, N_offdiag_neg_blocks, b_flat_system, coeffs_shape)
        gate_net_c, gate_net_f = self.level_c.gate_net, self.level_f.gate_net
        basis_net_c, basis_net_f = self.level_c.basis_net_c, self.level_f.basis_net_f
        N_coarse, N_basis_c = self.net_setup_params["poly"]["coarse"]["num_partitions"], self.net_setup_params["poly"]["coarse"]["basis_size"]
        N_fine, N_basis_f = self.net_setup_params["poly"]["fine"]["num_partitions"], self.level_f.basis_size

        if level == 0:
            gate_params, basis_params = params_c
            g_out = vmap(gate_net_c, (None, 0))(gate_params, self.x)
            basis_out = vmap(basis_net_c, (None, 0))(basis_params, self.x)
            b_term_sum = posterior / sigma_comp**2 + g_out / sigma_coop**2
            b_flat_system = jnp.einsum("dm,d,dmn->mn", b_term_sum, self.u_test[:, 0], basis_out)
            mask_mk = jnp.eye(N_coarse)
            gate_basis = g_out[:, :, None] * basis_out
            M_c = jnp.einsum("dm,dmn,dkj,mk->mnkj", posterior, basis_out, basis_out, mask_mk) / sigma_comp**2
            N_c = -jnp.einsum("dkj,dmn->mnkj", gate_basis, gate_basis) / sigma_coop**2
            A_sys_full = (M_c - N_c).reshape(N_coarse, N_basis_c, N_coarse, N_basis_c)
            M_diag_blocks = jnp.array([A_sys_full[i, :, i, :] for i in range(N_coarse)]) + self.reg * jnp.eye(N_basis_c)
            N_offdiag_neg_blocks = -A_sys_full.at[jnp.arange(N_coarse), :, jnp.arange(N_coarse), :].set(0)
            coeffs_shape = (N_coarse, N_basis_c)
        elif level == 1: # Fine
            gate_params, basis_params = params_f
            g_out_c = vmap(gate_net_c, (None, 0))(params_c[0], self.x)
            g_out_f = vmap(gate_net_f, (None, 0))(gate_params, self.x)
            g_out = jnp.einsum("dm,dmn->dmn", g_out_c, g_out_f)
            basis_out = vmap(basis_net_f, (None, None, None, 0))(params_c[1], basis_params, basis_net_c, self.x)
            b_term_sum = posterior / sigma_comp**2 + g_out / sigma_coop**2
            b_flat_system = jnp.einsum("dmn,dmni,d->mni", b_term_sum, basis_out, self.u_test[:, 0])
            mask_mk, mask_nl = jnp.eye(N_coarse), jnp.eye(N_fine)
            gate_basis = jnp.einsum("dmn,dmni->dmni", g_out, basis_out)
            M_c = jnp.einsum("dmn,dmni,dklj,mk,nl->mniklj", posterior, basis_out, basis_out, mask_mk, mask_nl) / sigma_comp**2
            N_c = -jnp.einsum("dmni,dklj->mniklj", gate_basis, gate_basis) / sigma_coop**2
            A_sys_full = (M_c - N_c).reshape(N_coarse,N_fine,N_basis_f, N_coarse,N_fine,N_basis_f)
            M_diag_blocks_list = [A_sys_full[i,j,:,i,j,:] for i in range(N_coarse) for j in range(N_fine)]
            M_diag_blocks = jnp.array(M_diag_blocks_list) + self.reg * jnp.eye(N_basis_f)
            N_offdiag_neg_blocks = -A_sys_full
            for i_c in range(N_coarse):
                for i_f in range(N_fine):
                    N_offdiag_neg_blocks = N_offdiag_neg_blocks.at[i_c,i_f,:,i_c,i_f,:].set(0)
            coeffs_shape = (N_coarse, N_fine, N_basis_f)
        else: raise ValueError(f"Invalid level: {level}")

        M_diag_blocks = jax.device_put(M_diag_blocks, self.replicate_sharding)
        N_offdiag_neg_blocks = jax.device_put(N_offdiag_neg_blocks, self.replicate_sharding)
        b_flat_system = jax.device_put(b_flat_system, self.replicate_sharding)
        return (M_diag_blocks, N_offdiag_neg_blocks, b_flat_system), coeffs_shape

    def _assemble_iterative_block_distributed_system(self, level, params_c, params_f, posterior, sigma_comp, sigma_coop):
        # (Logic from _blended_lstsq_fit_iterative_block_distributed to assemble M_diag_blocks_rep, N_offdiag_neg_blocks_rep, b_flat_system_rep, coeffs_shape)
        # This is very similar to _assemble_iterative_block_system, just with different variable names for clarity (e.g. _rep)
        # and ensures the reshaping for distributed einsum is correct for N_offdiag_neg_blocks_rep
        (M_diag_blocks, N_offdiag_neg_blocks_unreshaped, b_flat_system), coeffs_shape = \
            self._assemble_iterative_block_system(level, params_c, params_f, posterior, sigma_comp, sigma_coop)

        # For distributed, ensure N_offdiag_neg_blocks is shaped for the einsum in spmd_while_loop_core
        if level == 0:
            total_blocks, basis_size_block = coeffs_shape[0], coeffs_shape[1]
            N_offdiag_neg_blocks_rep = N_offdiag_neg_blocks_unreshaped # Already (Nc, Nbc, Nc, Nbc)
        else: # level == 1
            Nc, Nf, Nbf = coeffs_shape
            total_blocks = Nc * Nf
            basis_size_block = Nbf
            # N_offdiag_neg_blocks_unreshaped is (Nc,Nf,Nbf, Nc,Nf,Nbf)
            # Reshape to (total_blocks, basis, total_blocks, basis)
            N_offdiag_neg_blocks_rep = N_offdiag_neg_blocks_unreshaped.reshape(
                total_blocks, basis_size_block, total_blocks, basis_size_block
            )

        M_diag_blocks_rep = M_diag_blocks
        b_flat_system_rep = b_flat_system.reshape(total_blocks, basis_size_block) if level==1 else b_flat_system

        return (M_diag_blocks_rep, N_offdiag_neg_blocks_rep, b_flat_system_rep), coeffs_shape


    @partial(jit, static_argnums=(0, 1))
    def _blended_lstsq_fit_direct(self, level, params_c, params_f, coeffs, posterior, sigma_comp, sigma_coop):
        A_flat, b_flat, c_shape = self._assemble_direct_system(level, params_c, params_f, posterior, sigma_comp, sigma_coop)
        coeffs_sol = jnp.linalg.solve(A_flat, b_flat)
        return coeffs_sol.reshape(c_shape), 1

    @partial(jit, static_argnums=(0, 1))
    def _blended_lstsq_fit_iterative(self, level, params_c, params_f, coeffs, posterior, sigma_comp, sigma_coop):
        (D_part, R_neg_part, b_system_flat), c_shape = \
            self._assemble_iterative_monolithic_system(level, params_c, params_f, posterior, sigma_comp, sigma_coop)

        coeffs_flat_init = coeffs.ravel()
        coeffs_flat_init = jax.device_put(coeffs_flat_init, self.replicate_sharding)

        history_init = jnp.zeros(self.max_iter)
        # state: i, coeffs_curr, converged, error, history_curr, (D_part, R_neg_part, b_system_flat)
        initial_state = (0, coeffs_flat_init, False, jnp.inf, history_init, (D_part, R_neg_part, b_system_flat))

        final_state = jax.lax.while_loop(self._cond_fun_iter_mono, self._body_fun_iter_mono, initial_state)
        iters, coeffs_final_flat, _, _, _, _ = final_state
        return coeffs_final_flat.reshape(c_shape), iters

    @partial(jit, static_argnums=(0, 1))
    def _blended_lstsq_fit_iterative_block(self, level: int, params_c: tuple, params_f: tuple, coeffs: Array, posterior: Array, sigma_comp: float, sigma_coop: float):
        (M_diag_blocks, N_offdiag_neg_blocks, b_flat_system), coeffs_shape = \
            self._assemble_iterative_block_system(level, params_c, params_f, posterior, sigma_comp, sigma_coop)

        cond_fun, body_fun = self.iterative_solver_coarse if level == 0 else self.iterative_solver_fine

        coeffs_curr_flat = coeffs.reshape(coeffs_shape) # multi-dim block structure
        coeffs_curr_flat = jax.device_put(coeffs_curr_flat, self.replicate_sharding)

        history_init = jnp.zeros(self.max_iter, dtype=coeffs.dtype)
        # state: i, M_diags, N_offdiag_neg, b_vec, coeffs_curr, converged, error, history
        initial_state = (0, M_diag_blocks, N_offdiag_neg_blocks, b_flat_system, coeffs_curr_flat, False, jnp.inf, history_init)

        final_state = lax.while_loop(cond_fun, body_fun, initial_state)
        return final_state[4].reshape(coeffs_shape), final_state[0]

    def _blended_lstsq_fit_iterative_block_distributed(self, level: int, params_c: tuple, params_f: tuple, coeffs: Array, posterior: Array, sigma_comp: float, sigma_coop: float):
        (M_diag_blocks_rep, N_offdiag_neg_blocks_rep, b_flat_system_rep), coeffs_shape = \
            self._assemble_iterative_block_distributed_system(level, params_c, params_f, posterior, sigma_comp, sigma_coop)

        total_blocks = M_diag_blocks_rep.shape[0]
        coeffs_flat_init_rep = coeffs.reshape(total_blocks, -1) # (total_blocks, basis_size_block)
        coeffs_flat_init_rep = jax.device_put(coeffs_flat_init_rep, self.replicate_sharding)
        history_init_rep = jnp.zeros(self.max_iter, dtype=coeffs_flat_init_rep.dtype)
        history_init_rep = jax.device_put(history_init_rep, self.replicate_sharding)

        # spmd_while_loop_core is the solve kernel here. It's called once by shard_map
        def solve_one_instance_spmd( M_dr, N_cr, b_fr, c_ir, hist_ir):
            num_devices = lax.psum(1, axis_name=self.mesh_axis_name)
            device_id = lax.axis_index(self.mesh_axis_name)
            glob_total_blocks = M_dr.shape[0]
            blocks_per_dev = glob_total_blocks // num_devices
            start_idx = device_id * blocks_per_dev

            M_block_shard_arg = lax.dynamic_slice_in_dim(M_dr, start_idx, blocks_per_dev, axis=0)
            coeffs_curr_local_arg = lax.dynamic_slice_in_dim(c_ir, start_idx, blocks_per_dev, axis=0)
            b_local_shard_arg = lax.dynamic_slice_in_dim(b_fr, start_idx, blocks_per_dev, axis=0)

            use_vmap_local_spmd = self.coef_slv_params.get("use_vmap", False)
            solve_fn_spmd = jax.vmap(lambda M,b: jnp.linalg.solve(M,b)) if use_vmap_local_spmd else \
                            lambda M_s,b_s: lax.map(lambda i_m: jnp.linalg.solve(M_s[i_m],b_s[i_m]), jnp.arange(M_s.shape[0]))

            def cond_fun_spmd(st): return (st[0] < self.max_iter) & jnp.logical_not(st[2])
            def body_fun_spmd(st):
                i_s, c_curr_loc_s, _, _, h_curr_s, _ = st
                c_curr_gath_s = lax.all_gather(c_curr_loc_s, axis_name=self.mesh_axis_name, tiled=True)
                coupling_term_global_s = jnp.einsum("ikjl,jl->ik", N_cr, c_curr_gath_s)
                coupling_term_local_s = lax.dynamic_slice_in_dim(coupling_term_global_s, start_idx, blocks_per_dev, axis=0)
                rhs_local_s = b_local_shard_arg + coupling_term_local_s
                c_hat_loc_s = solve_fn_spmd(M_block_shard_arg, rhs_local_s)
                c_next_loc_s = self.omega * c_hat_loc_s + (1 - self.omega) * c_curr_loc_s
                diff_sq_loc_s = jnp.sum((c_next_loc_s - c_curr_loc_s)**2)
                err_sq_glob_s = lax.psum(diff_sq_loc_s, axis_name=self.mesh_axis_name)
                err_next_s = jnp.sqrt(err_sq_glob_s)
                conv_next_s = err_next_s < self.tol
                h_new_s = h_curr_s.at[i_s].set(err_next_s)
                return (i_s + 1, c_next_loc_s, conv_next_s, err_next_s, h_new_s, num_devices)

            init_st_spmd = (0, coeffs_curr_local_arg, False, jnp.inf, hist_ir, num_devices)
            final_st_spmd = lax.while_loop(cond_fun_spmd, body_fun_spmd, init_st_spmd)
            final_c_loc_s, final_i_s = final_st_spmd[1], final_st_spmd[0]
            final_c_gath_s = lax.all_gather(final_c_loc_s, axis_name=self.mesh_axis_name, tiled=True)
            return final_c_gath_s, final_i_s

        in_specs_map = (self.replicate_partition_spec,) * 5
        out_specs_map = (self.replicate_partition_spec,) * 2

        coeffs_res_flat_rep, iters_rep = shard_map(
            solve_one_instance_spmd, mesh=self.mesh, in_specs=in_specs_map, out_specs=out_specs_map, check_rep=False
        )(M_diag_blocks_rep, N_offdiag_neg_blocks_rep, b_flat_system_rep, coeffs_flat_init_rep, history_init_rep)

        return coeffs_res_flat_rep.reshape(coeffs_shape), iters_rep

    def _define_monolithic_iterative_fns(self):
        def cond_fun_iter_mono(state):
            i, _, converged, _, _, _ = state
            return (i < self.max_iter) & (~converged)

        # state: i, coeffs_curr, converged, error, history_curr, (D_part, R_neg_part, b_system_flat)
        def body_fun_iter_mono(state):
            i, coeffs_curr, _, _, history_curr, matrices_tuple = state
            D_part, R_neg_part, b_system_flat = matrices_tuple

            rhs = b_system_flat + R_neg_part @ coeffs_curr
            coeffs_hat = jnp.linalg.solve(D_part, rhs)
            coeffs_new = self.omega * coeffs_hat + (1 - self.omega) * coeffs_curr

            error = jnp.linalg.norm(coeffs_new - coeffs_curr)
            converged = error < self.tol
            history_new = history_curr.at[i].set(error)
            return i + 1, coeffs_new, converged, error, history_new, matrices_tuple

        self._cond_fun_iter_mono = cond_fun_iter_mono
        self._body_fun_iter_mono = jit(body_fun_iter_mono)


    def _define_iterative_solver(self,  force_redefine: bool = False):

        if hasattr(self, 'iterative_solver_coarse') and not force_redefine:
            return

        if jax.process_index() == 0: print("Defining single-device iterative block solver functions (cond/body)...")
        use_vmap = self.coef_slv_params.get("use_vmap", False)
        solve_block_system_fn = lambda M_block_diag, b_block_val: jnp.linalg.solve(M_block_diag, b_block_val)
        slv_call_map_fn = jax.vmap(solve_block_system_fn) if use_vmap else \
                          lambda M_arg, rhs_arg: lax.map(lambda i_map: solve_block_system_fn(M_arg[i_map], rhs_arg[i_map]), jnp.arange(M_arg.shape[0]))

        def cond_fun_block(state): # i, M, N, b, coeffs, converged, error, history
            i, _, _, _, _, converged, _, _ = state
            return (i < self.max_iter) & (~converged)

        @partial(jit)
        def body_fun_coarse_block(state): # i, M_diags, N_offdiag_neg, b_vec, coeffs_curr, converged, error, history
            i, M_diags, N_offdiag_neg, b_vec, coeffs_curr, _, _, history = state
            rhs = b_vec + jnp.einsum("minj,nj->mi", N_offdiag_neg, coeffs_curr)
            coeffs_hat = slv_call_map_fn(M_diags, rhs)
            coeffs_new = self.omega * coeffs_hat + (1 - self.omega) * coeffs_curr
            error_new = jnp.linalg.norm(coeffs_new - coeffs_curr)
            history_new = history.at[i].set(error_new)
            return i + 1, M_diags, N_offdiag_neg, b_vec, coeffs_new, error_new < self.tol, error_new, history_new

        @partial(jit)
        def body_fun_fine_block(state): # i, M_diags, N_offdiag_neg, b_vec, coeffs_curr, converged, error, history
            i, M_diags, N_offdiag_neg, b_vec, coeffs_curr, _, _, history = state
            N_basis_f = coeffs_curr.shape[-1]
            rhs = b_vec + jnp.einsum("mniklj,klj->mni", N_offdiag_neg, coeffs_curr)
            rhs_reshaped = rhs.reshape(-1, N_basis_f)
            coeffs_curr_reshaped = coeffs_curr.reshape(-1, N_basis_f)
            coeffs_hat_reshaped = slv_call_map_fn(M_diags, rhs_reshaped)
            coeffs_new_reshaped = self.omega * coeffs_hat_reshaped + (1 - self.omega) * coeffs_curr_reshaped
            error_new = jnp.linalg.norm(coeffs_new_reshaped - coeffs_curr_reshaped)
            history_new = history.at[i].set(error_new)
            coeffs_new = coeffs_new_reshaped.reshape(coeffs_curr.shape)
            return i + 1, M_diags, N_offdiag_neg, b_vec, coeffs_new, error_new < self.tol, error_new, history_new

        self.iterative_solver_coarse = (cond_fun_block, body_fun_coarse_block)
        self.iterative_solver_fine = (cond_fun_block, body_fun_fine_block)


    # BENCHMARKING METHOD
    def benchmark_solver(self, level: int, model: Any, solver_type_to_benchmark: str, num_runs: int = 100, logger: Any = None):
        if level not in [0, 1]: raise ValueError("level must be 0 or 1")

        _use_vmap_info = ""
        # use_vma only works for iterative_block
        if "block" in solver_type_to_benchmark:
            use_vmap_for_benchmark = self.coef_slv_params.get("use_vmap", False)
            _use_vmap_info = f" use_vmap={use_vmap_for_benchmark}"
            if solver_type_to_benchmark == "iterative_block":
                 self._define_iterative_solver(force_redefine=True)


        final_log_prefix = f"[Solver Benchmark Level={level} SolverType={solver_type_to_benchmark}{_use_vmap_info}]"
        def log_info(msg):
            if jax.process_index() == 0:
                if logger is not None: logger.info(f"{final_log_prefix} {msg}")
                else: print(f"{final_log_prefix} {msg}")

        log_info(f"Starting benchmark for SOLVE KERNEL ({num_runs} runs)...")

        log_info("Preparing inputs and assembling matrices...")
        params_c, params_f = model.params_c, model.params_f
        coeffs_init_guess = model.coeffs_c if level == 0 else model.coeffs_f

        sigma_comp_val = model.sigma_schedule["comp"](0) if isinstance(model.sigma_schedule["comp"], collections.abc.Callable) else model.sigma_schedule["comp"]
        sigma_coop_val = model.sigma_schedule["coop"](0) if isinstance(model.sigma_schedule["coop"], collections.abc.Callable) else model.sigma_schedule["coop"]
        posterior_val = model.E_step(level, params_c, params_f, coeffs_init_guess)

        for arr_item in [params_c, params_f, coeffs_init_guess, sigma_comp_val, sigma_coop_val, posterior_val]:
             if hasattr(arr_item, 'block_until_ready'): arr_item.block_until_ready()
             elif isinstance(arr_item, (tuple, list)):
                 for item_leaf in jax.tree_util.tree_leaves(arr_item):
                     if hasattr(item_leaf, 'block_until_ready'): item_leaf.block_until_ready()

        assembled_data_for_kernel = None
        solve_kernel_fn = None
        coeffs_init_for_kernel = None
        _c_shape_for_kernel = None

        if solver_type_to_benchmark == "direct":
            A_flat, b_flat, c_shape_direct = self._assemble_direct_system(level, params_c, params_f, posterior_val, sigma_comp_val, sigma_coop_val)
            A_flat.block_until_ready(); b_flat.block_until_ready()
            assembled_data_for_kernel = (A_flat, b_flat)
            coeffs_init_for_kernel = coeffs_init_guess
            _c_shape_for_kernel = c_shape_direct

            def _direct_solve_kernel_static(kernel_data, _coeffs_dummy_init_kernel):
                A, b = kernel_data
                cs = _c_shape_for_kernel
                res = jnp.linalg.solve(A, b)
                return res.reshape(cs), 1
            solve_kernel_fn = _direct_solve_kernel_static

        elif solver_type_to_benchmark == "iterative_monolithic":
            matrices_tuple, c_shape_mono = self._assemble_iterative_monolithic_system(level, params_c, params_f, posterior_val, sigma_comp_val, sigma_coop_val)
            jax.tree_util.tree_map(lambda x: x.block_until_ready() if hasattr(x,'block_until_ready') else x, matrices_tuple)
            assembled_data_for_kernel = (matrices_tuple,)
            coeffs_init_for_kernel = coeffs_init_guess.ravel()
            _c_shape_for_kernel = c_shape_mono

            def _iter_mono_solve_kernel_static(kernel_data, c_init_flat_kernel):
                (mats_t,) = kernel_data
                cs = _c_shape_for_kernel
                _cond_f, _body_f = self._cond_fun_iter_mono, self._body_fun_iter_mono

                hist_i = jnp.zeros(self.max_iter, dtype=c_init_flat_kernel.dtype)
                init_st = (0, c_init_flat_kernel, False, jnp.inf, hist_i, mats_t)
                final_st = lax.while_loop(_cond_f, _body_f, init_st)
                iters, c_final_flat, _, _, _, _ = final_st
                return c_final_flat.reshape(cs), iters
            solve_kernel_fn = _iter_mono_solve_kernel_static

        elif solver_type_to_benchmark == "iterative_block":
            matrices_tuple, c_shape_block = self._assemble_iterative_block_system(level, params_c, params_f, posterior_val, sigma_comp_val, sigma_coop_val)
            jax.tree_util.tree_map(lambda x: x.block_until_ready() if hasattr(x,'block_until_ready') else x, matrices_tuple)
            assembled_data_for_kernel = (matrices_tuple,)
            coeffs_init_for_kernel = coeffs_init_guess.reshape(c_shape_block)
            _c_shape_for_kernel = c_shape_block

            def _iter_block_solve_kernel_static(kernel_data, c_init_structured_kernel):
                ((M_d, N_ond, b_fs),) = kernel_data
                cs = _c_shape_for_kernel
                _cond_f, _body_f = self.iterative_solver_coarse if level == 0 else self.iterative_solver_fine

                hist_i = jnp.zeros(self.max_iter, dtype=c_init_structured_kernel.dtype)
                init_st = (0, M_d, N_ond, b_fs, c_init_structured_kernel, False, jnp.inf, hist_i)
                final_st = lax.while_loop(_cond_f, _body_f, init_st)
                return final_st[4].reshape(cs), final_st[0]
            solve_kernel_fn = _iter_block_solve_kernel_static

        elif solver_type_to_benchmark == "iterative_block_distributed":
            matrices_tuple_rep, c_shape_dist = self._assemble_iterative_block_distributed_system(level, params_c, params_f, posterior_val, sigma_comp_val, sigma_coop_val)
            jax.tree_util.tree_map(lambda x: x.block_until_ready() if hasattr(x,'block_until_ready') else x, matrices_tuple_rep)
            M_diag_rep, N_coupling_rep, b_system_rep = matrices_tuple_rep
            total_blocks_dist = M_diag_rep.shape[0]
            coeffs_init_flat_rep = coeffs_init_guess.reshape(total_blocks_dist, -1)
            assembled_data_for_kernel = (M_diag_rep, N_coupling_rep, b_system_rep)
            coeffs_init_for_kernel = coeffs_init_flat_rep
            _c_shape_for_kernel = c_shape_dist
        else:
            raise ValueError(f"Unsupported solver_type_to_benchmark for refined benchmark: {solver_type_to_benchmark}")

        log_info("Assembly complete. Defining JITted benchmark loop for solve kernel...")

        if solver_type_to_benchmark != "iterative_block_distributed":
            coeffs_init_for_loop_body = jax.device_put(coeffs_init_for_kernel, self.replicate_sharding)

            @partial(jax.jit, static_argnames=("num_runs_loop", "static_solve_kernel_fn"))
            def _run_benchmark_loop_jitted(static_solve_kernel_fn,
                                           kernel_data_arg,
                                           init_coeffs_arg,
                                           num_runs_loop):
                def _loop_body_fn(loop_idx, carry_state):
                    s_coeffs, s_iters = carry_state
                    res_c, res_i = static_solve_kernel_fn(kernel_data_arg, init_coeffs_arg)
                    return (s_coeffs + jnp.sum(res_c), s_iters + res_i)

                sum_c_dtype = init_coeffs_arg.dtype if hasattr(init_coeffs_arg, 'dtype') else jnp.float32
                init_carry = (jnp.array(0.0, dtype=sum_c_dtype), jnp.array(0, dtype=jnp.int32))
                final_sum_coeffs, final_sum_iters = lax.fori_loop(0, num_runs_loop, _loop_body_fn, init_carry)
                return final_sum_coeffs, final_sum_iters

            log_info("Running warm-up iterations for solve kernel...")
            warmup_runs_count = min(5, num_runs) if num_runs > 0 else 0
            if warmup_runs_count > 0:
                warmup_sum, warmup_iters = _run_benchmark_loop_jitted(
                    solve_kernel_fn, assembled_data_for_kernel,
                    coeffs_init_for_loop_body, warmup_runs_count
                )
                warmup_sum.block_until_ready(); warmup_iters.block_until_ready()
            log_info("Warm-up complete.")

            if num_runs == 0:
                total_time_val, final_coeffs_sum_item, final_iters_sum_item = 0.0, 0.0, 0
            else:
                log_info(f"Starting timed benchmark for solve kernel ({num_runs} runs)...")
                t_start_val = time.perf_counter()
                final_coeffs_sum, final_iters_sum = _run_benchmark_loop_jitted(
                    solve_kernel_fn, assembled_data_for_kernel,
                    coeffs_init_for_loop_body, num_runs
                )
                final_coeffs_sum.block_until_ready(); final_iters_sum.block_until_ready()
                total_time_val = time.perf_counter() - t_start_val
                final_coeffs_sum_item, final_iters_sum_item = final_coeffs_sum.item(), final_iters_sum.item()

        else:
            M_dr, N_cr, b_sr = assembled_data_for_kernel
            c_if = coeffs_init_for_kernel

            def _spmd_benchmark_fori_loop_wrapper( M_diag_rep_smap, N_coupling_rep_smap, b_system_rep_smap,
                                                   coeffs_init_flat_rep_smap,
                                                   num_runs_smap):

                def solve_one_instance_spmd_kernel( M_dr_kernel, N_cr_kernel, b_fr_kernel,
                                                    c_ir_kernel, hist_ir_kernel,
                                                    dev_id_kernel, blocks_per_dev_kernel, start_idx_kernel):
                    M_block_shard_k = lax.dynamic_slice_in_dim(M_dr_kernel, start_idx_kernel, blocks_per_dev_kernel, axis=0)
                    coeffs_curr_local_k = lax.dynamic_slice_in_dim(c_ir_kernel, start_idx_kernel, blocks_per_dev_kernel, axis=0)
                    b_local_shard_k = lax.dynamic_slice_in_dim(b_fr_kernel, start_idx_kernel, blocks_per_dev_kernel, axis=0)

                    use_vmap_k = self.coef_slv_params.get("use_vmap", False)
                    solve_fn_k = jax.vmap(lambda M,b: jnp.linalg.solve(M,b)) if use_vmap_k else \
                                 lambda M_s,b_s: lax.map(lambda i_m: jnp.linalg.solve(M_s[i_m],b_s[i_m]), jnp.arange(M_s.shape[0]))
                    def cond_fun_k(st_k): return (st_k[0] < self.max_iter) & jnp.logical_not(st_k[2])
                    def body_fun_k(st_k):
                        i_k, cc_loc_k, _, _, h_curr_k, _ = st_k
                        cc_gath_k = lax.all_gather(cc_loc_k, axis_name=self.mesh_axis_name, tiled=True)
                        coup_glob_k = jnp.einsum("ikjl,jl->ik", N_cr_kernel, cc_gath_k)
                        coup_loc_k = lax.dynamic_slice_in_dim(coup_glob_k, start_idx_kernel, blocks_per_dev_kernel, axis=0)
                        rhs_loc_k = b_local_shard_k + coup_loc_k
                        ch_loc_k = solve_fn_k(M_block_shard_k, rhs_loc_k)
                        cn_loc_k = self.omega * ch_loc_k + (1 - self.omega) * cc_loc_k
                        diff_sq_loc_k = jnp.sum((cn_loc_k - cc_loc_k)**2)
                        err_sq_glob_k = lax.psum(diff_sq_loc_k, axis_name=self.mesh_axis_name)
                        err_next_k = jnp.sqrt(err_sq_glob_k)
                        conv_next_k = err_next_k < self.tol
                        h_new_k = h_curr_k.at[i_k].set(err_next_k)
                        return (i_k + 1, cn_loc_k, conv_next_k, err_next_k, h_new_k, lax.psum(1, self.mesh_axis_name))
                    init_st_k = (0, coeffs_curr_local_k, False, jnp.inf, hist_ir_kernel, lax.psum(1, self.mesh_axis_name))
                    final_st_k = lax.while_loop(cond_fun_k, body_fun_k, init_st_k)
                    final_c_loc_k, final_i_k = final_st_k[1], final_st_k[0]
                    final_c_gath_k = lax.all_gather(final_c_loc_k, axis_name=self.mesh_axis_name, tiled=True)
                    return final_c_gath_k, final_i_k

                def _smap_fori_loop_body(loop_idx_smap, carry_state_smap):
                    s_coeffs_smap, s_iters_smap = carry_state_smap
                    hist_template_smap = jnp.zeros(self.max_iter, dtype=coeffs_init_flat_rep_smap.dtype)
                    num_devices_smap = lax.psum(1, axis_name=self.mesh_axis_name)
                    dev_id_smap = lax.axis_index(self.mesh_axis_name)
                    blocks_p_dev_smap = M_diag_rep_smap.shape[0] // num_devices_smap
                    start_i_smap = dev_id_smap * blocks_p_dev_smap
                    res_c_gath_smap, res_i_smap = solve_one_instance_spmd_kernel(
                        M_diag_rep_smap, N_coupling_rep_smap, b_system_rep_smap,
                        coeffs_init_flat_rep_smap, hist_template_smap,
                        dev_id_smap, blocks_p_dev_smap, start_i_smap
                    )
                    return (s_coeffs_smap + jnp.sum(res_c_gath_smap), s_iters_smap + res_i_smap)

                init_carry_smap_coeffs_dtype = coeffs_init_flat_rep_smap.dtype if hasattr(coeffs_init_flat_rep_smap, 'dtype') else jnp.float32
                init_carry_smap = (jnp.array(0.0, dtype=init_carry_smap_coeffs_dtype), jnp.array(0, dtype=jnp.int32))

                final_sum_coeffs_smap, final_sum_iters_smap = lax.fori_loop(
                    0, num_runs_smap, _smap_fori_loop_body, init_carry_smap
                )
                return final_sum_coeffs_smap, final_sum_iters_smap

            in_specs_smap = (self.replicate_partition_spec, self.replicate_partition_spec, self.replicate_partition_spec,
                             self.replicate_partition_spec) # num_runs_smap will be static via partial
            out_specs_smap = (self.replicate_partition_spec, self.replicate_partition_spec)

            log_info("Running warm-up iterations for distributed solve kernel...")
            warmup_runs_count = min(5, num_runs) if num_runs > 0 else 0
            smap_fn_for_benchmark = partial(_spmd_benchmark_fori_loop_wrapper, num_runs_smap=warmup_runs_count)

            if warmup_runs_count > 0:
                warmup_sum, warmup_iters = shard_map(
                    smap_fn_for_benchmark, mesh=self.mesh,
                    in_specs=in_specs_smap, out_specs=out_specs_smap, check_rep=False
                )(M_dr, N_cr, b_sr, c_if)
                warmup_sum.block_until_ready(); warmup_iters.block_until_ready()
            log_info("Warm-up complete.")

            if num_runs == 0:
                total_time_val, final_coeffs_sum_item, final_iters_sum_item = 0.0, 0.0, 0
            else:
                log_info(f"Starting timed benchmark for distributed solve kernel ({num_runs} runs)...")
                smap_fn_for_timed_benchmark = partial(_spmd_benchmark_fori_loop_wrapper, num_runs_smap=num_runs)
                t_start_val = time.perf_counter()
                final_coeffs_sum, final_iters_sum = shard_map(
                    smap_fn_for_timed_benchmark, mesh=self.mesh,
                    in_specs=in_specs_smap, out_specs=out_specs_smap, check_rep=False
                )(M_dr, N_cr, b_sr, c_if)
                final_coeffs_sum.block_until_ready(); final_iters_sum.block_until_ready()
                total_time_val = time.perf_counter() - t_start_val
                final_coeffs_sum_item, final_iters_sum_item = final_coeffs_sum.item(), final_iters_sum.item()

        avg_time_run = total_time_val / num_runs if num_runs > 0 else 0.0
        avg_iters_run = float(final_iters_sum_item) / num_runs if num_runs > 0 else 0.0
        meaningful_iters_for_time_per_iter = final_iters_sum_item > 0
        if solver_type_to_benchmark == "direct":
             meaningful_iters_for_time_per_iter = False
        avg_time_iter = total_time_val / final_iters_sum_item if meaningful_iters_for_time_per_iter else avg_time_run

        log_info("Benchmark for SOLVE KERNEL complete.")
        log_info(f"Total time for {num_runs} SOLVE KERNEL runs: {total_time_val:.6f} seconds")
        log_info(f"Average time per SOLVE KERNEL call: {avg_time_run:.8f} seconds")
        if solver_type_to_benchmark != "direct":
            log_info(f"Total iterations accumulated (all runs): {final_iters_sum_item}")
            log_info(f"Average iterations per SOLVE KERNEL call: {avg_iters_run:.2f}")
            if meaningful_iters_for_time_per_iter and avg_iters_run > 1.001 :
                 log_info(f"Average time per single iteration: {avg_time_iter:.8f} seconds")

        return {
            'solver_benchmarked': f"{solver_type_to_benchmark}_solve_kernel", 'level': level, 'num_runs': num_runs,
            'total_time_seconds': total_time_val, 'average_time_per_run_seconds': avg_time_run,
            'total_iters': final_iters_sum_item, 'average_iters_per_run': avg_iters_run,
            'average_time_per_iter_seconds': avg_time_iter,
            'final_coeffs_sum_check': final_coeffs_sum_item
        }
