from jax import vmap
import jax.numpy as jnp

import flax.linen as nn

from ml_collections import ConfigDict

from src.trainer_2d import Trainer, construct_W
from src.whitney_utils import (construct_M1_2D, 
                                 construct_oriented_2D, 
                                 construct_M0_2D, 
                                 whitney_1form_vec_2D_midpoints)

from src.models_2d import CVit2DLatent as CVit2D




class CrModel(nn.Module):
    
    tr_config: ConfigDict
   
    @nn.compact
    def __call__(self,x,coords):

        x,z=CVit2D(**self.tr_config,name="transformer")(x,coords)
        
        return x,z




class AdvTrainer(Trainer):
        
        def __init__(self, config):
              
              self.tr_model=CrModel(tr_config=config.model)
              self.tr_apply=self.tr_model.apply

              super().__init__(config)

        
        def apply(self,inputs,params):

            batch_size=inputs.shape[0]
            num_nodes=self.config.num_nodes
            h_x=self.config.h_x
            h_y= self.config.h_y

            coords=self.coords
            
            W_int,z=self.tr_apply(params,inputs,coords)
                                
            W_int=jnp.moveaxis(W_int,-1,1)
            W_int=jnp.reshape(W_int,W_int.shape[:2]+(num_nodes[0]-2,num_nodes[1]-2))
            W=vmap(construct_W)(W_int)
            
            #making them periodic for the matrices

            M_1=vmap(construct_M1_2D,in_axes=(0,None,None))(W,h_x,h_y)
            or_areas=10*vmap(construct_oriented_2D,in_axes=(0,None,None))(W,h_x,h_y)

            coords=coords.reshape((num_nodes[0]-2,num_nodes[0]-2,-1))

            force=self.config.force*(jnp.sin(jnp.pi * coords[...,0]) * jnp.sin(jnp.pi * coords[...,1]))
            
            force=jnp.pad(force,pad_width=((1, 1), (1, 1)), mode='constant', constant_values=0)[None]
            
            force=vmap(construct_M0_2D,in_axes=(None,0,None,None))(force,W,h_x,h_y)
            
            #return to original_shape   
            bnd_values=jnp.zeros((batch_size,1,1))
           
            flux_eval=vmap(whitney_1form_vec_2D_midpoints,in_axes=(0,None,None))(W,h_x,h_y)
            flux_eval=flux_eval[...,None,:]
            
            
            return W, M_1, or_areas, z, bnd_values, force, flux_eval