"""
Data generation and visualization for Franke function regression example.
"""

import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import os

def franke_function(x, y):
   """
   Original Franke function for 2D regression using NumPy.
   
   Args:
      x: x coordinates
      y: y coordinates
      
   Returns:
      Function values at (x, y)
   """
   term1 = 0.75 * np.exp(-(9*x-2)**2/4 - (9*y-2)**2/4)
   term2 = 0.75 * np.exp(-(9*x+1)**2/49 - (9*y+1)/10)
   term3 = 0.5 * np.exp(-(9*x-7)**2/4 - (9*y-3)**2/4)
   term4 = -0.2 * np.exp(-(9*x-4)**2 - (9*y-7)**2)
   return term1 + term2 + term3 + term4

def franke_function_jax(x, y):
   """
   JAX version of the Franke function for 2D regression.
   
   Args:
      x: x coordinates
      y: y coordinates
      
   Returns:
      Function values at (x, y)
   """
   term1 = 0.75 * jnp.exp(-(9*x-2)**2/4 - (9*y-2)**2/4)
   term2 = 0.75 * jnp.exp(-(9*x+1)**2/49 - (9*y+1)/10)
   term3 = 0.5 * jnp.exp(-(9*x-7)**2/4 - (9*y-3)**2/4)
   term4 = -0.2 * jnp.exp(-(9*x-4)**2 - (9*y-7)**2)
   return term1 + term2 + term3 + term4

def generate_data_jax(key, n_samples, noise_level):
   """
   Generate training data with the Franke function using JAX for better JIT compatibility.
   
   Args:
      key: JAX PRNG key
      n_samples: Number of samples to generate
      noise_level: Standard deviation of noise to add
      
   Returns:
      x: x coordinates, shape (n_samples,)
      y: y coordinates, shape (n_samples,)
      z_normalized: normalized z values, shape (n_samples,)
   """
   # Split the key for x, y and noise
   key_x, key_y, key_noise = jax.random.split(key, 3)
   
   # Generate random x and y coordinates in [0, 1]
   x = jax.random.uniform(key_x, (n_samples,))
   y = jax.random.uniform(key_y, (n_samples,))
   
   # Compute z values with noise
   z_true = jax.vmap(franke_function_jax)(x, y)
   noise = jax.random.normal(key_noise, (n_samples,)) * noise_level
   z = z_true + noise
   
   return x, y, z

def generate_data(n_samples, noise_level, seed=42):
   """
   Generate training data with the Franke function.
   Now uses JAX for better performance and JIT compatibility.
   
   Args:
      n_samples: Number of samples to generate
      noise_level: Standard deviation of noise to add
      seed: Random seed
      
   Returns:
      Tuple of data from generate_data_jax
   """
   key = jax.random.PRNGKey(seed)
   return generate_data_jax(key, n_samples, noise_level)

def visualize_dataset(x, y, z, figure_dir, filename="dataset_visualization.png"):
   """
   Visualize the generated dataset to verify sampling correctness
   
   Args:
      x: x coordinates
      y: y coordinates
      z: corresponding z values (normalized)
      figure_dir: Directory to save the visualization
      filename: Name of the visualization file
   """
   plt.figure(figsize=(10, 8))
   
   # 3D scatter plot of the data
   ax = plt.axes(projection='3d')
   ax.scatter3D(x, y, z, c=z, cmap='viridis', s=10, alpha=0.8)
   
   ax.set_xlabel('X', fontsize=20)
   ax.set_ylabel('Y', fontsize=20)
   ax.set_zlabel('Z', fontsize=20)
   ax.set_title('Franke Function Dataset Samples', fontsize=25)
   
   # Save the visualization
   plt.savefig(os.path.join(figure_dir, filename), dpi=300)
   plt.close()
   
   print(f"Dataset visualization saved to {os.path.join(figure_dir, filename)}")

def prepare_test_data(dtype=jnp.float64, jax_device=None):
   """
   Generate test data for model evaluation on a grid.
   
   Args:
      dtype: Data type for arrays
      jax_device: JAX device to place data on
      
   Returns:
      X_test_jax: Test inputs
      X_test, Y_test: Meshgrid for visualization
      Z_true: True Franke function values
   """
   # Generate test data for evaluation
   x_test = np.linspace(0, 1, 30)
   y_test = np.linspace(0, 1, 30)
   X_test, Y_test = np.meshgrid(x_test, y_test)
   X_flat = X_test.flatten()
   Y_flat = Y_test.flatten()
   Z_true = franke_function(X_flat, Y_flat)
   
   # Prepare test data for JAX
   X_test_jax = jnp.array(np.column_stack((X_flat, Y_flat)), dtype=dtype)
   if jax_device:
      X_test_jax = jax.device_put(X_test_jax, jax_device)
   
   return X_test_jax, X_test, Y_test, Z_true