"""
This file contains the pytorch and flax neural network models and loss functions
"""

import jax 
import jax.numpy as jnp 
from typing import Any

def NllFlax(logits, target, batch=None):
   """
   Calculate Nll loss with model output logits and target labels

   Args:
      logits (jax.numpy.ndarray, required):
         The model's output logits
      target (jax.numpy.ndarray, required):
         The target labels (starting from 0)
      batch (tuple or any, optional):
         The input batch data, not used in standard NLL but needed for PINNs

   Returns:
      jax.numpy.ndarray:
         The mean negative log-likelihood loss on the batch
   """
   log_probabilities = jax.nn.log_softmax(logits, axis=1)
   target_one_hot = jax.nn.one_hot(target, num_classes=logits.shape[1])
   target_probabilities = jnp.sum(log_probabilities * target_one_hot, axis=1)
   nll_loss = -target_probabilities
   return jnp.mean(nll_loss)

def MSEFlax(predictions, target, batch=None):
   """
   Calculate Mean Squared Error loss for regression tasks

   Args:
      predictions (jax.numpy.ndarray, required):
         The model's output predictions
      target (jax.numpy.ndarray, required):
         The target values
      batch (tuple or any, optional):
         The input batch data, not used in standard MSE

   Returns:
      jax.numpy.ndarray:
         The mean squared error loss
   """
   squared_error = jnp.square(predictions - target)
   return jnp.mean(squared_error)

def PredictionsClassificationFlax(logits, batch=None):
   """
   Generate predictions from model output logits

   Args:
      logits (jax.numpy.ndarray, required):
         The model's output logits
      batch (tuple or any, optional):
         The input batch data, not used in this function but included for API consistency

   Returns:
      jax.numpy.ndarray:
         The predicted class indices.
   """
   return jnp.argmax(logits, axis=1)


def AccuracyClassificationFlax(predictions, target, batch=None):
   """
   Calculate the accuracy of predictions.

   Args:
      predictions (jax.numpy.ndarray, required):
         The predicted class indices.
      target (jax.numpy.ndarray, required):
         The true class labels.
      batch (tuple or any, optional):
         The input batch data, not used in this function but included for API consistency

   Returns:
      float:
         The accuracy percentage.
   """
   correct = jnp.sum(predictions == target)
   total = target.shape[0]
   return (correct / total) * 100.0

def FlaxUpdateParameters(params, lr, grads):
   """
   Update model parameters using gradients. That is, params -= lr * grads.

   Args:
      params (jax.tree_util.PyTreeDef, required):
         The current model parameters.
      lr (float, required):
         The learning rate.
      grads (jax.tree_util.PyTreeDef, required):
         The gradients for each parameter.

   Returns:
      jax.tree_util.PyTreeDef:
         The updated model parameters.
   """
   params = jax.tree_util.tree_map(lambda p, g: p - lr * g, params, grads)
   return params