"""
FrankeModel definition for regression example.
Contains model definition and model-related utility functions.
"""

import jax
import jax.numpy as jnp
import numpy as np
import pickle
from flax import linen as nn
from typing import List, Any, Dict, Optional
from pathlib import Path
from foo.models import FlaxNet, TrainState

# Global variable to store hidden layer configuration (used by FrankeModel)
hidden_layers = None

class FrankeModel(FlaxNet):
   """
   Neural network model for Franke function regression task.
   Inherits from foo.models.FlaxNet.
   """
   _output_size: int = 1

   @nn.compact
   def __call__(self, x, train: bool = True):
      """Forward pass through the model."""
      for feat in hidden_layers:
         x = nn.Dense(feat)(x)
         x = nn.relu(x)
      x = nn.Dense(self._output_size)(x)
      return x

def verify_state_dtype(state, expected_dtype):
   """
   Verify that the model state parameters have the expected dtype.

   Args:
      state: Model state containing parameters
      expected_dtype: Expected data type (e.g., jnp.float32, jnp.float64)

   Returns:
      bool: True if the dtype matches, False otherwise
   """
   # Get the first parameter to check its dtype
   sample_param = jax.tree_util.tree_leaves(state.params)[0]
   actual_dtype = sample_param.dtype

   if actual_dtype != expected_dtype:
      print(f"WARNING: Parameter dtype mismatch! Expected {expected_dtype}, got {actual_dtype}")
      print("This could lead to precision issues or unexpected behavior.")
      return False

   print(f"Parameter dtype verification passed: {actual_dtype}")
   return True

def create_model(init_key, hidden_size, X_sample, dtype=jnp.float32):
   """
   Create and initialize a model with specified parameters.
   
   Args:
      init_key: JAX PRNG key for initialization
      hidden_size: List of hidden layer sizes
      X_sample: Sample input for initialization
      dtype: Data type for model parameters
      
   Returns:
      tuple: (model, state) - Initialized model and its state
   """
   # use global variable
   global hidden_layers
   hidden_layers = hidden_size

   model = FrankeModel()
   state = model.init_naive_training(init_key, X_sample, learning_rate=1e-03)

   # Verify the model parameters have the correct dtype
   verify_state_dtype(state, dtype)

   return model, state