"""
PoissonNet model definition for Poisson equation PDE example.
Contains neural network model and related utility functions for PDE solving.
"""

import jax
import jax.numpy as jnp
import numpy as np
from flax import linen as nn
from jax import grad, vmap
from typing import List, Any, Dict, Optional
from foo.models import FlaxNet, TrainState
from data import true_solution

# Global variable to store hidden layer configuration (used by PoissonNet)
hidden_layers = None

class PoissonNet(FlaxNet):
   """
   Neural network model for Poisson equation.
   Inherits from foo.models.FlaxNet.
   """
   _output_size: int = 1

   @nn.compact
   def __call__(self, x, train: bool = True):
      for feat in hidden_layers:
         x = nn.Dense(feat)(x)
         x = nn.tanh(x)
      x = nn.Dense(self._output_size)(x)
      return x

   def _compute_derivatives(self, params, x_val, y_val, train):
      """
      Optimized computation of second derivatives at point (x_val, y_val).
      This implementation is more amenable to JAX's JIT optimization.
      """
      # Calculate second derivative with respect to x
      def eval_fn_x(t):
         point = jnp.array([[t, y_val]])
         return self.apply({'params': params}, point, train=train)[0, 0]
      
      # Calculate second derivative with respect to y
      def eval_fn_y(t):
         point = jnp.array([[x_val, t]])
         return self.apply({'params': params}, point, train=train)[0, 0]
      
      # Using JAX's grad function to compute second derivatives
      # This form is more amenable to JAX's optimization
      u_xx = grad(grad(eval_fn_x))(x_val)
      u_yy = grad(grad(eval_fn_y))(y_val)
      
      return u_xx, u_yy

   def laplacian(self, variables, x, train: bool = True):
      """Compute Laplacian for a batch of points."""
      params = variables['params']
      
      # Define a function that computes Laplacian for a single point
      # This is more JAX-optimization friendly
      def single_point_laplacian(point):
         x_val, y_val = point[0], point[1]
         u_xx, u_yy = self._compute_derivatives(params, x_val, y_val, train)
         return -(u_xx + u_yy)
      
      # Use JAX's vmap for parallel computation
      return vmap(single_point_laplacian)(x)

   def laplacian_single_point(self, variables, x, train: bool = True):
      """Compute Laplacian for a single point (used by optimizers)."""
      params = variables['params']
      x_val, y_val = x[0], x[1]
      
      # Use the optimized derivative computation
      u_xx, u_yy = self._compute_derivatives(params, x_val, y_val, train)
      
      # Return Laplacian
      return -(u_xx + u_yy)
   
   def optax_apply_fn(self, variables, x, train: bool = True):
      """All in one apply function for Optax training."""
      x_i, x_b = x
      u_i = self.laplacian(variables, x_i, train)
      u_b = self.apply(variables, x_b, train)
      return (u_i, u_b)

def verify_state_dtype(state, expected_dtype):
   """
   Verify that the model state parameters have the expected dtype.

   Args:
      state: Model state containing parameters
      expected_dtype: Expected data type (e.g., jnp.float32, jnp.float64)

   Returns:
      bool: True if the dtype matches, False otherwise
   """
   # Get the first parameter to check its dtype
   sample_param = jax.tree_util.tree_leaves(state.params)[0]
   actual_dtype = sample_param.dtype

   if actual_dtype != expected_dtype:
      print(f"WARNING: Parameter dtype mismatch! Expected {expected_dtype}, got {actual_dtype}")
      print("This could lead to precision issues or unexpected behavior.")
      return False

   print(f"Parameter dtype verification passed: {actual_dtype}")
   return True

def create_model(init_key, hidden_size, X_sample, dtype=jnp.float32):
   """
   Create and initialize a model with specified parameters.
   
   Args:
      init_key: JAX PRNG key for initialization
      hidden_size: List of hidden layer sizes
      X_sample: Sample input for initialization
      dtype: Data type for model parameters
      
   Returns:
      tuple: (model, state) - Initialized model and its state
   """
   # use global variable
   global hidden_layers
   hidden_layers = hidden_size

   model = PoissonNet()

   if isinstance(X_sample, tuple):
      X_sample = X_sample[0]

   # For Poisson, we need 2 apply functions: laplacian (PDE), apply (boundary)
   state = model.init_naive_training(init_key, 
                                     X_sample[0], 
                                     learning_rate=1e-03,
                                     primary_apply_fn=model.apply,
                                     apply_fns=[model.laplacian, model.apply],
                                     apply_fns_single_point=[model.laplacian_single_point, model.apply])

   # Verify the model parameters have the correct dtype
   verify_state_dtype(state, dtype)

   return model, state


def evaluate_model(model, params, nx=50, ny=50):
   """
   Evaluate the trained PoissonNet model on a grid and compute error metrics.
   
   Args:
      model: The trained PoissonNet model
      params: Model parameters
      nx: Number of grid points in x direction
      ny: Number of grid points in y direction
      
   Returns:
      X_test: x-coordinates of the evaluation grid
      Y_test: y-coordinates of the evaluation grid
      u_pred: Model predictions on the grid
      u_true: True solution on the grid
      error: Absolute error between prediction and true solution
   """
   # Create evaluation grid
   x = jnp.linspace(0, 1, nx)
   y = jnp.linspace(0, 1, ny)
   X_test, Y_test = jnp.meshgrid(x, y)
   
   # Flatten grid for evaluation
   xy_flat = jnp.column_stack((X_test.flatten(), Y_test.flatten()))
   
   # Evaluate model
   u_pred_flat = model.apply({'params': params}, xy_flat)[:, 0]
   u_pred = u_pred_flat.reshape(X_test.shape)
   
   # Compute true solution
   u_true_flat = jnp.array([true_solution(xi, yi) for xi, yi in xy_flat])
   u_true = u_true_flat.reshape(X_test.shape)
   
   # Compute error
   error = jnp.abs(u_pred - u_true)
   
   return X_test, Y_test, u_pred, u_true, error