"""
GreensNet model definition for convection-diffusion Green's function problem.
Contains neural network model and physics-informed operator implementation.
"""

import jax
import jax.numpy as jnp
import numpy as np
import pickle
from flax import linen as nn
from jax import grad, vmap, jit
from typing import List, Any, Dict, Optional
from pathlib import Path
from foo.models import FlaxNet, TrainState

# Global variables for model configuration
hidden_layers = None

# Physics coefficients for the convection-diffusion equation
diffusion_coeff = 0.1    # Diffusion coefficient (D)
convection_coeff = 1.0   # Convection coefficient (C)

class GreensNet(FlaxNet):
   _output_size: int = 1

   @nn.compact
   def __call__(self, x, train: bool = True):
      for feat in hidden_layers:
         x = nn.Dense(feat)(x)
         x = nn.tanh(x)
      x = nn.Dense(self._output_size)(x)
      return x

   def _compute_derivatives(self, params, x_val, y_val, train):
      """
      Optimized computation of first and second derivatives for the convection-diffusion operator.
      JIT-friendly implementation that separates derivative calculations.
      
      Args:
         params: Model parameters
         x_val: Evaluation point x-coordinate
         y_val: Source point y-coordinate
         train: Whether in training mode
         
      Returns:
         tuple: (u_x, u_xx) - First and second derivatives w.r.t. x
      """
      # Compute derivatives with respect to x
      def eval_fn_x(t):
         input_point = jnp.array([[t, y_val]])
         return self.apply({'params': params}, input_point, train=train)[0, 0]
      
      # Calculate first derivative for convection term
      u_x = grad(eval_fn_x)(x_val)
      
      # Calculate second derivative for diffusion term
      u_xx = grad(grad(eval_fn_x))(x_val)
      
      return u_x, u_xx

   def convection_diffusion_operator(self, variables, x, train: bool = True):
      """
      Compute the convection-diffusion operator applied to the network output:
      L[G] = -D * ∂²G/∂x² + C * ∂G/∂x
      Where D is the diffusion coefficient and C is the convection coefficient.

      For Green's function: L[G] should match the delta function approximation

      Args:
         variables: Model parameters
         x: Input points (x,y) where x[0] is the evaluation point, x[1] is the source point
         train: Whether in training mode

      Returns:
         Result of applying the convection-diffusion operator to G
      """
      global diffusion_coeff, convection_coeff
      params = variables['params']
      
      # Define a function that computes the operator for a single point
      def single_point_operator(point):
         x_val, y_val = point[0], point[1]
         
         # Get derivatives using the optimized helper method
         u_x, u_xx = self._compute_derivatives(params, x_val, y_val, train)
         
         # Apply the convection-diffusion operator: -D * ∂²G/∂x² + C * ∂G/∂x
         return -diffusion_coeff * u_xx + convection_coeff * u_x
      
      # Apply to all points in parallel using vmap
      return vmap(single_point_operator)(x)
      
   # Alias for backward compatibility
   laplacian = convection_diffusion_operator

   def convection_diffusion_operator_single_point(self, variables, x, train: bool = True):
      """
      Compute the convection-diffusion operator at a single point:
      L[G] = -D * ∂²G/∂x² + C * ∂G/∂x
      
      This is an optimized single-point version that reuses the derivative calculator.

      Args:
         variables: Model parameters
         x: Single input point (x,y)
         train: Whether in training mode

      Returns:
         Result of applying the convection-diffusion operator to G at a single point
      """
      global diffusion_coeff, convection_coeff
      params = variables['params']
      x_val, y_val = x[0], x[1]
      
      # Use the shared derivative calculation method
      u_x, u_xx = self._compute_derivatives(params, x_val, y_val, train)
      
      # Apply the convection-diffusion operator: -D * ∂²G/∂x² + C * ∂G/∂x
      return -diffusion_coeff * u_xx + convection_coeff * u_x
      
   # Alias for backward compatibility
   laplacian_single_point = convection_diffusion_operator_single_point

   def optax_apply_fn(self, variables, x, train: bool = True):
      """
      All-in-one apply function for Optax training.
      Computes both the convection-diffusion operator values for interior points
      and the network values for boundary points.
      """
      x_i, x_b = x
      u_i = self.convection_diffusion_operator(variables, x_i, train)
      u_b = self.apply(variables, x_b, train)
      return (u_i, u_b)

def verify_state_dtype(state, expected_dtype):
   """
   Verify that the model state parameters have the expected dtype.

   Args:
      state: Model state containing parameters
      expected_dtype: Expected data type (e.g., jnp.float32, jnp.float64)

   Returns:
      bool: True if the dtype matches, False otherwise
   """
   # Get the first parameter to check its dtype
   sample_param = jax.tree_util.tree_leaves(state.params)[0]
   actual_dtype = sample_param.dtype

   if actual_dtype != expected_dtype:
      print(f"WARNING: Parameter dtype mismatch! Expected {expected_dtype}, got {actual_dtype}")
      print("This could lead to precision issues or unexpected behavior.")
      return False

   print(f"Parameter dtype verification passed: {actual_dtype}")
   return True

def create_model(init_key, hidden_size, X_sample, dtype=jnp.float32):
   """
   Create and initialize a Green's function model with specified parameters.
   
   Args:
      init_key: JAX PRNG key for initialization
      hidden_size: List of hidden layer sizes
      X_sample: Sample input for initialization
      dtype: Data type for model parameters
      
   Returns:
      tuple: (model, state) - Initialized model and its state
   """
   # Set global variables
   global hidden_layers
   hidden_layers = hidden_size

   # Create and initialize the model
   model = GreensNet()
   state = model.init_naive_training(init_key, 
                                     X_sample, 
                                     learning_rate=1e-03,
                                     primary_apply_fn=model.apply,
                                     apply_fns=[model.convection_diffusion_operator, model.apply],
                                     apply_fns_single_point=[model.convection_diffusion_operator_single_point, model.apply])

   # Verify the model parameters have the correct dtype
   verify_state_dtype(state, dtype)

   return model, state


def evaluate_model(model, params, y_source_points=None, resolution=100, epsilon=0.01):
   """
   Evaluate the trained model for fixed source points y
   
   Args:
      model: The neural network model
      params: The trained model parameters
      y_source_points: Fixed source points y for which to evaluate G(x,y)
      resolution: Number of x points for evaluation
      epsilon: Width parameter of the delta function
      
   Returns:
      X_eval: x coordinates for evaluation
      Y_source: Source points y for which G(x,y) is evaluated
      G_pred: Predicted Green's function values G(x,y)
      laplacian_pred: Predicted Laplacian values (-∂²G/∂x² + C*∂G/∂x)
      laplacian_true: True Laplacian values (delta function approximation)
   """
   # Import here to avoid circular imports
   from data import gaussian_delta

   # Default source points if none provided
   if y_source_points is None:
      y_source_points = [0.25, 0.5, 0.75]
   
   # Create a grid of x points
   x_eval = jnp.linspace(0, 1, resolution)
   
   # Initialize arrays for results
   G_pred = []
   laplacian_pred = []
   laplacian_true = []
   
   # For each source point y, evaluate G(x,y) across all x
   for y in y_source_points:
      # Create input points (x,y) where y is fixed
      xy_points = jnp.column_stack([x_eval, jnp.full_like(x_eval, y)])
      
      # Predict G(x,y)
      g_vals = model.apply({'params': params}, xy_points).flatten()
      G_pred.append(g_vals)
      
      # Compute predicted Laplacian (-∂²G/∂x² + C*∂G/∂x)
      laplacian_vals = model.convection_diffusion_operator({'params': params}, xy_points).flatten()
      laplacian_pred.append(laplacian_vals)
      
      # Compute true Laplacian (delta function approximation)
      true_vals = jnp.array([gaussian_delta(x, y, epsilon) for x in x_eval])
      laplacian_true.append(true_vals)
   
   # Convert lists to arrays
   G_pred = jnp.array(G_pred)
   laplacian_pred = jnp.array(laplacian_pred)
   laplacian_true = jnp.array(laplacian_true)
   
   return x_eval, jnp.array(y_source_points), G_pred, laplacian_pred, laplacian_true