from jax import vmap

import jax.numpy as jnp
from flax.training import train_state

from ml_collections import ConfigDict

import flax.linen as nn

from src.models_2d import CVit2DLatent as CVit2D

from src.trainer_2d import Trainer, construct_W



from src.whitney_utils import (construct_M0_2D, 
                                 construct_M1_2D, 
                                 construct_oriented_2D,
                                 whitney_1form_vec_2D_midpoints)


class CrModel(nn.Module):
    
    tr_config: ConfigDict
    #in_layers: int = 1
    #adapter_params:int = 1
    #adapter_layers: int = 1
    #adapter_hidden_dim: int = 256
    #activation: Callable = nn.gelu
    #kernel_init: Callable = xavier_uniform()
    

    @nn.compact
    def __call__(self,x,coords):
        force = x[0]
        tensor = x[1]
    
        x,z=CVit2D(**self.tr_config,name="transformer")(force,coords,tensor)
        
        return x,z
    
class AdvTrainer(Trainer):
    def __init__(self, config):

        self.tr_model=CrModel(tr_config=config.model)
        self.tr_apply=self.tr_model.apply

        super().__init__(config)


    def train(self, train_ds, test_ds, force_scale = 1, tensor_scale=1):
        
        self.force_scale=force_scale
        self.tensor_scale=tensor_scale
        super().train(train_ds, test_ds)


    
    def create_train_state(self, tx):

        config=self.config
        x = (jnp.ones(config.force_dim),jnp.ones(config.tensor_dim))
        coords = self.coords
        params = self.tr_model.init(config.model_key, x=x, coords=coords)

        return train_state.TrainState.create(apply_fn=self.tr_apply,
                                            params=params, 
                                            tx=tx)

    def apply(self,inputs,params):
    
        batch_size=inputs[0].shape[0]
        num_nodes=self.config.num_nodes
        h_x=self.config.h_x
        h_y= self.config.h_y
        measure=h_x*h_y

        coords=self.coords

        force=inputs[0]
        tensor=inputs[1]
        
        f_norm=measure*jnp.linalg.norm(force,axis=(1,2))
        f_scale=f_norm/self.force_scale

        force=force/f_scale[:,None,None]
        
        tensor=tensor/f_scale
        tensor=tensor/self.tensor_scale[None]

        W_int, z=self.tr_apply(params,(force,tensor),coords)
       
        W_int=jnp.moveaxis(W_int,-1,1)
        W_int=jnp.reshape(W_int,W_int.shape[:2]+(num_nodes[0]-2,num_nodes[1]-2))
        
        W=vmap(construct_W)(W_int)


        M_1=vmap(construct_M1_2D,in_axes=(0,None,None))(W,h_x,h_y)
        or_areas=10*vmap(construct_oriented_2D,in_axes=(0,None,None))(W,h_x,h_y)


        force=force*f_scale[:,None,None]
        force=jnp.moveaxis(force,-1,1)
        force=vmap(construct_M0_2D,in_axes=(0,0,None,None))(force,W,h_x,h_y) 

        bnd_values=jnp.zeros((batch_size,1,1))

        flux_eval=vmap(whitney_1form_vec_2D_midpoints,in_axes=(0,None,None))(W,h_x,h_y)
        flux_eval=flux_eval[...,None,:]
        #return to original_shape    
        #jax.debug.breakpoint()
        #jax.debug.breakpoint()
        return W, M_1, or_areas, z, bnd_values, force, flux_eval
    
    


