"""
This file contains the pytorch and flax neural network models and loss functions
"""

import jax 
import jax.numpy as jnp 
import flax.linen as fnn
from jax.tree_util import tree_flatten, tree_unflatten
import flax
from flax.training import train_state
import optax
from pathlib import Path
from foo.utils import get_project_root
from foo.models.loss import FlaxUpdateParameters
from typing import Any
import pickle

class TrainState(train_state.TrainState):
   """
   A custom TrainState class for JAX/Flax.
   
   Attributes:
      batch_stats: Batch statistics for normalization layers
      rngs: Random number generator states
      apply_fns: Optional collection of apply functions for multi-objective models like PINNs
   """
   batch_stats: Any
   rngs: Any
   apply_fns: Any = flax.struct.field(pytree_node=False, default=None)
   apply_fns_single_point: Any = flax.struct.field(pytree_node=False, default=None)
   num_apply_fns: int = flax.struct.field(pytree_node=False, default=1)

class FlaxNet(fnn.Module):
   """
   A base class for neural networks implemented with JAX/Flax providing reusable functions.

   Attributes:
      _opt_state (flax.optim.OptimizerState):
         Optimizer state.
      _leaves (List(Any)):
         List of parameter leaves.
      _tree_def (jax.tree_util.PyTreeDef):
         Definition of the parameter tree structure.

   Methods:
      __call__(x, train: bool = True):
         Forward pass of the network.

      init_naive_training(rng, x_train, learning_rate=1e-03):
         Initialize the model parameters and Adam optimizer state for training.
      
      pytree2leaves(params):
         Convert a PyTree of parameters to a flat list of leaves.

      leaves2vector(leaves):
         Convert a list of parameter leaves to a flat vector.

      leaves2vector_batch(leaves):
         Convert a list of batched parameter leaves to a 2D array.
      
      pytree2vector(params):
         Convert a PyTree of parameters to a flat vector.

      vector2leaves(vector):
         Convert a flat vector to a list of parameter leaves.

      leaves2pytree(leaves):
         Convert a list of parameter leaves back to a PyTree.

      vector2pytree(vector):
         Convert a flat vector back to a PyTree of parameters.

      update_parameters(grads, lr, isflat=False):
         Update the model parameters using the provided gradients.
   """

   @fnn.compact
   def __call__(self, x, train: bool = True):
      """
      Forward pass of the network.
      """
      pass

   def init_naive_training(self, rng, x_train, learning_rate=1e-03, subkey=None, primary_apply_fn=None, apply_fns=None, apply_fns_single_point=None):
      """
      Initialize the model parameters and Adam optimizer state for training.

      Args:
         rng (jax.random.PRNGKey, required):
            JAX random number generator key.
         x_train (jax.numpy.ndarray, required):
            Sample input data for initializing the model.
         learning_rate (float, optional):
            Learning rate for the optimizer (default: 1e-03).
         subkey (jax.random.PRNGKey, optional):
            JAX random number generator key. If None, a new key is generated.
         primary_apply_fn (Callable, optional):
            Apply function for the model u = model(x). If None, the default apply function is used.
         apply_fns (Callable or List[Callable], optional):
            Apply function(s) for the model u_i = operator_i(model(x_i)). If None, the default apply function is used.
            Can be a single function or a list/tuple of functions for multi-objective models.
         double_precision (bool, optional):
            Whether to use double precision for the optimizer state.

      Returns:
         opt_state (flax.optim.OptimizerState):
            The optimizer state.
      """
      
      # Determine the data type from the input
      if isinstance(x_train, jnp.ndarray):
         dtype = x_train.dtype
      else:
         dtype = x_train[0].dtype
      
      # Initialize the model with the input's data type
      variables = self.init(rng, x_train, train=False)
      
      # Convert parameters to match input dtype if they don't already
      if jax.config.x64_enabled and dtype == jnp.float64:
          variables = jax.tree_util.tree_map(
              lambda p: p.astype(jnp.float64) if p.dtype != jnp.float64 else p,
              variables
          )
      
      if 'batch_stats' in variables:
         batch_stats = variables['batch_stats']
      else:
         batch_stats = None

      subkey = jax.random.PRNGKey(42) if subkey is None else subkey

      if primary_apply_fn is None:
         primary_apply_fn = self.apply
      else:
         primary_apply_fn = primary_apply_fn

      if apply_fns is None:
         apply_fns = (primary_apply_fn,)
      elif not isinstance(apply_fns, (list, tuple)):
         apply_fns = (apply_fns,)

      if apply_fns_single_point is None:
         apply_fns_single_point = apply_fns
      elif not isinstance(apply_fns_single_point, (list, tuple)):
         apply_fns_single_point = (apply_fns_single_point,)

      self._opt_state = TrainState.create(
         apply_fn=primary_apply_fn,
         apply_fns=apply_fns,
         apply_fns_single_point=apply_fns_single_point,
         num_apply_fns=len(apply_fns),
         params=variables['params'],
         batch_stats=batch_stats,
         tx=optax.adam(learning_rate=learning_rate),
         rngs={'dropout': subkey}
      )
      
      return self._opt_state

   def pytree2leaves(self, params):
      """
      Convert a PyTree of parameters to a flat list of leaves.

      Args:
         params (jax.tree_util.PyTreeDef, required):
            PyTree of parameters.

      Returns:
         List(Any):
            Flat list of parameter leaves.
      """
      self._leaves, self._tree_def = tree_flatten(params)
      return self._leaves

   def leaves2vector(self, leaves):
      """
      Convert a list of parameter leaves to a flat vector. Can be used after calling pytree2leaves.

      Args:
         leaves (List, required):
            Flat list of parameter leaves.

      Returns:
         jax.numpy.ndarray:
            Flat vector of parameters.
      """
      return jnp.concatenate([jnp.ravel(leaf) for leaf in leaves])

   def leaves2vector_batch(self, leaves):
      """
      Convert a list of batched parameter leaves to a 2D array.

      Args:
         leaves: (List, required):
            List of batched parameter leaves.

      Returns:
         jax.numpy.ndarray:
            2D array with dimensions (batch, *).
      """
      return jnp.concatenate([jnp.reshape(leaf, (leaves[0].shape[0], -1)) for leaf in leaves], axis=1)

   def pytree2vector(self, params):
      """
      Convert a PyTree of parameters to a flat vector.

      Args:
         params: (jax.tree_util.PyTreeDef, required):
            PyTree of parameters.

      Returns:
         jax.numpy.ndarray:
            Flat vector of parameters.
      """
      return self.leaves2vector(self.pytree2leaves(params))

   def vector2leaves(self, vector):
      """
      Convert a flat vector to a list of parameter leaves.

      Args:
         vector (jax.numpy.ndarray, required):
            Flat vector of parameters.

      Returns:
         List(Any):
            List of parameter leaves.
      """
      if not hasattr(self, '_leaves') or self._leaves is None:
         raise ValueError("Leaves are not set. Call pytree2leaves first.")
      start = 0
      leaves = []
      for leaf in self._leaves:
         num_elements = jnp.prod(jnp.array(leaf.shape)).item()
         leaves.append(vector[start:start+num_elements].reshape(leaf.shape))
         start += num_elements
      return leaves

   def leaves2pytree(self, leaves):
      """
      Convert a list of parameter leaves back to a PyTree.

      Args:
         leaves (List, required):
            List of parameter leaves.

      Returns:
         jax.tree_util.PyTreeDef:
            Reconstructed PyTree of parameters.
      """
      if not hasattr(self, '_tree_def') or self._tree_def is None:
         raise ValueError("Tree definition is not set. Call pytree2leaves first.")
      return tree_unflatten(self._tree_def, leaves)

   def vector2pytree(self, vector):
      """
      Convert a flat vector back to a PyTree of parameters.

      Args:
         vector (jax.numpy.ndarray, required):
            Flat vector of parameters.

      Returns:
         jax.tree_util.PyTreeDef:
            Reconstructed PyTree of parameters.
      """
      return self.leaves2pytree(self.vector2leaves(vector))

   def update_parameters(self, grads, lr, isflat=False):
      """
      Update the model parameters in self._opt_state using the provided gradients.

      Args:
         grads: (jax.tree_util.PyTreeDef or jax.numpy.ndarray, required):
            Gradients for parameter updates.
         lr: (float, required):
            Learning rate.
         isflat: (bool, optional):
            Whether the gradients are provided as a flat vector or jax.tree_util.PyTreeDef (default: False).

      Returns:
         jax.tree_util.PyTreeDef:
            Updated model parameters.
      """
      if isflat:
         grads = self.vector2pytree(grads)
      new_params = jax.tree_util.tree_map(lambda p, g: p - lr * g, self._opt_state.params, grads)
      self._opt_state = self._opt_state.replace(params=new_params)
      return self._opt_state

   def save_model(self, params, name="model", path=None):
      """
      Save the model and its parameters
      
      Args:
         params (jax.tree_util.PyTreeDef, required):
            The model parameters to save
         name (str, optional):
            Name of the model file. Defaults to "model"
         path (Path or str, optional):
            Path to save directory. Defaults to project root
      """
      
      if path is None:
         path = get_project_root()
      else:
         path = Path(path)
         
      path.mkdir(parents=True, exist_ok=True)
      
      save_dict = {
         'model_class': self.__class__,
         'model_config': {
            '_output_size': self._output_size,
            '_hidden': getattr(self, '_hidden', None),
            '_activation': getattr(self, '_activation', 'relu')
         },
         'params': params
      }
      
      with open(path / f"{name}.pkl", 'wb') as f:
         pickle.dump(save_dict, f)

   @classmethod
   def load_model(cls, name="model", path=None):
      """
      Load a model and its parameters
      
      Args:
         name (str, optional):
            Name of the model file. Defaults to "model"
         path (Path or str, optional):
            Path to load directory. Defaults to project root
            
      Returns:
         tuple: (model, params)
      """
      
      if path is None:
         path = get_project_root()
      else:
         path = Path(path)
      
      with open(path / f"{name}.pkl", 'rb') as f:
         save_dict = pickle.load(f)
      
      # Reconstruct the model
      model_class = save_dict['model_class']
      config = save_dict['model_config']
      
      if model_class == FlaxNetDNN:
         model = model_class(
            _output_size=config['_output_size'],
            _hidden=config['_hidden'],
            _activation=config['_activation']
         )
      else:
         # For models without additional config (like LeNet, ResNet)
         model = model_class(_output_size=config['_output_size'])
      
      return model, save_dict['params']

# Next is a simple DNN model

class FlaxNetDNN(FlaxNet):
   """
   A configurable Deep Neural Network (DNN) implementation for JAX/Flax.

   Attributes:
      _output_size (int, required):
         The size of the output layer.
      _hidden (tuple[int, ...], optional):
         Tuple of integers representing the number of nodes per layer for the hidden layers. Defaults to (100, 80).
      _activation (str, optional):
         Activation function name. Defaults to 'relu'. Note that you can use any upper or lower case version.

   Methods:
      __call__(x, train: bool = True):
         Forward pass of the network.
   """
   _output_size: int
   _hidden: tuple[int, ...] = (100, 80)
   _activation: str = 'relu'

   @fnn.compact
   def __call__(self, x, train: bool = True):
      """
      Forward pass of the network.

      Args:
         x (jax.numpy.ndarray, required):
            Input data.
         train (bool, optional):
            Whether to use training mode. Defaults to True.

      Returns:
         jax.numpy.ndarray:
            Output data.
      """

      activation = getattr(fnn, self._activation.lower())
      for layer in self._hidden:
         x = fnn.Dense(layer, kernel_init=jax.nn.initializers.xavier_uniform())(x)
         x = activation(x)
      x = fnn.Dense(self._output_size, kernel_init=jax.nn.initializers.xavier_uniform())(x)
      return x

class FlaxNetDNNBN(FlaxNet):
   """
   A configurable Deep Neural Network (DNN) implementation for JAX/Flax with BatchNorm layers.
   """
   _output_size: int
   _hidden: tuple[int, ...] = (100, 80)
   _activation: str = 'relu'

   @fnn.compact
   def __call__(self, x, train: bool = True):
      """
      Forward pass of the network.

      Args:
         x (jax.numpy.ndarray, required):
            Input data.
         train (bool, optional):
            Whether to use training mode. Defaults to True.

      Returns:
         jax.numpy.ndarray:
            Output data.
      """
      activation = getattr(fnn, self._activation.lower())
      for layer in self._hidden:
         x = fnn.Dense(layer, kernel_init=jax.nn.initializers.xavier_uniform())(x)
         x = fnn.BatchNorm(use_running_average=not train)(x)
         x = activation(x)
      x = fnn.Dense(self._output_size, kernel_init=jax.nn.initializers.xavier_uniform())(x)
      x = fnn.BatchNorm(use_running_average=not train)(x)
      return x

class FlaxNetLeNet(FlaxNet):
   """
   A LeNet-like neural network implemented with JAX/Flax.

   Attributes:
      _output_size (int, required):
         The size of the output layer.

   Methods:
      __call__(x, train: bool = True):
         Forward pass of the LeNet network.
   """
   _output_size: int

   @fnn.compact
   def __call__(self, x, train: bool = True):
      """
      Forward pass of the LeNet network.

      Args:
         x (jax.numpy.ndarray, required):
            Input data.
         train (bool, optional):
            Whether to use training mode. Defaults to True.

      Returns:
         jax.numpy.ndarray:
            Output data.
      """
      if len(x.shape) == 3:
         x = x[None, ...]
      x = fnn.Conv(features=6, kernel_size=(5, 5))(x)
      x = fnn.tanh(x)
      x = fnn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
      x = fnn.Conv(features=16, kernel_size=(5, 5))(x)
      x = fnn.tanh(x)
      x = fnn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
      x = x.reshape((x.shape[0], -1))
      x = fnn.Dense(features=120, kernel_init=jax.nn.initializers.xavier_uniform())(x)
      x = fnn.tanh(x)
      x = fnn.Dense(features=84, kernel_init=jax.nn.initializers.xavier_uniform())(x)
      x = fnn.tanh(x)
      x = fnn.Dense(features=self._output_size, kernel_init=jax.nn.initializers.xavier_uniform())(x)
      if x.shape[0] == 1:
         x = x.squeeze(axis=0)
      return x

class FlaxNetVGG(FlaxNet):
   """
   A VGG11-like neural network implemented with JAX/Flax.

   Attributes:
      _output_size (int, required):
         The size of the output layer.

   Methods:
      __call__(x, train: bool = True):
         Forward pass of the VGG11 network.
   """
   _output_size: int

   @fnn.compact
   def __call__(self, x, train: bool = True):
      """
      Forward pass of the VGG11 network.

      Args:
         x (jax.numpy.ndarray, required):
            Input data.
         train (bool, optional):
            Whether to use training mode. Defaults to True.

      Returns:
         jax.numpy.ndarray:
               Output data.
      """
      if len(x.shape) == 3:
         x = x[None, ...]

      # Conv Block 1
      x = fnn.Conv(features=64, kernel_size=(3, 3))(x)
      x = fnn.relu(x)
      x = fnn.max_pool(x, window_shape=(2, 2), strides=(2, 2))

      # Conv Block 2
      x = fnn.Conv(features=128, kernel_size=(3, 3))(x)
      x = fnn.relu(x)
      x = fnn.max_pool(x, window_shape=(2, 2), strides=(2, 2))

      # Conv Block 3
      x = fnn.Conv(features=256, kernel_size=(3, 3))(x)
      x = fnn.relu(x)
      x = fnn.Conv(features=256, kernel_size=(3, 3))(x)
      x = fnn.relu(x)
      x = fnn.max_pool(x, window_shape=(2, 2), strides=(2, 2))

      # Conv Block 4
      x = fnn.Conv(features=512, kernel_size=(3, 3))(x)
      x = fnn.relu(x)
      x = fnn.Conv(features=512, kernel_size=(3, 3))(x)
      x = fnn.relu(x)
      x = fnn.max_pool(x, window_shape=(2, 2), strides=(2, 2))

      # Conv Block 5
      x = fnn.Conv(features=512, kernel_size=(3, 3))(x)
      x = fnn.relu(x)
      x = fnn.Conv(features=512, kernel_size=(3, 3))(x)
      x = fnn.relu(x)
      x = fnn.max_pool(x, window_shape=(2, 2), strides=(2, 2))

      # Flatten
      x = x.reshape((x.shape[0], -1))

      # Fully Connected Layers
      x = fnn.Dense(features=4096, kernel_init=jax.nn.initializers.xavier_uniform())(x)
      x = fnn.relu(x)
      #x = fnn.Dropout(rate=0.5)(x, deterministic=False)  # Add dropout for regularization
      x = fnn.Dense(features=4096, kernel_init=jax.nn.initializers.xavier_uniform())(x)
      x = fnn.relu(x)
      #x = fnn.Dropout(rate=0.5)(x, deterministic=False)  # Add dropout for regularization
      x = fnn.Dense(features=self._output_size, kernel_init=jax.nn.initializers.xavier_uniform())(x)

      if x.shape[0] == 1:
         x = x.squeeze(axis=0)
      return x
   
class FlaxNetResNet(FlaxNet):
   """
   A ResNet-like neural network implemented with JAX/Flax.

   Attributes:
      _output_size (int, required):
         The size of the output layer.

   Methods:
      __call__(x, train: bool = True):
         Forward pass of the ResNet network.
   """
   _output_size: int


   def residual_block(self, x, features, strides=(1, 1), train: bool = True):
      """Basic ResNet residual block with two 3x3 convolutions"""
      shortcut = x
      
      # First conv block
      x = fnn.Conv(features=features, kernel_size=(3, 3), strides=strides, padding=((1, 1), (1, 1)))(x)
      x = fnn.BatchNorm(use_running_average=not train)(x)
      x = fnn.relu(x)
      
      # Second conv block
      x = fnn.Conv(features=features, kernel_size=(3, 3), padding=((1, 1), (1, 1)))(x)
      x = fnn.BatchNorm(use_running_average=not train)(x)

      # Shortcut connection
      if shortcut.shape != x.shape:
         shortcut = fnn.Conv(features=features, kernel_size=(1, 1), strides=strides)(shortcut)
         shortcut = fnn.BatchNorm(use_running_average=not train)(shortcut)

      x = x + shortcut
      return fnn.relu(x)

   @fnn.compact
   def __call__(self, x, train: bool = True):
      """
      Forward pass of the ResNet network.

      Args:
         x (jax.numpy.ndarray, required):
            Input data.
         train (bool, optional):
            Whether to use training mode. Defaults to True.

      Returns:
         jax.numpy.ndarray:
            Output data.
      """
      if len(x.shape) == 3:
         x = x[None, ...]

      # Initial convolution
      x = fnn.Conv(features=64, kernel_size=(7, 7), strides=(2, 2), padding=((3, 3), (3, 3)))(x)
      x = fnn.BatchNorm(use_running_average=not train)(x)
      x = fnn.relu(x)
      x = fnn.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding=((1, 1), (1, 1)))

      # ResNet blocks
      # Layer 1
      x = self.residual_block(x, features=64)
      x = self.residual_block(x, features=64)

      # Layer 2
      x = self.residual_block(x, features=128, strides=(2, 2))
      x = self.residual_block(x, features=128)

      # Layer 3
      x = self.residual_block(x, features=256, strides=(2, 2))
      x = self.residual_block(x, features=256)

      # Layer 4
      x = self.residual_block(x, features=512, strides=(2, 2))
      x = self.residual_block(x, features=512)

      # Global average pooling
      x = jnp.mean(x, axis=(1, 2))

      # Final dense layer
      x = fnn.Dense(features=self._output_size, kernel_init=jax.nn.initializers.xavier_uniform())(x)

      if x.shape[0] == 1:
         x = x.squeeze(axis=0)
      return x

class FlaxNetResNet18_Cust(FlaxNet):
   """
   A CIFAR-10 optimized ResNet18 implementation, following the original paper's architecture.
   """
   _output_size: int

   def basic_block(self, x, planes, stride=(1, 1), train: bool = True):
      """Basic ResNet block optimized for CIFAR-10"""
      shortcut = x
      in_planes = x.shape[-1]  # Get input channels from x
      
      # First conv block
      out = fnn.Conv(
         features=planes, 
         kernel_size=(3, 3),
         strides=stride,
         padding=((1, 1), (1, 1)),
         use_bias=False)(x)
      out = fnn.BatchNorm(use_running_average=not train)(out)
      out = fnn.relu(out)
      
      # Second conv block
      out = fnn.Conv(
         features=planes,
         kernel_size=(3, 3),
         strides=(1, 1),
         padding=((1, 1), (1, 1)),
         use_bias=False)(out)
      out = fnn.BatchNorm(use_running_average=not train)(out)

      # Shortcut
      if stride != (1, 1) or in_planes != planes:
         shortcut = fnn.Conv(
            features=planes,
            kernel_size=(1, 1),
            strides=stride,
            use_bias=False)(x)
         shortcut = fnn.BatchNorm(use_running_average=not train)(shortcut)

      out = out + shortcut
      out = fnn.relu(out)
      return out

   @fnn.compact
   def __call__(self, x, train: bool = True):
      if len(x.shape) == 3:
         x = x[None, ...]

      # Initial convolution - CIFAR-10 optimized
      x = fnn.Conv(
         features=64,
         kernel_size=(3, 3),
         strides=(1, 1),
         padding=((1, 1), (1, 1)),
         use_bias=False)(x)
      x = fnn.BatchNorm(use_running_average=not train)(x)
      x = fnn.relu(x)

      # Layer 1
      x = self.basic_block(x, planes=64, stride=(1, 1), train=train)
      x = self.basic_block(x, planes=64, stride=(1, 1), train=train)

      # Layer 2
      x = self.basic_block(x, planes=128, stride=(2, 2), train=train)
      x = self.basic_block(x, planes=128, stride=(1, 1), train=train)

      # Layer 3
      x = self.basic_block(x, planes=256, stride=(2, 2), train=train)
      x = self.basic_block(x, planes=256, stride=(1, 1), train=train)

      # Layer 4
      x = self.basic_block(x, planes=512, stride=(2, 2), train=train)
      x = self.basic_block(x, planes=512, stride=(1, 1), train=train)

      # Final pooling and dense layer
      x = fnn.avg_pool(x, window_shape=(4, 4))
      x = x.reshape((x.shape[0], -1))
      x = fnn.Dense(features=self._output_size)(x)

      if x.shape[0] == 1:
         x = x.squeeze(axis=0)
      return x
