from functools import partial
from typing import Any, Callable, Sequence, Tuple, Optional, Dict


from flax.training import train_state
from flax import jax_utils

import jax.numpy as jnp
import numpy as np
from collections import defaultdict
from jax import lax, jit, grad, pmap, random, jacfwd, jacrev, vmap, hessian
from jax.tree_util import tree_map, tree_reduce, tree_leaves

from jax.scipy.special import logsumexp, digamma, polygamma

import optax

from phijax.models import *
from phijax.data import *
from phijax.utils import flatten_pytree, Collection

from phijax.equations.objectives import *

from scipy.stats import skew, kurtosis

from phijax.em import *


from jax.flatten_util import ravel_pytree



def mse(r):
    return jnp.mean(jnp.square(r))

@jax.jit
def t_objective_d(r, nu:float, lam:float):
    nu = jnp.asarray(nu, dtype=r.dtype)
    lam = jnp.asarray(lam, dtype=r.dtype)
    loss = 0.5 * (nu + 1.0) * jnp.log(  1.0 + (lam * jnp.square(r)) / (nu ))
    return jnp.mean(loss)

def estep_w_d(r, nu: float, lam: float):
    return (nu + 1.0) / (nu + lam * jnp.square(r))


def snll_term_loss(resid, nu, lam):
    return t_objective_d(resid, nu=nu, lam=lam)

def wls_term_loss(resid, nu, lam):
    w = lam / nu
    #w = estep_w_d(resid, nu=nu, lam=lam) #jax.lax.stop_gradient(estep_w_d(resid, nu=nu, lam=lam))
    l =  0.5 * lam * jnp.mean(w * jnp.square(resid))
    #loss =  0.5 * (nu + 1) * jnp.exp( 1.0 + (lam * jnp.square(resid)) / (nu ))
    return jnp.mean(l)





from phijax.utils_.subspace_alignment import (
    init_state as init_subspace_state,
    update_state as update_subspace_state,
    compute_metrics as subspace_metrics,
    pack_pairwise_logs as pack_subspace_pairwise_logs,
)


from .pcgrad_utils import *


def alpha_mean(array: jnp.ndarray, axis: Optional[Sequence[int]] = None, keepdims: bool = False, eps: float = 1e-12) -> jnp.ndarray:
   """Computes the alpha mean of an array in the sense of Amari et al. 2016.

   maths: the alpha-mean is defined as:

       for x and y: 
   
   """
   pass

def group_by_prefix(losses: dict[str, jnp.ndarray]) -> dict[str, jnp.ndarray]:
        grouped = {}
        for k, v in losses.items():
            prefix = k.split("_", 1)[0]  # "ics_u" -> "ics"
            grouped[prefix] = grouped.get(prefix, 0.0) + v
        return grouped


def flatten_pytree(pytree):
    return jnp.concatenate([jnp.ravel(x) for x in jax.tree_util.tree_leaves(pytree)])

@jax.jit
def fisher_skewness_excess_kurtosis(x: jnp.ndarray, eps: float = 1e-12):
    x = jnp.ravel(x)

    mu = jnp.mean(x)
    xc = x - mu

    m2 = jnp.mean(xc * xc)
    sigma = jnp.sqrt(m2 + eps)

    m3 = jnp.mean(xc * xc * xc)
    m4 = jnp.mean((xc * xc) * (xc * xc))

    skew = m3 / (sigma**3 + eps)
    ex_kurt = m4 / (sigma**4 + eps) - 3.0
    return skew, ex_kurt


def scipy_skewness_excess_kurtosis(x: jnp.ndarray):
    x = jnp.ravel(x)
    x_np = jnp.array(x)
    skewness = skew(x_np, False)
    ex_kurtosis = kurtosis(x_np, fisher=True, bias=False)
    return skewness, ex_kurtosis




def _task_space_alpha(G, w, eps=1e-12, clip_min=1e-20):
    M = G.T @ G
    lam, V = jnp.linalg.eigh(M)
    pos = lam > eps

    def fallback():
        return w

    def aligned():
        lam_pos = jnp.where(pos, lam, jnp.inf)
        lam_min_pos = jnp.maximum(jnp.min(lam_pos), clip_min)
        inv_sqrt = jnp.where(pos, 1.0 / jnp.sqrt(jnp.maximum(lam, clip_min)), 0.0)
        B = jnp.sqrt(lam_min_pos) * (V @ (inv_sqrt[:, None] * V.T))
        return B @ w

    return jax.lax.cond(jnp.any(pos), aligned, fallback)


def _grads_dict_to_G(grads_dict, keys):
    flat0, _ = ravel_pytree(grads_dict[keys[0]])
    cols = [flat0]
    for k in keys[1:]:
        fk, _ = ravel_pytree(grads_dict[k])
        cols.append(fk)
    return jnp.stack(cols, axis=1)


def _alpha_to_weight_dict(keys, alpha):
    return {k: alpha[i] for i, k in enumerate(keys)}


def orthogonality_penalty(W):
    A = W.reshape((W.shape[0], -1)) if W.ndim > 2 else W
    d0, d1 = A.shape
    if d0 <= d1:
        G = A @ A.T
        I = jnp.eye(d0, dtype=A.dtype)
    else:
        G = A.T @ A
        I = jnp.eye(d1, dtype=A.dtype)
    return jnp.mean((G - I) ** 2)


def layerwise_orthogonality_regularizer(params) -> jnp.ndarray:
    def per_leaf(x):
        return orthogonality_penalty(x) if x.ndim == 2 else jnp.array(0.0, dtype=x.dtype)
    penalties = jax.tree_util.tree_map(per_leaf, params)
    return jax.tree_util.tree_reduce(lambda a, b: a + b, penalties)






def residual_stats_logging(residuals, qs=(0.01, 0.05)):
    # Pull from device if it's a JAX array
    res = np.asarray(jax.device_get(residuals)).ravel()

    mean = np.mean(res)
    std = np.std(res, ddof=1)

    skewness_bias = skew(res, bias=False)
    kurtosis_bias = kurtosis(res, bias=False)
    skewness = skew(res)
    kurt = kurtosis(res)

    rms = np.sqrt(np.mean(res**2))
    max_abs = np.max(np.abs(res))
    max_to_rms = (max_abs / rms) if rms > 0 else np.nan

    energy = res**2
    total_energy = np.sum(energy)

    energy_conc = {}
    n = res.size
    if total_energy > 0 and n > 0:
        abs_res = np.abs(res)
        for q in qs:
            q = float(q)
            k = int(np.ceil(q * n))
            k = max(k, 1)
            idx = np.argpartition(abs_res, -k)[-k:]
            energy_conc[f"top_{int(q*100)}pct_energy"] = float(np.sum(energy[idx]) / total_energy)
            #energy_conc[f"top_{int(q*100)}_res"] = idx
    else:
        for q in qs:
            energy_conc[f"top_{int(float(q)*100)}pct_energy"] = np.nan

    return dict(
        mean=float(mean),
        std=float(std),
        skewness_bias=float(skewness_bias),
        kurtosis_bias=float(kurtosis_bias),
        skewness=float(skewness),
        kurtosis=float(kurt),
        rms=float(rms),
        max_abs=float(max_abs),
        max_to_rms=float(max_to_rms),
        **energy_conc,
    )

class BaseMonitorMixin:
    def init_monitor(self):
        if not hasattr(self.config, "subspace"):
            self.config.subspace = Collection()
            self.config.subspace.enabled = True
            self.config.subspace.rank = 8
            self.config.subspace.oja_lr = 0.05
            self.config.subspace.normalize_grads = True
            self.config.subspace.log_matrices = False
            self.config.subspace.seed = 0

        if not hasattr(self.config, "logging"):
            self.config.logging = Collection()
            self.config.logging.log_losses = False
            self.config.logging.log_weights = False
            self.config.logging.log_grads = False
           
            self.config.logging.log_stats = False

        self._subspace_state = None

        self.log_dict = {}

    #### helpers to log

    def _log_losses(self, state, batch, *args):
        losses = self.losses(state, batch, *args)
        for key, values in losses.items():
            self.log_dict[key + "_loss"] = values

        # log sigmas
        if "loss_params" in state.__dict__:
            log_sigma = state.loss_params["log_sigma"]
            for key, values in log_sigma.items():
                self.log_dict["scale_" + key] = jnp.exp(-values)

    def _log_weights(self, state):
        weights = state.st_params
        for key, values in weights.items():
            for k ,v in values.items():
                self.log_dict[f"st/{key}_{k}"] = v

    def _log_grads(self, state, batch, *args):
        grads = jacrev(self.losses)(state, batch, *args)
        for key, value in grads.items():
            flattened_grad = flatten_pytree(value)
            grad_norm = jnp.linalg.norm(flattened_grad)
            self.log_dict[key + "_grad_norm"] = grad_norm


    def _log_stats(self, state, batch, *args):
        residuals = self.residuals(state, batch, *args)
        pde_res = residuals["res"]
  

        pde_logs = residual_stats_logging(pde_res, qs=(0.01, 0.05))
        for k, v in pde_logs.items():
            self.log_dict[f"stats/pde_{k}"] = v

        ic_res = residuals.get("ics", None)
        if ic_res is not None:
            self.log_dict["stats/ic_var"] = jnp.var(ic_res)
            self.log_dict["stats/ic_mean"] = jnp.mean(ic_res)
            self.log_dict["stats/ic_max"] = jnp.max(ic_res)
            self.log_dict["stats/ic_min"] = jnp.min(ic_res)

        if "bcs" in residuals:
            bc_res = residuals.get("bcs", None)
            self.log_dict["stats/bc_var"] = jnp.var(bc_res)
            self.log_dict["stats/bc_mean"] = jnp.mean(bc_res)
            self.log_dict["stats/bc_max"] = jnp.max(bc_res)
            self.log_dict["stats/bc_min"] = jnp.min(bc_res)
        

    def log(self, state, batch, *args):
        self.log_dict = {}
        params = state.params

        #set default logging options if not present
        

        #if self.config.logging.log_losses:
            
        if self.config.logging.log_weights:
            self._log_weights(state)
        #if self.config.logging.log_grads:
            #self._log_grads(state, batch, *args)

        self._log_losses(state, batch, *args)
        self._log_leaf_grad(state, batch, *args)
        self._log_stats(state, batch, *args)

            

        return self.log_dict
    
    #####  Gradient Diagnostics  #####--> Point is to measure gradient maginitudes for different terms and the alignment
    ##### of the gradients. 

    @partial(jit, static_argnums=(0,))
    def leaf_residuals(self, state, batch, *args):
        return self.residuals(state, batch, *args)

    @partial(jit, static_argnums=(0,))
    def leaf_losses_from_residuals(self, state, batch, *args, eps=0.0):
        r = self.leaf_residuals(state, batch, *args)
        
        return {k: jnp.mean(jnp.square(v)) + eps for k, v in r.items()}

    @partial(jit, static_argnums=(0,))
    def leaf_grads_mse(self, state, batch, *args):
        keys = tuple(self.leaf_losses_from_residuals(state, batch, *args).keys())

        def loss_k(state, key):
            l = self.leaf_losses_from_residuals(state, batch, *args)
            return l[key]

        grads = {}
        for k in keys:
            grads[k] = grad(lambda p, kk=k: loss_k(p, kk))(state.params)
        return grads
    @partial(jit, static_argnums=(0,))
    def leaf_grads(self, state, batch, *args):
        def losses_wrt_params(params):
            st = state.replace(params=params)
            losses = self.losses(st, batch, *args)
            return {
            p: sum(losses[k] for k in ks)
            for p, ks in self.prefix_group.items()
        }

        return jax.jacrev(losses_wrt_params)(state.params)
    

    @partial(jit, static_argnums=(0,))
    def leaf_grads(self, state, batch, *args):
        def losses_wrt_params(params):
            st = state.replace(params=params)
            losses = self.losses(st, batch, *args)
            return {
            p: sum(losses[k] for k in ks)
            for p, ks in self.prefix_group.items()
        }
        return jax.jacrev(losses_wrt_params)(state.params)

    def _log_leaf_grad(self, state, batch, *args, eps=1e-12):

        def alignment_score(vectors, eps=1e-8):
            n = vectors.shape[0]
            norms = jnp.linalg.norm(vectors, axis=1, keepdims=True)
            normalized_vectors = vectors / (norms + eps)
            summed_v = jnp.sum(normalized_vectors, axis=0)
            summed_norm_sq = jnp.sum(jnp.square(summed_v))
            score = (2.0 / (n**2)) * summed_norm_sq - 1.0
            return score
        
        g = self.leaf_grads(state, batch, *args)
        w = state.weights

        flat = {k: flatten_pytree(v) for k, v in g.items()}
        keys = list(flat.keys())

        norm = {k: jnp.linalg.norm(flat[k]) for k in keys}
        wflat = {k: w[k] * flat[k] for k in keys}
        wnorm = {k: jnp.linalg.norm(wflat[k]) for k in keys}

        g_total = sum([wflat[k] for k in keys])
        total_norm = jnp.linalg.norm(g_total)

        denom = jnp.dot(g_total, g_total) + eps
        for k in keys:
            self.log_dict[f"grads/contrib_{k}"] = jnp.dot(wflat[k], g_total) / denom
            self.log_dict[f"grads/norm_{k}"] = norm[k]
        self.log_dict["grads/norm_total"] = total_norm
        self.log_dict["grads/align_score"] = alignment_score(jnp.stack([wflat[k] for k in keys]), eps=eps)

        for i in range(len(keys)):
            for j in range(i + 1, len(keys)):
                a, b = keys[i], keys[j]
                dot = jnp.dot(flat[a], flat[b])
                denom = jnp.linalg.norm(flat[a]) * jnp.linalg.norm(flat[b]) + eps
                self.log_dict[f"grads/cos_{a}_{b}"] = dot / denom
                #self.log_dict[f"grads/dot_{a}_{b}"] = dot

        """if getattr(self.config.subspace, "enabled", False):
            r = int(self.config.subspace.rank)
            lr = float(self.config.subspace.oja_lr)
            normalize_grads = bool(self.config.subspace.normalize_grads)
            seed = int(getattr(self.config.subspace, "seed", 0))

            if self._subspace_state is None:
                key = jax.random.PRNGKey(seed)
                self._subspace_state = init_subspace_state(key, state.params, keys, r=r)

            self._subspace_state, G = update_subspace_state(
                self._subspace_state,
                grads_dict=g,
                lr=lr,
                normalize_grads=normalize_grads,
                eps=eps,
            )
            m = subspace_metrics(self._subspace_state, G, eps=eps)

            #print(m['offA'], m['offC'])

            self.log_dict["subspace/A_mean_offdiag"] = m["A_mean_offdiag"]
            self.log_dict["subspace/A_min_offdiag"] = m["A_min_offdiag"]
            self.log_dict["subspace/cos_mean_offdiag"] = m["cos_mean_offdiag"]
            self.log_dict["subspace/cos_min_offdiag"] = m["cos_min_offdiag"]
            self.log_dict["subspace/align_to_consensus"] = m["align_to_consensus"]
            self.log_dict["subspace/consensus_gap"] = m["consensus_gap"]

            for i, kname in enumerate(keys):
                self.log_dict[f"subspace/cap_to_consensus_{kname}"] = m["cap_to_consensus_per_task"][i]
                self.log_dict[f"subspace/grad_norm_{kname}"] = m["grad_norms"][i]

            if bool(getattr(self.config.subspace, "log_matrices", False)):
                pair_logs = pack_subspace_pairwise_logs(keys, m["A_pairwise"], m["cos_pairwise"], prefix="subspace")
                self.log_dict.update(pair_logs)"""


def _is_replicated_state(state):
    # look at any param leaf and see if it has a leading device axis
    leaves = jax.tree_util.tree_leaves(state.params)
    if not leaves:
        return False
    x = leaves[0]
    return hasattr(x, "ndim") and x.ndim >= 1 and x.shape[0] == jax.local_device_count()

def _replicate_scalar_like_state(state, x):
    x = jnp.asarray(x)
    if _is_replicated_state(state):
        ndev = jax.local_device_count()
        return jnp.broadcast_to(x, (ndev,) + x.shape)
    return x

def _replicate_pytree_like_state(state, pytree):
    return jax.tree_util.tree_map(lambda v: _replicate_scalar_like_state(state, v), pytree)



class AdaptivePINNMixin:
    def run_em(self, state, batch, em_config):
            residuals = self.residuals(state, batch)
            st_params = state.st_params
            nu_dict = st_params["nu"]
            lam_dict = st_params["lam"]

            new_nu = {}
            new_lam = {}

            for term in residuals:
                res = residuals[term].ravel()
                nu = nu_dict[term]
                lam = lam_dict[term]

                w, elog_eta = e_step(res, lam, nu)
                nu_new = m_step_nu_flat(nu, elog_eta, w, em_config)
                lam_new = m_step_lambda(res, w, a_lam=1.0, b_lam=1.0)

                new_nu[term] = nu_new
                new_lam[term] = lam_new
            return nu_dict, new_lam
    
    
    
class PINN:
    loss_keys = ("ics", "bcs", "res")
    def __init__(self, config):
        self.config = config
        ### add rotation defaults
        if not hasattr(self.config, "rotation"):
            self.config.rotation = Collection()
            self.config.rotation.rot_rho = 0.99
            self.config.rotation.rot_precond_freq = 50
            self.config.rotation.rot_eps = 1e-6
            self.config.rotation.rot_max_dim = 512

        if not hasattr(self.config, "subspace"):
            self.config.subspace = Collection()
            self.config.subspace.enabled = True
            self.config.subspace.rank = 8
            self.config.subspace.oja_lr = 0.05
            self.config.subspace.normalize_grads = True
            self.config.subspace.log_matrices = False
            self.config.subspace.seed = 0

        self.state = get_model(config)
        self.loss_keys = tuple(getattr(self, "loss_keys", ("ics","bcs","res")))
        groups = defaultdict(list)
        for k in self.loss_keys:
            prefix = k.split("_", 1)[0]
            groups[prefix].append(k)
        self.prefix_group = {p: tuple(v) for p, v in groups.items()}


        if "spec" in self.state.weights:
            self.loss_keys = tuple(self.loss_keys) + ("spec",)
            self._spec_residual_fn = self._spec_residual
        else:
            self._spec_residual_fn = self._zero_spec_residual

        
        self.term_mode, self.st_handler, init_st_params = build_objectives(config, self.loss_keys)

        self.init_st_params = init_st_params

        mode_set = {str(self.term_mode[k]).lower() for k in self.loss_keys}
        uses_em = any(m in ( "em", ) for m in mode_set)

        if uses_em and getattr(self.state, "st_params", None) is None:
            self.state = self.state.replace(st_params=init_st_params)
        self._ensure_st_params(self.init_st_params)
        #self._init_objectives_fixed(self.init_st_params)
        self._init_objectives_dynamic()

    def _ensure_st_params(self, init_st_params):
        sp = dict(getattr(self.state, "st_params", {}) or {})
        if "nu" not in sp:
            sp["nu"] = init_st_params["nu"]
        if "lam" not in sp:
            sp["lam"] = init_st_params["lam"]
        sp["nu"] = _replicate_pytree_like_state(self.state, sp["nu"])
        sp["lam"] = _replicate_pytree_like_state(self.state, sp["lam"])
        self.state = self.state.replace(st_params=sp)

    def set_st_params(self, state, *, nu=None, lam=None):
        sp = dict(state.st_params)
        if nu is not None:
            sp["nu"] = _replicate_pytree_like_state(state, nu)
        if lam is not None:
            sp["lam"] = _replicate_pytree_like_state(state, lam)
        return state.replace(st_params=sp)


    def _init_objectives_fixed(self, init_st_params):
        nu0 = init_st_params["nu"]
        lam0 = init_st_params["lam"]

        self._term_loss_fn = {}
        for k in self.loss_keys:
            mode = str(self.term_mode.get(k, "mse")).lower()

            if mode == "mse":
                self._term_loss_fn[k] = lambda e, state, _k=k: jnp.mean(e * e)
                print(f"Using MSE for term {k}")

            elif mode == "snll":
                _fixed_snll = FixedSNLL(nu_map=nu0, lam_map=lam0)
                self._term_loss_fn[k] = lambda e, state, _k=k: _fixed_snll.term_loss(_k, e, state)
                print(f"Using fixed SNLL for term {k} with nu = {nu0[k]} and lam = {lam0[k]}")

            elif mode == "wls":
                _weighted_wls = WeightedMSE(nu_map=nu0, lam_map=lam0)
                self._term_loss_fn[k] = lambda e, state, _k=k: _weighted_wls.term_loss(_k, e, state)
                print(f"Using fixed WLS for term {k} with nu = {nu0[k]} and lam = {lam0[k]}")

            elif mode == "em":
                if self.st_handler is None:
                    raise ValueError("objective mode 'em' requested but no EM handler was built")
                self._term_loss_fn[k] = lambda e, state, _k=k: self.st_handler.term_loss(_k, e, state)

            else:
                raise ValueError(f"Unknown objective mode '{mode}' for term '{k}'")


    def _init_objectives_dynamic(self):
        self._term_loss_fn = {}
        for k in self.loss_keys:
            mode = str(self.term_mode.get(k, "mse")).lower()

            if mode == "mse":
                self._term_loss_fn[k] = lambda e, state, _k=k: mse(e)

            elif mode == "snll":
                self._term_loss_fn[k] = lambda e, state, _k=k: snll_term_loss(
                    e,
                    nu=state.st_params["nu"][_k],
                    lam=state.st_params["lam"][_k],
                )
                print(f"Using dynamic SNLL for term {k} with nu = {self.state.st_params['nu'][k]} and lam = {self.state.st_params['lam'][k]}")

            elif mode == "wls":
                self._term_loss_fn[k] = lambda e, state, _k=k: wls_term_loss(
                    e,
                    nu=state.st_params["nu"][_k],
                    lam=state.st_params["lam"][_k],
                )
                print(f"Using dynamic WLS for term {k} with nu = {self.state.st_params['nu'][k]} and lam = {self.state.st_params['lam'][k]}")

            elif mode == "em":
                if self.st_handler is None:
                    raise ValueError("objective mode 'em' requested but no EM handler was built")
                self._term_loss_fn[k] = lambda e, state, _k=k: self.st_handler.term_loss(_k, e, state)

            else:
                raise ValueError(f"Unknown objective mode '{mode}' for term '{k}'")


    def set_sampler(self, sampler):
        self.sampler = sampler

    def make_sampler(self, *, dom=None, batch_size=None, **kwargs):
        dom = dom if dom is not None else self.dom
        batch_size = batch_size if batch_size is not None else self.config.training.batch_size
        return UniformSampler(dom, batch_size=batch_size, **kwargs)
    


    def u_net(self, params, *args):
        raise NotImplementedError("Subclasses should implement this!")

    def r_net(self, params, *args):
        raise NotImplementedError("Subclasses should implement this!")
    
    def residuals(self, params, batch, *args):
        raise NotImplementedError("Subclasses should implement this!")


    
    @partial(jit, static_argnums=(0,))
    def losses(self, state, batch, *args):
        err = self.residuals(state, batch, *args)
        return {k: self._term_loss_fn[k](err[k], state) for k in self.loss_keys}
        

    
    @partial(jit, static_argnums=(0,))
    def loss(self, state, batch, *args):
        leaf_losses = self.losses(state, batch, *args)
        #regroup and weight losses
        #losses = group_by_prefix(leaf_losses)
        losses = {
            p: sum(leaf_losses[k] for k in ks)
            for p, ks in self.prefix_group.items()
        }
        weighted = tree_map(lambda x, w: x * w, losses, state.weights)
        total_loss =  tree_reduce(lambda a, b: a + b, weighted)
        return total_loss
    

    @partial(pmap, axis_name="batch", static_broadcasted_argnums=(0,))
    def step(self, state, batch, *args):
        grads = grad(lambda p: self.loss(state.replace(params=p), batch, *args))(state.params)
        grads = lax.pmean(grads, "batch")
        state = state.apply_gradients(grads=grads)
        return state
    
    

    @partial(jit, static_argnums=(0,))
    def compute_weights(self, state, batch, *args):
        scheme = self.config.weighting.scheme

        if scheme == "grad_norm":
            # IMPORTANT: jacrev w.r.t. params ONLY (argnums=0)
            def _losses_wrt_params(st):
                err = self.residuals(st, batch, *args)
                return {k: self._term_loss_fn[k](err[k], state) for k in self.loss_keys}

            grads = jacrev(_losses_wrt_params)(state)

            grad_norm_dict = {}
            for key, value in grads.items():
                grad_norm_dict[key] = jnp.linalg.norm(flatten_pytree(value))

            mean_grad_norm = jnp.mean(jnp.stack(tree_leaves(grad_norm_dict)))
            w = tree_map(lambda x: (mean_grad_norm / (x + 1e-5 * mean_grad_norm)), grad_norm_dict)
            return w
        
        if scheme == "align":
            return self.compute_weights_align(state, batch, *args)

        return state.weights
    


    @partial(pmap, axis_name="batch", static_broadcasted_argnums=(0,))
    def update_weights(self, state, batch, *args):
        scheme = self.config.weighting.scheme

        if scheme != "groupdro":
            weights = self.compute_weights(state, batch, *args)
            weights = lax.pmean(weights, "batch")
            #state = state.ema_update_weights(weights)
            return state

        leaf_losses = self.losses(state, batch, *args)
        grouped = {
            p: sum(leaf_losses[k] for k in ks)
            for p, ks in self.prefix_group.items()
        }
        grouped = lax.pmean(grouped, "batch")

        beta = getattr(self.config.weighting, "groupdro_loss_ema_beta", state.loss_ema_beta)
        state = state.ema_update_losses(grouped, beta=beta)
        return state

    #####
    def maybe_update_objective(self, state, batch, step: int, *args):
        if self.st_handler is None:
            return state
        residuals = self.residuals(state, batch, *args)
        return self.st_handler.maybe_update_hyperparams(state=state, residuals=residuals, step=step)
    

    def _zero_spec_residual(self, params, Z):
        return jnp.array([0.0])

    def _spec_residual(self, params, Z, eps=1e-12):
        feats, _ = vmap(lambda z: self.state.apply_fn(params, z))(Z)
        F = feats / (jnp.linalg.norm(feats, axis=1, keepdims=True) + eps)

        n = F.shape[0]
        C = (F.T @ F) / n

        trC = jnp.trace(C)
        trC2 = jnp.sum(C * C)

        R_spec = trC2 / (trC * trC + eps)
        return jnp.array([jnp.sqrt(R_spec + eps)])


class IVP(PINN, BaseMonitorMixin, AdaptivePINNMixin):
    def __init__(self, config):
        super().__init__(config)

        if config.weighting.use_causal:
            self.tol = config.weighting.causal_tol
            self.num_chunks = config.weighting.num_chunks
            self.M = jnp.triu(jnp.ones((self.num_chunks, self.num_chunks)), k=1).T

        if config.weighting.scheme == "student_t":
            self.lam = config.weighting.lam
            self.nu = config.weighting.nu
        self.init_monitor()

    @partial(jit, static_argnums=(0,))
    def compute_l2_error(self, state):
        u_pred = self.u_pred_fn(state, self.t_star, self.x_star)
        error = jnp.linalg.norm(u_pred - self.u_ref) / jnp.linalg.norm(self.u_ref)
        return error
    @partial(jit, static_argnums=(0,))
    def compute_rmae(self, state):
        u_pred = self.u_pred_fn(state, self.t_star, self.x_star)
        error = jnp.sum(jnp.abs(u_pred - self.u_ref)) / jnp.sum(jnp.abs(self.u_ref))
        return error
    
    def log_errors(self, state, *args):
        l2_error = self.compute_l2_error(state)
        rmae_error = self.compute_rmae(state)
        self.log_dict["rmse_error"] = l2_error
        self.log_dict["rmae_error"] = rmae_error

    def log_preds(self, state):
        u_pred = self.u_pred_fn(state, self.model.t_star, self.model.x_star)
    def log(self, state, batch, *args):
        self.log_dict = super().log(state, batch)
        self.log_errors(state, *args)
        #self.log_preds(state)
        return self.log_dict
    

class BVP(PINN):
    def __init__(self, config):
        super().__init__(config)






