
import jax
import jax.numpy as jnp
from jax import vmap,jit
from jax import jacobian
from jax import lax
from jax import random
from jax.tree_util import tree_map


import optax
from flax.training import train_state
import orbax.checkpoint as orbax_ckpt
from flax.training import orbax_utils
import flax.linen as nn


import wandb
import time
from functools import partial

import numpy as np


from ml_collections import ConfigDict


from jax.sharding import PartitionSpec as P, NamedSharding


from jax.experimental.shard_map import shard_map
from jax.lax import pmean


from src.models_2d import  PairwiseNLFluxFieldsNorm
from src.models_2d import CVit2DLatent as CVit2D



from src.whitney_utils import (construct_delta0,
                            construct_oriented_2D, 
                            construct_M0_2D,
                            whitney_1form_vec_2D_midpoints)


from src.whitney_utils import construct_M1_2D_with_K as construct_M1_2D

from src.data_utils import create_dataloaders


class CrModel(nn.Module):
    
    tr_config: ConfigDict

    @nn.compact
    def __call__(self,x,coords):
        force = x[0]
        tensor = x[1]
        
        x,z=CVit2D(**self.tr_config,name="transformer")(force,coords,tensor)
        
        return x,z



def construct_W(W):
    W=nn.softmax(W,axis=0)
    return W




class Trainer:

    def __init__(self,config):
        
        self.config = config
        self.tr_model=CrModel(tr_config=config.model)
    
        
        self.tr_apply=self.tr_model.apply
        ##dataset_load

        
        self.measure=jnp.prod(jnp.array(config.dataset.measure))
        ##datasets_load
        
        self.flux_model = PairwiseNLFluxFieldsNorm(**config.flux)
        self.flux_apply=self.flux_model.apply
        
        self.actual_Npou=config.actual_Npou
        
        x=jnp.arange(0,1,config.h_x)
        y=jnp.arange(0,1,config.h_y)
        
        xx,yy=jnp.meshgrid(x,y,indexing='ij')
        self.coords = jnp.hstack([xx.flatten()[:, None], yy.flatten()[:, None]])
        
        self.ckpt_manager,self.ckpt_manager_cpu = self.create_ckpt_manager(
                                                    config.training.save_folder)
        
    
    def train(self,train_ds,test_ds,force_scale=1,tensor_scale=1):
        
        self.force_scale=force_scale
        self.tensor_scale=tensor_scale
        
        config=self.config

        val_num_samples=config.dataset.val_num_samples 
        val_batch_size=config.training.val_batch_size
        num_steps=config.training.num_steps


        
        train_dl,val_dl=create_dataloaders(config.training.batch_size,
                                           config.training.val_batch_size,
                                           train_ds,
                                           test_ds)

        
        tx1,lr = self.create_opt()
        state=self.create_train_state(tx1)
        
        tx2,lr = self.create_flux_opt()
        flux_state = self.create_flux_state(tx2)


        mesh=jax.make_mesh((jax.local_device_count(),),('batch',))
        replic_sharding=NamedSharding(mesh,P())
        state,flux_state=tree_map(lambda x: jax.device_put(x, replic_sharding), (state,flux_state))


        train_step_fn = jit(shard_map(
                                self.train_step,
                                mesh,
                                in_specs=(P(),P(),P("batch"),P()),
                                out_specs=(P()),
                                check_rep=False,),)
        

        eval_step_fn = jit(shard_map(
                                self.eval_step,
                                mesh,
                                in_specs=(P(),P(),P("batch"),P()),
                                out_specs=(P()),
                                check_rep=False,),)



        #param_count = sum(x.size for x in jax.tree.leaves(state.params))
        #param_count = sum(x.size for x in jax.tree.leaves(state.params["params"]["transformer"]["latent_flux"]))

        key,newton_key = random.split(config.newton.key)
        newton_key=jax.device_put(newton_key,replic_sharding)
        
        mse_loss=1.
        last_loss=1.
    
        start_time=time.time()
            
        for step in range(num_steps):

            batch = next(train_dl)
            flux_state,state,metrics = train_step_fn(flux_state,
                                                    state,
                                                    batch,
                                                    newton_key)
            
            key,newton_key = random.split(key)
            newton_key=jax.device_put(newton_key,replic_sharding)
            
            train_mse=metrics['mse_loss']
            train_flux=metrics['flux_loss']

            
            if (step+1)%config.training.log_loss==0:
                t=time.time()-start_time

                mse_l=[]
                flux_l=[]
                loss_l=[]
                itrs_l=[]

                for _ in range(val_num_samples//val_batch_size):
                    
                    batch = next(val_dl)
                    metrics=eval_step_fn(
                                    flux_state,
                                    state,
                                    batch,
                                    newton_key)
                                
                    
                    key,newton_key = random.split(key)

                    mse_l.append(metrics["mse_loss"])
                    loss_l.append(metrics["loss"])
                    itrs_l.append(metrics["newton_itrs"])
                    flux_l.append(metrics["flux_loss"])

                
                mse_loss=jnp.array(mse_l).mean()
                loss=jnp.array(loss_l).mean()
                itrs=jnp.array(itrs_l).mean()
                flux_loss=jnp.array(flux_l).mean()

                log_dict = {'loss': loss, 'test_mse': mse_loss, 'test_flux':flux_loss,'lr': lr(flux_state.step), 'train_mse': train_mse, 'train_flux':train_flux}
                wandb.log(log_dict,step)


                print(f"Step: {step:>5},lr: {lr(flux_state.step):.5e}, Loss: {loss:.5e}, Train MSE: {train_mse:.5e}, Train Flux: {train_flux:.5e}, MSE Loss: {mse_loss:.5e}, Test Flux:{flux_loss:.5e}, Iterations: {itrs: >2}, Time {t:.4f}")

                start_time=time.time()
                """
                if (np.abs(loss) >= 10*last_loss):
                    print("Loss blow up detected, reverting to previous ckpt")
                    flux_state, state = self.ckpt_load(flux_state, state)
                    continue
                """
                
                
            
            if ((step+1)%config.training.steps_per_save==0 and 
                config.training.save_flag and
                last_loss>np.abs(loss)):

                self.ckpt_save(  
                        flux_state,
                        state,
                        step)
                
                last_loss=loss
                

    
    
    #@partial(jit,static_argnums=(0,))    
    def train_step(self,
                   flux_state,
                   state,
                   batch,
                   newton_key
                   ):

        grad_fn = jax.value_and_grad(self.nonlinear_solve,
                                    argnums=(0,1),
                                    has_aux=True)
        
        (loss,(mse_loss,flux_loss,itrs,_)
        ),(flux_grads,grads)=grad_fn(flux_state.params,
                                     state.params,
                                     batch,
                                     newton_key,
                                     train=True)
        
        #jax.debug.breakpoint()

        flux_grads,grads = tree_map(lambda g: pmean(g, axis_name='batch'),(flux_grads,grads))
    
        flux_state=flux_state.apply_gradients(grads=flux_grads)
        state=state.apply_gradients(grads=grads)
    
    #jax.debug.print("x:{}",flux_state.psi_scale)
        metrics={"loss":pmean(loss,axis_name='batch'),
                "mse_loss":pmean(mse_loss,axis_name='batch'),
                "flux_loss":pmean(flux_loss,axis_name='batch'),
                "newton_itrs":pmean(itrs,axis_name='batch')}
        
        return flux_state,state,metrics
    


    #@partial(jit,static_argnums=(0,))    
    def eval_step(
        self,
        flux_state,
        state,
        batch,
        newton_key,
        ):

        loss,(mse_loss,flux_loss,itrs,_)=self.nonlinear_solve(
                                                    flux_state.params,
                                                    state.params,
                                                    batch,
                                                    newton_key,
                                                    train=False)
        
        metrics={"loss":pmean(loss,axis_name='batch'),
                "mse_loss":pmean(mse_loss,axis_name='batch'),
                "flux_loss":pmean(flux_loss,axis_name='batch'),
                "newton_itrs":pmean(itrs,axis_name='batch')}

        
        return metrics


    def nonlinear_solve(self,flux_params,params,batch,newton_key,train):
        
        solns = batch[0]
        inputs=batch[1]
        fluxes=batch[2]
        batch_size=solns.shape[0]
        num_fields=solns.shape[-1]
        actual_Npou=self.actual_Npou
        
        W,M_1,or_areas,z,force,flux_eval,visc= self.apply(inputs,params)
        #jax.debug.breakpoint()

        parameters=random.normal(newton_key,(batch_size,2*actual_Npou*num_fields))
        
        parameters,itrs=self.newton_method(
                            flux_params,
                            parameters,
                            M_1,
                            or_areas,
                            z,
                            W,
                            flux_eval,
                            solns,
                            fluxes,
                            visc,
                            force,
                            train
                            )


        parameters=lax.stop_gradient(parameters)

        
        
        loss,mse_loss,flux_loss=self.loss_fn(
                            flux_params,
                            parameters,
                            M_1,
                            or_areas,
                            z,
                            W,
                            flux_eval,
                            solns,
                            fluxes,
                            visc,
                            force
                            )
        
        #psi_hat=parameters[:,:actual_Npou]
        #jax.debug.breakpoint()
        return loss,(mse_loss,flux_loss,itrs,parameters)
    

    
    
    @partial(jit,static_argnums=(0,))
    def inference(self,flux_params,params,batch,newton_key):
        
        solns = batch[0]
        inputs=batch[1]
        fluxes=batch[2]
        batch_size=solns.shape[0]
        num_fields=solns.shape[-1]
        actual_Npou=self.actual_Npou
        
        W,M_1,or_areas,z,force,flux_eval,visc= self.apply(inputs,params)

        parameters=random.normal(newton_key,(batch_size,2*actual_Npou*num_fields))
        
        parameters,_=self.newton_method(
                            flux_params,
                            parameters,
                            M_1,
                            or_areas,
                            z,
                            W,
                            flux_eval,
                            solns,
                            fluxes,
                            visc,
                            force,
                            train=False
                            )
        
        parameters=parameters[:,:actual_Npou*num_fields]

        F=partial(self.flux_coef,flux_params)
        fluxes=vmap(F)(parameters,
                       or_areas,
                       z)

        return parameters,fluxes,W,visc

    

    def create_opt(self):
        
        config=self.config
        lr = optax.warmup_exponential_decay_schedule(
                    init_value=config.lr.init_value,
                    peak_value=config.lr.peak_value,
                    end_value=config.lr.end_value,
                    warmup_steps=config.lr.warmup_steps,
                    transition_steps=config.lr.transition_steps,
                    decay_rate=config.lr.decay_rate,
                                        )
        
        tx = optax.chain(
            optax.clip_by_global_norm(config.opt.clip_norm),  # Clip gradients at norm 1
            optax.adamw(lr, weight_decay=config.opt.weight_decay),)

        return tx,lr
    

    def create_train_state(self, tx):

        config=self.config
        x = (jnp.ones(config.force_dim),jnp.ones(config.tensor_dim))
        coords = self.coords
        params = self.tr_model.init(config.model_key, x=x, coords=coords)

        return train_state.TrainState.create(apply_fn=self.tr_apply,
                                            params=params, 
                                            tx=tx)
    

    def create_flux_opt(self):
       
        config=self.config
        lr = optax.warmup_exponential_decay_schedule(
                                    init_value=config.flux_lr.init_value,
                                    peak_value=config.flux_lr.peak_value,
                                    end_value=config.flux_lr.end_value,
                                    warmup_steps=config.flux_lr.warmup_steps,
                                    transition_steps=config.flux_lr.transition_steps,
                                    decay_rate=config.flux_lr.decay_rate,
                                    )
        tx = optax.chain(
                optax.clip_by_global_norm(config.flux_opt.clip_norm),  # Clip gradients at norm 1
                optax.adamw(lr, weight_decay=config.flux_opt.weight_decay),)

        return tx,lr
    
   
    
    def create_flux_state(self,tx):
        config=self.config.flux

        if config.oriented_areas:
            num_one_forms = config.num_dimensions*(config.Npou * (config.Npou - 1) // 2)
        else:
            num_one_forms=0
        
        
        params=self.flux_model.init(self.config.flux_key,
                        jnp.ones((2,config.Npou*config.num_fields+config.extra_params+num_one_forms)))
        
        params['params']['head']['kernel']=params['params']['head']['kernel'].at[:].set(0.0)
        params['params']['head']['bias']=params['params']['head']['bias'].at[:].set(0.0)
        
    
        return train_state.TrainState.create(apply_fn=self.flux_apply, 
                                                    params=params, 
                                                    tx=tx)
    


    def flux(self,
            flux_params,
            psi_hat,
            M_1,
            oriented_areas,
            z,
            force):
        
        flux_apply=self.flux_apply
        
        num_fields=self.config.num_fields

        #or_areas=jnp.diag(M_1)/(oriented_areas+1)
       
        or_areas=(oriented_areas)
        nn_inputs = jnp.concat([psi_hat, z, or_areas])
    
        psi_hat=jnp.reshape(psi_hat,(num_fields,-1))

        _,Npou=psi_hat.shape

        delta_0=construct_delta0(Npou)
        
        A=jnp.matmul(delta_0.T,jnp.matmul(M_1[...,0],delta_0))

        diffusive_term=jnp.matmul(psi_hat,A)
        
        advective_term=jnp.matmul(delta_0.T,
                       jnp.matmul(M_1[...,1],
                                  flux_apply(flux_params,nn_inputs[None]))).T
        
        return  jnp.reshape(diffusive_term+advective_term-force,-1)
    


    def flux_coef(self,
                flux_params,
                psi_hat,
                oriented_areas,
                z):
        
        flux_apply=self.flux_apply
        
        num_fields=self.config.num_fields

        #or_areas=jnp.diag(M_1)/(oriented_areas+1)
    
        or_areas=(oriented_areas)
        nn_inputs = jnp.concat([psi_hat, z, or_areas])
    
        psi_hat=jnp.reshape(psi_hat,(num_fields,-1))

        _,Npou=psi_hat.shape

        delta_0=construct_delta0(Npou)
        
        diffusive_term=jnp.matmul(psi_hat,delta_0.T)
        
        advective_term=flux_apply(flux_params,nn_inputs[None]).T

        return  jnp.stack([jnp.reshape(diffusive_term,-1),jnp.reshape(advective_term,-1)],axis=-1)
    

    

    def apply(self,inputs,params):
    
       
        num_nodes=self.config.num_nodes
        h_x=self.config.h_x
        h_y= self.config.h_y

        coords=self.coords

        force=inputs[0]
        tensor=inputs[1]
        K=tensor[:,:3]

        f_norm=self.measure*jnp.linalg.norm(force,axis=(1,2))
        f_scale=f_norm/self.force_scale

        force=force/f_scale[:,None,None]
        
        tensor=tensor/f_scale
        tensor=tensor/self.tensor_scale[None]

        W, z=self.tr_apply(params,(force,tensor),coords)
       
   
        W=jnp.moveaxis(W,-1,1)
        W=jnp.reshape(W,W.shape[:2]+(num_nodes[0],num_nodes[1]))
        W=vmap(construct_W)(W)

        #making them periodic for the matrices

        W=jnp.concat([W,W[:,:,0:1]],axis=-2)
        W=jnp.concat([W,W[...,0:1]],axis=-1)

        M_1=vmap(construct_M1_2D,in_axes=(0,0,None,None))(W,K,h_x,h_y)
        or_areas=10*vmap(construct_oriented_2D,in_axes=(0,None,None))(W,h_x,h_y)


        force=force*f_scale[:,None,None]
        force=jnp.moveaxis(force,-1,1)

        force=jnp.concat([force,force[:,:,0:1]],axis=-2)
        force=jnp.concat([force,force[...,0:1]],axis=-1)
        force=vmap(construct_M0_2D,in_axes=(0,0,None,None))(force,W,h_x,h_y) 
        

        flux_eval=vmap(whitney_1form_vec_2D_midpoints,in_axes=(0,None,None))(W,h_x,h_y)
        flux_eval=flux_eval[...,None,:]

        #return to original_shape    
        W=W[:,:,:-1,:-1]

        visc=vmap(lambda K: jnp.array([[K[0],K[2]],[K[2],K[1]]]))(K)
        #jax.debug.breakpoint()
        #jax.debug.breakpoint()
        return W, M_1, or_areas, z, force, flux_eval, visc
    



    def newton_method(
           self,
           flux_params,
           parameters,
           M_1,
           or_areas,
           z,
           W,
           flux_eval,
           solns,
           fluxes,
           visc,
           force,
           train
           ):
    
    
        tol=self.config.newton.tol
        num_iters=self.config.newton.num_iters
    
        actual_Npou=self.actual_Npou
        num_fields=self.config.num_fields

        single_F=partial(self.flux,flux_params)

        vec_F=vmap(single_F)

        single_grad_L=partial(self.grad_L_fn,flux_params,train=train)
        
        vec_grad_L=vmap(single_grad_L)
        vec_jacobian_grad_L=vmap(jacobian(single_grad_L))


        J = vec_jacobian_grad_L(
                    parameters,
                    M_1,
                    or_areas,
                    z,
                    W,
                    flux_eval,
                    solns,
                    fluxes,
                    visc,
                    force
                    )
        
        
        grad_L = vec_grad_L(
                    parameters,
                    M_1,
                    or_areas,
                    z,
                    W,
                    flux_eval,
                    solns,
                    fluxes,
                    visc,
                    force
                    )
        

        
        delta=-vmap(lambda lhs, rhs: 
                    #jnp.linalg.solve(lhs, rhs)
                    jnp.linalg.lstsq(lhs, rhs, rcond=1e-15)[0]
                    )(J,grad_L)
        
        
        physics_error = jnp.max(
                            jnp.sqrt(
                                jnp.mean(        
                                    jnp.power(
                                        vec_F(
                                        parameters[:,:num_fields*actual_Npou],
                                        M_1,
                                        or_areas,
                                        z,
                                        force),2),axis=-1)))

        
        iters=0
        epsilon=1.
        
        def body_fun(state):
            
            parameters,i,physics_error,delta,grad_L,epsilon=state

            new_parameters=parameters+epsilon*delta

            new_grad_L=vec_grad_L(
                    new_parameters,
                    M_1,
                    or_areas,
                    z,
                    W,
                    flux_eval,
                    solns,
                    fluxes,
                    visc,
                    force
                    )
            
            
            def true_branch():
                
                new_J=vec_jacobian_grad_L(
                            new_parameters,
                            M_1,
                            or_areas,
                            z,
                            W,
                            flux_eval,
                            solns,
                            fluxes,
                            visc,
                            force
                            )
                
                new_delta=-vmap(lambda lhs, rhs: 
                                #jnp.linalg.solve(lhs, rhs)
                                jnp.linalg.lstsq(lhs, rhs, rcond=1e-15)[0]
                            )(new_J,new_grad_L)
                

                new_physics_error = jnp.max(
                                        jnp.sqrt(
                                            jnp.mean(        
                                                jnp.power(
                                                    vec_F(
                                                    parameters[:,:num_fields*actual_Npou],
                                                    M_1,
                                                    or_areas,
                                                    z,
                                                    force),2),axis=-1)))

                return new_parameters,new_delta,new_physics_error,new_grad_L,1.0
            
            def false_branch():
                
                return parameters,delta,physics_error,grad_L,epsilon/2

            
            update=(jnp.sqrt(jnp.mean(new_grad_L**2)) <
                        jnp.sqrt(jnp.mean(grad_L**2)))
            
            parameters,delta,physics_error,grad_L,epsilon=jax.lax.cond(
                                    update, 
                                    true_branch,
                                    false_branch)
            
            i+=1
            
            return (parameters,i,physics_error,delta,grad_L,epsilon)
            

        def cond_fun(state):
            _,i,physics_error,_,_,_=state 
            return jnp.logical_and(physics_error > tol, 
                                        i < num_iters)
        
        
        (parameters,iters,_,_,_,_)=lax.while_loop(cond_fun, 
                                    body_fun,
                                    init_val=(parameters,iters,physics_error,delta,grad_L,epsilon))
        
        return parameters,iters
        

    def grad_L_fn(
                self,
                flux_params,
                parameters,
                M_1,
                or_areas,
                z,
                W,
                flux_eval,
                solns,
                fluxes,
                visc,
                force,
                train=True
                ):
    
        actual_Npou=self.actual_Npou
        num_fields=self.config.num_fields
       
        F=partial(self.flux,flux_params)
        
        def compute_Lagrangian(parameters):
            
            psi_hat = parameters[:actual_Npou*num_fields]
            lambdas = parameters[num_fields*actual_Npou:]


            F_coef=partial(self.flux_coef,flux_params)
            f_hat=F_coef(psi_hat,
                        or_areas,
                        z)
            
            
            f_hat_r=jnp.reshape(f_hat,(num_fields,-1,2))
            f_hat_r=jnp.swapaxes(f_hat_r,0,1)
            f_hat_r=jnp.reshape(f_hat_r,f_hat_r.shape[:-2]+
                                (W.ndim-1)*(1,)+
                                (f_hat_r.shape[-2],)+(1,)+(2,)) 
            
            
            f_dif=f_hat_r[...,0]
            #jax.debug.breakpoint()
            flux_solns=jnp.sum(f_dif*flux_eval,axis=0)
            flux_solns=jnp.sum(flux_solns[...,None]*visc[None,None,None],axis=-2)                 

            f_con=f_hat_r[...,1]


            #flux_solns+=jnp.sum(jnp.sum(f_con*flux_eval,axis=0)[...,None]*visc[None,None,None],axis=-2)
            flux_solns+=jnp.sum(f_con*flux_eval,axis=0)


            
            flux_error=jnp.mean((flux_solns-fluxes)**2)

            #generalised matrix multiplication
            #or einsum with broadcasting
            psi_hat_r=jnp.reshape(psi_hat,(num_fields,-1)).T
            psi_hat_r=jnp.reshape(psi_hat_r,
                                psi_hat_r.shape[:-1]+
                                (W.ndim-1)*(1,)+
                                (psi_hat_r.shape[-1],))

            
        
            nodal_solns=jnp.sum(psi_hat_r*W[...,None],axis=0)

            error=(nodal_solns-solns)**2
            mse_loss=jnp.mean(error)
            
            equality_constraint=F(psi_hat,
                                M_1,
                                or_areas,
                                z,
                                force
                                )
            
        
            internal_constraints=(lambdas*equality_constraint).sum(axis=-1)
            
            mse_loss=jnp.where(train,mse_loss,jnp.mean(error[0,:])+jnp.mean(error[:,0])+jnp.mean(error[-1,:])+jnp.mean(error[:,-1]))
            flux_error=jnp.where(train,flux_error,0.0)
            
            return (mse_loss
                    +self.config.training.flux_pen*flux_error
                    +internal_constraints)
        
        grad_L = jax.jacobian(compute_Lagrangian)(parameters)
        return grad_L
    

    def loss_fn(self,
                flux_params,
                parameters,
                M_1,
                or_areas,
                z,
                W,
                flux_eval,
                solns,
                fluxes,
                visc,
                force
                ):

        

        batch_size=W.shape[0]
        num_fields=self.config.num_fields
        actual_Npou=self.actual_Npou
        

        psi_hat=parameters[:,:actual_Npou*num_fields]
        lambdas=parameters[:,actual_Npou*num_fields:]

        
        F_coef=vmap(partial(self.flux_coef,flux_params))
        f_hat=F_coef(psi_hat,
                     or_areas,
                     z)
        

        
        
        f_hat_r=jnp.reshape(f_hat,(batch_size,num_fields,-1,2))
        f_hat_r=jnp.swapaxes(f_hat_r,1,2)
        f_hat_r=jnp.reshape(f_hat_r,f_hat_r.shape[:-2]+
                              (W.ndim-2)*(1,)+
                              (f_hat_r.shape[-2],)+(1,)+(2,))
        

        

        f_dif=f_hat_r[...,0]
        flux_solns=jnp.sum(f_dif*flux_eval,axis=1)
        flux_solns=jnp.sum(flux_solns[...,None]*visc[:,None,None,None],axis=-2)                 

        f_con=f_hat_r[...,1]

        #flux_solns+=jnp.sum(jnp.sum(f_con*flux_eval,axis=1)[...,None]*visc[:,None,None,None],axis=-2)
        
        flux_solns+=jnp.sum(f_con*flux_eval,axis=1)


        flux_error=jnp.mean((flux_solns-fluxes)**2)

        #psi_hat for generalised matrix multiplication

        psi_hat_r=jnp.reshape(psi_hat,(batch_size,num_fields,-1))
        psi_hat_r=jnp.moveaxis(psi_hat_r,-1,1)
        psi_hat_r=jnp.reshape(psi_hat_r,psi_hat_r.shape[:-1]+
                            (W.ndim-2)*(1,)+
                            (psi_hat_r.shape[-1],))
    
        #generalized_multiplication
        nodal_solns=jnp.sum(psi_hat_r*W[...,None],axis=1)
        
        error=(nodal_solns-solns)**2

        #mean_ssr_loss=jnp.mean(
                   # jnp.sum(error,axis=tuple(range(1,error.ndim))))
        
        mse_loss=jnp.mean(error)

        
        F=partial(self.flux,flux_params)
        
        vec_F=vmap(F)

        equality_constraint = vec_F(psi_hat,
                                    M_1,
                                    or_areas,
                                    z,
                                    force
                                    )      
        
        
        internal_constraints=jnp.mean((lambdas*equality_constraint).sum(axis=-1))
        
        loss=mse_loss+internal_constraints+self.config.training.flux_pen*flux_error
        #loss=mean_ssr_loss+internal_constraints+boundary_constraints
       

        return loss,mse_loss,flux_error
    

    def ckpt_save(self,flux_state,state,step):  
        ckpt={'flux_state': flux_state,'state': state}
        self.ckpt_manager.save(step, 
                        ckpt,
                        save_kwargs={'save_args':
                                        orbax_utils.save_args_from_target(
                                        ckpt)})
        
        self.ckpt_manager_cpu.save(step, 
                        tree_map(lambda x: jax.device_put(x,jax.devices("cpu")[0]),ckpt), 
                        save_kwargs={'save_args':
                                        orbax_utils.save_args_from_target(
                                        tree_map(lambda x: jax.device_put(x,jax.devices("cpu")[0]),ckpt))})
        

    def create_ckpt_manager(self,dir_path,max_to_keep=1):
     
        ckpt_dir=dir_path
        ckpt_dir_cpu=ckpt_dir+"_cpu"

        options = orbax_ckpt.CheckpointManagerOptions(max_to_keep=max_to_keep,create=True)  # Keep only the latest 1 checkpoint
        ckpt_manager = orbax_ckpt.CheckpointManager(
        ckpt_dir,
        orbax_ckpt.PyTreeCheckpointer(),
        options=options
        )

        ckpt__manager_cpu = orbax_ckpt.CheckpointManager(
        ckpt_dir_cpu,
        orbax_ckpt.PyTreeCheckpointer(),
        options=options
        )
        return ckpt_manager,ckpt__manager_cpu

    

    def ckpt_load_cpu(self,flux_state,state):
        ckpt={'flux_state':flux_state,'state':state}
        raw=self.ckpt_manager_cpu.restore(self.ckpt_manager_cpu.latest_step(),items=ckpt)
        return raw['flux_state'],raw['state'],

    
    def ckpt_load(self,flux_state,state):
        ckpt={'flux_state':flux_state,'state':state}
        raw=self.ckpt_manager.restore(self.ckpt_manager.latest_step(),
                                      items=ckpt,
                                      restore_kwargs={'restore_args': 
                                                      orbax_utils.restore_args_from_target(ckpt)})
        return raw['flux_state'],raw['state']
    