import jax
import flax 
import optax 
from jax import numpy as jnp
from flax import struct
from flax.training import train_state
from typing import Any, Callable


class TrainState(train_state.TrainState):
    params: Any
    batch_stats: Any
    loss_fn: Callable = struct.field(pytree_node=False)

    def apply_fn_test(self, x_batch, **kwargs):
      return self.apply_fn({'params': self.params, 
                           'batch_stats': self.batch_stats}, 
                           x_batch, 
                           train = False,
                           mutable = False,
                           **kwargs)
    
    @classmethod
    def create(cls,model,params):
      _params = params['params']
      if 'batch_stats' in params:
        _batch_stats = params['batch_stats']
      else:
        _batch_stats = flax.core.frozen_dict.FrozenDict()
      return cls(apply_fn=model.apply, 
                 params=_params, 
                 batch_stats=_batch_stats, 
                 step=0, 
                 tx=None, 
                 opt_state=None, 
                 loss_fn=optax.softmax_cross_entropy)
