import jax
import jax.numpy as jnp
from jax import vmap,jit
from jax import jacobian
from jax import lax
from jax import random
from jax.tree_util import tree_map

import optax
from flax.training import train_state
import orbax.checkpoint as orbax_ckpt
from flax.training import orbax_utils
import flax.linen as nn



import wandb
import time
from functools import partial



from jax.sharding import PartitionSpec as P, NamedSharding
from jax.experimental.shard_map import shard_map
from jax.lax import pmean

from src.models_2d import  PairwiseNLFluxFieldsNorm


from src.whitney_utils import construct_delta0
from src.data_utils import create_dataloaders





def boundary_indices_fn(
        Npou,
        actual_Npou):
        boundary_indices=jnp.concat([jnp.zeros(Npou,dtype=jnp.bool_),
                                jnp.ones(actual_Npou-Npou,jnp.bool_)])

        return boundary_indices


def construct_W(W_int):
    
    W_int=nn.softmax(W_int,axis=0)
    W_bnd=jnp.zeros_like(W_int[0:1])
    W_bnd=jnp.pad(W_bnd,pad_width=((0,0),(1, 1), (1, 1)), mode='constant', constant_values=1)
    #W_int=jnp.concat([W_int,W_int[...,0][...,None]],axis=-1)
    W_int= jnp.pad(W_int, pad_width=((0, 0), (1, 1), (1, 1)), mode='constant', constant_values=0)

    return jnp.concat([W_int,W_bnd],axis=0) 

class Trainer:

    def __init__(self,config):
        
        self.config = config
        #self.tr_model=CrModel(tr_config=config.model)
        #self.tr_apply=self.tr_model.apply

        self.flux_model = PairwiseNLFluxFieldsNorm(**config.flux)
        self.flux_apply=self.flux_model.apply
        
        self.boundary_indices=boundary_indices_fn(config.Npou,config.actual_Npou)
        self.actual_Npou=config.actual_Npou
        
        x = jnp.linspace(0,1.,config.num_nodes[0])[1:-1]
        y = jnp.linspace(0,1.,config.num_nodes[1])[1:-1]
        
        xx,yy = jnp.meshgrid(x,y,indexing='ij')
        self.coords = jnp.hstack([xx.flatten()[:, None], yy.flatten()[:, None]])
        
        self.ckpt_manager,self.ckpt_manager_cpu = self.create_ckpt_manager(
                                                    config.training.save_folder)
        


    def train(self, train_ds, test_ds, K=1):

        self.K=K
        
        config=self.config

        val_num_samples=config.dataset.val_num_samples 
        val_batch_size=config.training.val_batch_size
        num_steps=config.training.num_steps

        
        train_dl,val_dl=create_dataloaders(config.training.batch_size,
                                            config.training.val_batch_size,
                                            train_ds,
                                            test_ds)
        
        
        tx1,lr = self.create_opt()
        state=self.create_train_state(tx1)
        
        tx2,lr = self.create_flux_opt()
        flux_state = self.create_flux_state(tx2)


        mesh=jax.make_mesh((jax.local_device_count(),),('batch',))
        replic_sharding=NamedSharding(mesh,P())
        state,flux_state=tree_map(lambda x: jax.device_put(x, replic_sharding), (state,flux_state))


        train_step_fn = jit(shard_map(
                                self.train_step,
                                mesh,
                                in_specs=(P(),P(),P("batch"),P()),
                                out_specs=(P()),
                                check_rep=False,),)
        

        eval_step_fn = jit(shard_map(
                                self.eval_step,
                                mesh,
                                in_specs=(P(),P(),P("batch"),P()),
                                out_specs=(P()),
                                check_rep=False,),)
       
        save_step = 0
        
        key,newton_key = random.split(config.newton.key)
        newton_key=jax.device_put(newton_key,replic_sharding)
        #newton_key=config.newton.key
        mse_loss=1
    
        start_time=time.time()
            
        for step in range(num_steps+1):

            batch = next(train_dl)
            flux_state,state,metrics = train_step_fn(flux_state,
                                                    state,
                                                    batch,
                                                    newton_key)
            
            key,newton_key = random.split(key)
            newton_key=jax.device_put(newton_key,replic_sharding)
            #print(f"Step: {step:>5},lr: {lr(flux_state.step):.5e}, Loss: {metrics['loss']:.5e}, MSE Loss: {metrics['mse_loss']:.5e}, Iterations: {metrics['newton_itrs']: >2}")
            
            train_mse=metrics['mse_loss']
            train_flux=metrics['flux_loss']

            
            if (step+1)%config.training.log_loss==0:
                t=time.time()-start_time

                mse_l=[]
                flux_l=[]
                loss_l=[]
                itrs_l=[]

                for _ in range(val_num_samples//val_batch_size):
                    
                    batch = next(val_dl)
                    metrics=eval_step_fn(
                                        flux_state,
                                        state,
                                        batch,
                                        newton_key)
                                    
                    
                    key,newton_key = random.split(key)

                    mse_l.append(metrics["mse_loss"])
                    flux_l.append(metrics["flux_loss"])
                    loss_l.append(metrics["loss"])
                    itrs_l.append(metrics["newton_itrs"])

                
                mse_loss=jnp.array(mse_l).mean()
                flux_loss=jnp.array(flux_l).mean()
                loss=jnp.array(loss_l).mean()
                itrs=jnp.array(itrs_l).mean()


                log_dict = {'loss': loss, 'test_mse': mse_loss, 'test_flux':flux_loss,'lr': lr(flux_state.step), 'train_mse': train_mse, 'train_flux':train_flux}
                wandb.log(log_dict,step)


                print(f"Step: {step:>5},lr: {lr(flux_state.step):.5e}, Loss: {loss:.5e}, Train MSE: {train_mse:.5e}, Train Flux: {train_flux:.5e}, MSE Loss: {mse_loss:.5e}, Flux Loss: {flux_loss:.5e}, Iterations: {itrs: >2}, Time {t:.4f},")

                start_time=time.time()
                
            
            if ((step+1)%config.training.steps_per_save==0 and config.training.save_flag):

                #print(f"Epoch: {epoch:>5},lr: {lr(state.step):.5e},Loss: {metrics['loss']:.5e}, MSE Loss: {metrics['mse_loss']:.5e}, Iterations: {metrics['newton_itrs']: >2}")

                        #end_time=time.time()
                #print(f"For epoch{epoch}: {end_time - start_time:.4f} seconds")


                self.ckpt_save(  
                        flux_state,
                        state,
                        save_step)
                
                save_step += 1
    
    def train_step(self,
                   flux_state,
                   state,
                   batch,
                   newton_key
                   ):

        grad_fn = jax.value_and_grad(self.nonlinear_solve,
                                    argnums=(0,1),
                                    has_aux=True)
        
        (loss,(mse_loss,flux_loss,itrs,parameters)
        ),(flux_grads,grads)=grad_fn(flux_state.params,
                                     state.params,
                                     batch,
                                     newton_key,
                                     train=True)
        
        
        flux_grads,grads = tree_map(lambda g: pmean(g, axis_name='batch'),(flux_grads,grads))

        flux_state=flux_state.apply_gradients(grads=flux_grads)
        state=state.apply_gradients(grads=grads)
    
    
        metrics={"loss":pmean(loss,axis_name='batch'),
                "mse_loss":pmean(mse_loss,axis_name='batch'),
                "flux_loss":pmean(flux_loss,axis_name='batch'),
                "newton_itrs":pmean(itrs,axis_name='batch')}
        
        return flux_state,state,metrics


    def eval_step(
        self,
        flux_state,
        state,
        batch,
        newton_key,
        ):

        loss,(mse_loss,flux_loss,itrs,parameters)=self.nonlinear_solve(
                                        flux_state.params,
                                        state.params,
                                        batch,
                                        newton_key,
                                        train=False)
        
        metrics={"loss":pmean(loss,axis_name='batch'),
                "mse_loss":pmean(mse_loss,axis_name='batch'),
                "flux_loss":pmean(flux_loss,axis_name='batch'),
                "newton_itrs":pmean(itrs,axis_name='batch')}
        
        return metrics


    def nonlinear_solve(self,flux_params,params,batch,newton_key,train):
        
        solns = batch[0]
        inputs = batch[1]
        fluxes = batch[2]
        batch_size=solns.shape[0]
        num_fields=solns.shape[-1]
        actual_Npou=self.actual_Npou

        W, M_1, or_areas, z, bcs, force, flux_eval = self.apply(inputs,params)
        
        parameters=random.normal(newton_key,(batch_size,2*actual_Npou*num_fields))
        
        parameters,itrs=self.newton_method(
                            flux_params,
                            parameters,
                            M_1,
                            or_areas,
                            z,
                            W,
                            flux_eval,
                            solns,
                            fluxes,
                            bcs,
                            force,
                            train
                            )

       
        parameters=lax.stop_gradient(parameters)

        
        loss,mse_loss,flux_loss=self.loss_fn(
                            flux_params,
                            parameters,
                            M_1,
                            or_areas,
                            z,
                            W,
                            flux_eval,
                            solns,
                            fluxes,
                            bcs,
                            force
                            )
        
        #psi_hat=parameters[:,:actual_Npou]
        
        return loss,(mse_loss,flux_loss,itrs,parameters)
    

    
    
    @partial(jit,static_argnums=(0,))
    def inference(self,flux_params,params,batch,newton_key):
        
        solns = batch[0]
        inputs = batch[1]
        fluxes = batch[2]
        
        batch_size=solns.shape[0]
        num_fields=solns.shape[-1]
        
        actual_Npou=self.actual_Npou

        W, M_1, or_areas, z, bcs, force, flux_eval = self.apply(inputs,params)
        
        parameters=random.normal(newton_key,(batch_size,2*actual_Npou*num_fields))
        
        
        parameters,_=self.newton_method(
                            flux_params,
                            parameters,
                            M_1,
                            or_areas,
                            z,
                            W,
                            flux_eval,
                            solns,
                            fluxes,
                            bcs,
                            force,
                            train=False
                            )
        
        parameters=lax.stop_gradient(parameters)
        
        parameters=parameters[:,:actual_Npou*num_fields]
       

        F=partial(self.flux_coef,flux_params)
        fluxes=vmap(F)(parameters,
                       or_areas,
                       z)

        return parameters,fluxes,W


    

    def create_opt(self):
        
        config=self.config
        lr = optax.warmup_exponential_decay_schedule(
                    init_value=config.lr.init_value,
                    peak_value=config.lr.peak_value,
                    end_value=config.lr.end_value,
                    warmup_steps=config.lr.warmup_steps,
                    transition_steps=config.lr.transition_steps,
                    decay_rate=config.lr.decay_rate,
                                        )
        tx = optax.chain(
            optax.clip_by_global_norm(config.opt.clip_norm),  # Clip gradients at norm 1
            optax.adamw(lr, weight_decay=config.opt.weight_decay),)

        return tx,lr
    

    def create_train_state(self, tx):

        config=self.config
        x =jnp.ones(config.nu_dim)
        coords = self.coords
        params = self.tr_model.init(config.model_key, x=x, coords=coords)

        return train_state.TrainState.create(apply_fn=self.tr_apply,
                                            params=params, 
                                            tx=tx)

    def create_flux_opt(self):
       
        config=self.config
        lr = optax.warmup_exponential_decay_schedule(
                                    init_value=config.flux_lr.init_value,
                                    peak_value=config.flux_lr.peak_value,
                                    end_value=config.flux_lr.end_value,
                                    warmup_steps=config.flux_lr.warmup_steps,
                                    transition_steps=config.flux_lr.transition_steps,
                                    decay_rate=config.flux_lr.decay_rate,
                                    )
        tx = optax.chain(
                optax.clip_by_global_norm(config.flux_opt.clip_norm),  # Clip gradients at norm 1
                optax.adamw(lr, weight_decay=config.flux_opt.weight_decay),)

        return tx,lr
    
   
    

    def create_flux_state(self,tx):
        config=self.config.flux

        if config.oriented_areas:
            num_one_forms = config.num_dimensions*(config.Npou * (config.Npou - 1) // 2)
        else:
            num_one_forms=0
        
        
        params=self.flux_model.init(self.config.flux_key,
                        jnp.ones((2,config.Npou*config.num_fields+config.extra_params+num_one_forms)))
        
        params['params']['head']['kernel']=params['params']['head']['kernel'].at[:].set(0.0)
        params['params']['head']['bias']=params['params']['head']['bias'].at[:].set(0.0)
        
    
        return train_state.TrainState.create(apply_fn=self.flux_apply, 
                                                    params=params, 
                                                    tx=tx)
    


    def flux(self,
            flux_params,
            psi_hat,
            M_1,
            oriented_areas,
            z,
            force):
        
        flux_apply=self.flux_apply
        
        num_fields=self.config.num_fields
        
        or_areas=oriented_areas
        #or_areas=jnp.diag(M_1)/(oriented_areas+1)

        nn_inputs = jnp.concatenate([psi_hat, z, or_areas])
        
        psi_hat=jnp.reshape(psi_hat,(num_fields,-1))

        _,Npou=psi_hat.shape
        
        delta_0=construct_delta0(Npou)
        
        A=jnp.matmul(delta_0.T,jnp.matmul(M_1,delta_0))
        
        diffusive_term=jnp.matmul(psi_hat,A)
        
        advective_term=jnp.matmul(delta_0.T,
                       jnp.matmul(M_1,
                                  flux_apply(flux_params,nn_inputs[None]))).T
        
       
       
        return  jnp.reshape(diffusive_term+advective_term-force,-1)
    

    def flux_coef(self,
                flux_params,
                psi_hat,
                oriented_areas,
                z):
        
        flux_apply=self.flux_apply
        
        num_fields=self.config.num_fields

        #or_areas=jnp.diag(M_1)/(oriented_areas+1)
    
        or_areas=(oriented_areas)
        nn_inputs = jnp.concatenate([psi_hat, z, or_areas])
    
        psi_hat=jnp.reshape(psi_hat,(num_fields,-1))

        _,Npou=psi_hat.shape

        delta_0=construct_delta0(Npou)
        
        diffusive_term=jnp.matmul(psi_hat,delta_0.T)
        
        advective_term=flux_apply(flux_params,nn_inputs[None]).T

        return  jnp.stack([jnp.reshape(diffusive_term,-1),jnp.reshape(advective_term,-1)],axis=-1)
    


    def apply(self):
        pass
        """
        batch_size=inputs.shape[0]
        num_nodes=self.config.num_nodes
        h_x=self.config.h_x
        h_y= self.config.h_y
        
        coords=self.coords
        
        W_int,z=self.tr_apply(params,inputs,coords)
                              
        W_int=jnp.moveaxis(W_int,-1,1)
        W_int=jnp.reshape(W_int,W_int.shape[:2]+(num_nodes[0]-2,num_nodes[1]-2))
        W=vmap(construct_W)(W_int)
        
        #making them periodic for the matrices

        M_1=vmap(construct_M1_2D,in_axes=(0,None,None))(W,h_x,h_y)
        or_areas=10*vmap(construct_oriented_2D,in_axes=(0,None,None))(W,h_x,h_y)

        coords=coords.reshape((num_nodes[0]-2,num_nodes[0]-2,-1))

        force=self.config.force*(jnp.sin(jnp.pi * coords[...,0]) * jnp.sin(jnp.pi * coords[...,1]))
        
        force=jnp.pad(force,pad_width=((1, 1), (1, 1)), mode='constant', constant_values=0)[None]
        
        force=vmap(construct_M0_2D,in_axes=(None,0,None,None))(force,W,h_x,h_y)
        
        #return to original_shape   
        bnd_values=jnp.zeros((batch_size,1,1))
        """
        """
        x_nodes=jnp.linspace(0,1,W.shape[-2])
        y_nodes=jnp.linspace(0,1,W.shape[-1])
        flux_eval_points=jnp.reshape(jnp.stack(jnp.meshgrid((x_nodes[1:]+x_nodes[:-1])/2,(y_nodes[1:]+y_nodes[:-1])/2,indexing="ij"),axis=-1),(-1,2))

        flux_eval=vmap(whitney_1form_vec_2D,in_axes=(0,None,None,None,None,None))(W,flux_eval_points,x_nodes,y_nodes,h_x,h_y)
        flux_eval=jnp.reshape(flux_eval,(batch_size,)+(num_one_forms,)+(num_nodes[0]-1,)+(num_nodes[1]-1,)+(num_dims,))
        flux_eval=flux_eval[...,None,:]
        """
        """
        flux_eval=vmap(whitney_1form_vec_2D_midpoints,in_axes=(0,None,None))(W,h_x,h_y)
        flux_eval=flux_eval[...,None,:]
        
        
        return W, M_1, or_areas, z, bnd_values, force, flux_eval

        """
        
    

    def newton_method(
           self,
           flux_params,
           parameters,
           M_1,
           or_areas,
           z,
           W,
           flux_eval,
           solns,
           fluxes,
           bcs,
           force,
           train
        ):
    
    
        tol=self.config.newton.tol
        num_iters=self.config.newton.num_iters
    

        actual_Npou=self.actual_Npou
        num_fields=self.config.num_fields

        boundary_indices=self.boundary_indices
        boundary_indices=jnp.tile(boundary_indices,num_fields)

        single_F=partial(self.flux,flux_params)

        vec_F=vmap(single_F)

        single_grad_L=partial(self.grad_L_fn,flux_params,train=train)
        
        vec_grad_L=vmap(single_grad_L)
        vec_jacobian_grad_L=vmap(jacobian(single_grad_L))


        J = vec_jacobian_grad_L(
                    parameters,
                    M_1,
                    or_areas,
                    z,
                    W,
                    flux_eval,
                    solns,
                    fluxes,
                    bcs,
                    force
                    )
        
        grad_L = vec_grad_L(
                    parameters,
                    M_1,
                    or_areas,
                    z,
                    W,
                    flux_eval,
                    solns,
                    fluxes,
                    bcs,
                    force
                    )
        

        delta=-vmap(lambda lhs, rhs: 
                    jnp.linalg.solve(lhs, rhs)
                    #jnp.linalg.lstsq(lhs, rhs,rcond=1e-15)[0]
                    )(J,grad_L)
        
        
        physics_error = jnp.max(
                            jnp.sqrt(
                                jnp.mean(
                                    jnp.pow(
                                        jnp.where(~boundary_indices,
                                                vec_F(
                                                parameters[:,:num_fields*actual_Npou],
                                                M_1,
                                                or_areas,
                                                z,
                                                force),
                                                0),2),axis=-1)))

        iters=0
        epsilon=1.
        
        def body_fun(state):
            
            parameters,i,physics_error,delta,grad_L,epsilon=state

            new_parameters=parameters+epsilon*delta

            new_grad_L=vec_grad_L(
                    new_parameters,
                    M_1,
                    or_areas,
                    z,
                    W,
                    flux_eval,
                    solns,
                    fluxes,
                    bcs,
                    force
                    )
            
            
            def true_branch():
                
                new_J=vec_jacobian_grad_L(
                            new_parameters,
                            M_1,
                            or_areas,
                            z,
                            W,
                            flux_eval,
                            solns,
                            fluxes,
                            bcs,
                            force
                            )
                
                new_delta=-vmap(lambda lhs, rhs: 
                                jnp.linalg.solve(lhs, rhs)
                                #jnp.linalg.lstsq(lhs, rhs,rcond=1e-15)[0]
                            )(new_J,new_grad_L)
                

                new_physics_error = jnp.max(
                                        jnp.sqrt(
                                            jnp.mean(
                                                jnp.pow(
                                                    jnp.where(~boundary_indices,
                                                            vec_F(
                                                            parameters[:,:num_fields*actual_Npou],
                                                            M_1,
                                                            or_areas,
                                                            z,
                                                            force),
                                                            0),2),axis=-1)))


                return new_parameters,new_delta,new_physics_error,new_grad_L,1.0
            
            def false_branch():
                return parameters,delta,physics_error,grad_L,epsilon/2

            
            update=(jnp.sqrt(jnp.mean(new_grad_L**2)) <
                        jnp.sqrt(jnp.mean(grad_L**2)))
            
            parameters,delta,physics_error,grad_L,epsilon=jax.lax.cond(
                                    update, 
                                    true_branch,
                                    false_branch)
            
            i+=1
            
            return (parameters,i,physics_error,delta,grad_L,epsilon)
            

        def cond_fun(state):
            _,i,physics_error,_,_,_=state 
            return jnp.logical_and(physics_error > tol, 
                                        i < num_iters)
        
        
        (parameters,iters,physics_error,_,_,_)=lax.while_loop(cond_fun, 
                                    body_fun,
                                    init_val=(parameters,iters,physics_error,delta,grad_L,epsilon))
        
        return parameters,iters
    


    def grad_L_fn(
                self,
                flux_params,
                parameters,
                M_1,
                or_areas,
                z,
                W,
                flux_eval,
                solns,
                fluxes,
                dir_bcs,
                force,
                train=True
                ):
    
        actual_Npou=self.actual_Npou
        num_fields=self.config.num_fields
        Npou=self.config.Npou
        boundary_indices=self.boundary_indices
        

        F=partial(self.flux,flux_params)
        
        boundary_indices=jnp.tile(boundary_indices,num_fields)
        
        def compute_Lagrangian(parameters):
            
            psi_hat = parameters[:actual_Npou*num_fields]
            lambdas = parameters[num_fields*actual_Npou:]

            F_coef=partial(self.flux_coef,flux_params)
            f_hat=F_coef(psi_hat,
                        or_areas,
                        z)
            
            f_hat=f_hat.sum(axis=-1)

            f_hat_r=jnp.reshape(f_hat,(num_fields,-1)).T
            f_hat_r=jnp.reshape(f_hat_r,f_hat_r.shape[:-1]+
                                (W.ndim-1)*(1,)+
                                (f_hat_r.shape[-1],)+(1,)) 
            
            

            flux_solns=jnp.sum(f_hat_r*flux_eval,axis=0)
            flux_error=jnp.mean((flux_solns-fluxes)**2)
            

            #generalised matrix multiplication
            #or einsum with broadcasting
            psi_hat_r=jnp.reshape(psi_hat,(num_fields,-1)).T
            psi_hat_r=jnp.reshape(psi_hat_r,
                                psi_hat_r.shape[:-1]+
                                (W.ndim-1)*(1,)+
                                (psi_hat_r.shape[-1],))

            
        
            nodal_solns=jnp.sum(psi_hat_r*W[...,None],axis=0)

            error=(nodal_solns-solns)**2
            mse_loss=jnp.mean(error)
            
            
            equality_constraint=F(psi_hat,
                                M_1,
                                or_areas,
                                z,
                                force
                                )
            
    
            #ssr_loss=jnp.sum((nodal_solns-solns)**2)
            
            ### 
            internal_constraints=(jnp.where(boundary_indices,
                                            0,
                                            lambdas*equality_constraint)).sum(axis=-1)
            
            boundary_constraints=(jnp.where(boundary_indices,
                                            lambdas*psi_hat,
                                            0)).sum(axis=-1)
            
            bc=jnp.zeros((num_fields,Npou))
            bc=jnp.concat([bc,dir_bcs],axis=1)
            bc=jnp.reshape(bc,(-1,))
            
            boundary_constraints-=(lambdas*
                                bc).sum(axis=-1)
         
            mse_loss=jnp.where(train,mse_loss,0)
            flux_error=jnp.where(train,flux_error,0)

            
            
            return (mse_loss
                    +self.config.training.flux_pen*flux_error
                    + internal_constraints
                    + boundary_constraints
                    )
       
        grad_L = jax.jacobian(compute_Lagrangian)(parameters)
        return grad_L

    
    def loss_fn(self,
                flux_params,
                parameters,
                M_1,
                or_areas,
                z,
                W,
                flux_eval,
                solns,
                fluxes,
                dir_bcs,
                force
                ):

        

        batch_size=W.shape[0]
        num_fields=self.config.num_fields
        Npou=self.config.Npou
        actual_Npou=self.actual_Npou
     


        boundary_indices=self.boundary_indices
        boundary_indices=jnp.tile(boundary_indices,num_fields)
        
        psi_hat=parameters[:,:actual_Npou*num_fields]
        lambdas=parameters[:,actual_Npou*num_fields:]


        F_coef=vmap(partial(self.flux_coef,flux_params))
        f_hat=F_coef(psi_hat,
                     or_areas,
                     z)
        

        f_hat=f_hat.sum(axis=-1)
        
        f_hat_r=jnp.reshape(f_hat,(batch_size,num_fields,-1))
        f_hat_r=jnp.transpose(f_hat_r,(0,2,1))
        f_hat_r=jnp.reshape(f_hat_r,f_hat_r.shape[:-1]+
                              (W.ndim-2)*(1,)+
                              (f_hat_r.shape[-1],)+(1,))
        

    
        flux_solns=jnp.sum(f_hat_r*flux_eval,axis=1)
        flux_error=jnp.mean((flux_solns-fluxes)**2)

        #psi_hat for generalised matrix multiplication

        psi_hat_r=jnp.reshape(psi_hat,(batch_size,num_fields,-1))
        psi_hat_r=jnp.transpose(psi_hat_r,(0,2,1))
        psi_hat_r=jnp.reshape(psi_hat_r,psi_hat_r.shape[:-1]+
                            (W.ndim-2)*(1,)+
                            (psi_hat_r.shape[-1],))
    
        #generalized_multiplication
        nodal_solns=jnp.sum(psi_hat_r*W[...,None],axis=1)
        
        error=(nodal_solns-solns)**2

        #mean_ssr_loss=jnp.mean(
                # jnp.sum(error,axis=tuple(range(1,error.ndim))))
        
        mse_loss=jnp.mean(error)

        
        F=partial(self.flux,flux_params)
        
        vec_F=vmap(F)

        equality_constraint = vec_F(psi_hat,
                                    M_1,
                                    or_areas,
                                    z,
                                    force
                                    )      
        
        
        internal_constraints=jnp.mean((jnp.where(boundary_indices,
                                        0,
                                        lambdas*equality_constraint)).sum(axis=-1))
        
        
            
        boundary_constraints=jnp.mean((jnp.where(boundary_indices,
                                            lambdas*psi_hat,
                                            0)).sum(axis=-1))
        
        bc=jnp.zeros((batch_size,num_fields,Npou))
        bc=jnp.concat([bc,dir_bcs],axis=-1)
        bc=jnp.reshape(bc,(batch_size,-1,))

            
        boundary_constraints-=jnp.mean((lambdas*bc).sum(axis=-1))
        
        loss=mse_loss+internal_constraints+boundary_constraints+self.config.training.flux_pen*flux_error
        #loss=mean_ssr_loss+internal_constraints+boundary_constraints
        
        return loss,mse_loss,flux_error
    

    def ckpt_save(self,flux_state,state,step):  
        ckpt={'flux_state': flux_state,'state': state}
        self.ckpt_manager.save(step, 
                            ckpt, 
                            save_kwargs={'save_args':
                                        orbax_utils.save_args_from_target(ckpt)})
        self.ckpt_manager_cpu.save(step, 
                        tree_map(lambda x: jax.device_put(x,jax.devices("cpu")[0]),ckpt), 
                        save_kwargs={'save_args':
                                        orbax_utils.save_args_from_target(
                                        tree_map(lambda x: jax.device_put(x,jax.devices("cpu")[0]),ckpt))})
        

    def create_ckpt_manager(self,dir_path,max_to_keep=1):
     
        ckpt_dir=dir_path
        ckpt_dir_cpu=ckpt_dir+"_cpu"

        options = orbax_ckpt.CheckpointManagerOptions(max_to_keep=max_to_keep,create=True)  # Keep only the latest 1 checkpoint
        ckpt_manager = orbax_ckpt.CheckpointManager(
                        ckpt_dir,
                        orbax_ckpt.PyTreeCheckpointer(),
                        options=options
                        )

        ckpt__manager_cpu = orbax_ckpt.CheckpointManager(
                            ckpt_dir_cpu,
                            orbax_ckpt.PyTreeCheckpointer(),
                            options=options
                            )
        return ckpt_manager,ckpt__manager_cpu

    

    def ckpt_load_cpu(self,flux_state,state):
        ckpt={'flux_state':flux_state,'state':state}
        raw=self.ckpt_manager_cpu.restore(self.ckpt_manager_cpu.latest_step(),
                                          items=ckpt,
                                          restore_kwargs={'restore_args': orbax_utils.restore_args_from_target(ckpt)})
        
        return raw['flux_state'],raw['state']

    
    def ckpt_load(self,flux_state,state):
        ckpt={'flux_state':flux_state,'state':state}
        raw=self.ckpt_manager.restore(self.ckpt_manager.latest_step(),
                                      items=ckpt,
                                      restore_kwargs={'restore_args': orbax_utils.restore_args_from_target(ckpt)})
        return raw['flux_state'],raw['state']
    

