"""
Data generation and visualization for Poisson equation PINN example.
"""

import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import os
from functools import partial
from jax import grad, vmap, jit

@jit
def gaussian_delta(x, y, sigma=0.01):
   """
   Gaussian approximation of the delta function δ(x-y)

   Args:
      x: Evaluation point
      y: Source point
      sigma: Width parameter controlling the approximation

   Returns:
      Approximation of delta function centered at y evaluated at x
   """
   return (1.0 / (sigma * jnp.sqrt(2 * jnp.pi))) * jnp.exp(-0.5 * ((x - y) / sigma) ** 2)

@jit
def source_term(x, y, sigma):
   """
   Right-hand side of the convection-diffusion Green's function equation:
   -D * ∂²G/∂x² + C * ∂G/∂x = δ(x-y)

   Where:
   - D is the diffusion coefficient (positive)
   - C is the convection coefficient
   - δ(x-y) is approximated by a Gaussian centered at y

   When C=0, this reduces to the standard Poisson equation (pure diffusion).

   Args:
      x: Spatial coordinate where equation is evaluated
      y: Source point where the Green's function is centered
      sigma: Width parameter for the Gaussian approximation

   Returns:
      Gaussian approximation of delta function
   """
   return gaussian_delta(x, y, sigma)

# JIT-compiled sampling functions
@partial(jit, static_argnums=(3,))
def _sample_interior_points_jit(key, x_min, x_max, n_points):
   """JIT-compiled version with static shape parameter"""
   return jax.random.uniform(key, (n_points, 2), minval=x_min, maxval=x_max)

def sample_random_interior_points(key, n_points, x_min=0.0, x_max=1.0):
   """
   Sample random interior points (x,y) from the unit square [0,1]²
   Using JIT-compiled implementation for efficiency

   Args:
      key: JAX PRNG key
      n_points: Number of interior points to sample
      x_min, x_max: Domain bounds

   Returns:
      points: Shape (n_points, 2) array of interior points
   """
   # Ensure n_points is a concrete integer
   n_points = int(n_points)
   return _sample_interior_points_jit(key, x_min, x_max, n_points)

# Pure function for grid parameters calculation - JIT compatible
def compute_grid_params(epsilon, x_min=0.0, x_max=1.0):
   """
   Compute grid parameters for a given epsilon
   Pure function version that's compatible with JIT

   Args:
      epsilon: Width parameter of the delta function approximation
      x_min, x_max: Domain bounds

   Returns:
      A tuple of (num_cells, h) where:
      - num_cells: Number of cells in the grid
      - h: Cell size
   """
   # Grid size based on epsilon
   h = epsilon * 3.0  # Grid cell size (3 sigma to capture most of the Gaussian)
   num_cells = jnp.maximum(2, jnp.floor((x_max - x_min) / h).astype(jnp.int32))

   # Adjusted cell size to exactly fill the domain
   h = (x_max - x_min) / num_cells

   return num_cells, h

# Dictionary cache for non-JIT contexts with size limit
_grid_params_cache = {}
MAX_CACHE_SIZE = 100  # Limit cache size to avoid memory growth

def get_grid_params(epsilon, x_min=0.0, x_max=1.0):
   """
   Get or compute grid parameters for a given epsilon
   Caching wrapper (for non-JIT contexts) with size limit

   Args:
      epsilon: Width parameter of the delta function approximation
      x_min, x_max: Domain bounds

   Returns:
      A tuple of (num_cells, h) where:
      - num_cells: Number of cells in the grid
      - h: Cell size
   """
   try:
      # Try to use as a key to test if we're in a tracing context
      _ = hash((epsilon, x_min, x_max))

      # Not in a tracing context, we can use the cache
      key = (float(epsilon), float(x_min), float(x_max))

      # Check if we've already computed this
      if key in _grid_params_cache:
         return _grid_params_cache[key]

      # Compute the result
      result = compute_grid_params(epsilon, x_min, x_max)

      # Cache the result, but enforce size limit
      if len(_grid_params_cache) >= MAX_CACHE_SIZE:
         # Simple eviction policy: remove a random entry
         # For more sophisticated LRU-like eviction, additional tracking would be needed
         _grid_params_cache.pop(next(iter(_grid_params_cache.keys())))
      
      _grid_params_cache[key] = result

      return result

   except (TypeError, ValueError):
      # We're in a tracing/JIT context, just compute without caching
      return compute_grid_params(epsilon, x_min, x_max)

@partial(jit, static_argnums=(4,))
def _sample_closexy_points_jit(key, num_cells, h, x_min, n_points):
   """
   JIT-compiled implementation for static n_points
   """
   key1, key2, key3 = jax.random.split(key, 3)

   # Sample starting cell indices (from 0 to num_cells-2)
   cell_indices = jax.random.randint(key1, (n_points,), 0, num_cells-1)

   # Randomly decide if x is in cell k or k+1, and y is in the other cell
   x_in_first_cell = jax.random.bernoulli(key2, 0.5, (n_points,))

   # Create masks for which indices go where
   x_cell_mask = x_in_first_cell.reshape(-1, 1)
   y_cell_mask = (1 - x_in_first_cell).reshape(-1, 1)

   # Sample offsets within cells (0 to h)
   offsets = jax.random.uniform(key3, (n_points, 2), minval=0, maxval=h)

   # Calculate cell positions (as column vectors)
   cell_k = cell_indices.reshape(-1, 1)
   cell_k_plus_1 = (cell_indices + 1).reshape(-1, 1)

   # Calculate x and y coordinates
   # x is either in cell k or k+1
   x = x_min + (x_cell_mask * cell_k + (1 - x_cell_mask) * cell_k_plus_1) * h + offsets[:, 0:1]
   # y is in the other cell from x
   y = x_min + (y_cell_mask * cell_k + (1 - y_cell_mask) * cell_k_plus_1) * h + offsets[:, 1:2]

   # Combine into points
   return jnp.concatenate([x, y], axis=1)

def sample_closexy_interior_points(key, n_points, epsilon, x_min=0.0, x_max=1.0):
   """
   Sample points where x and y are close to each other
   Uses an overlapping cell strategy with pre-computed grid parameters

   Args:
      key: JAX PRNG key
      n_points: Number of interior points to sample
      epsilon: Width parameter of the delta function approximation
      x_min, x_max: Domain bounds

   Returns:
      points: Shape (n_points, 2) array of interior points where x and y are close
   """
   # Ensure n_points is a concrete integer
   n_points = int(n_points)

   # Get or compute grid parameters
   num_cells, h = get_grid_params(epsilon, x_min, x_max)

   # Use the JIT-compiled implementation
   return _sample_closexy_points_jit(key, num_cells, h, x_min, n_points)

@partial(jit, static_argnums=(3,))
def _sample_boundary_points_jit(key, x_min, x_max, n_boundary):
   """JIT-compiled implementation with fixed shape parameter"""
   # Sample all y coordinates at once
   y_coords = jax.random.uniform(key, (n_boundary, 1))

   # Create a mask for left (x=0) vs right (x=1) boundary
   # First half goes to left boundary, second half to right
   is_right = jnp.arange(n_boundary) >= (n_boundary // 2)

   # Set x coordinates using the mask (vectorized)
   x_coords = jnp.where(is_right.reshape(-1, 1), x_max, x_min)

   # Combine into points
   return jnp.concatenate([x_coords, y_coords], axis=1)

def sample_boundary_points(key, n_boundary, x_min=0.0, x_max=1.0):
   """
   Sample points from the boundary where x=0 or x=1 with random y
   Uses JIT-compiled implementation for efficiency

   Args:
      key: JAX PRNG key
      n_boundary: Number of boundary points to sample
      x_min, x_max: Domain bounds

   Returns:
      boundary_points: Shape (n_boundary, 2) array of boundary points
   """
   # Ensure n_boundary is a concrete integer
   n_boundary = int(n_boundary)
   return _sample_boundary_points_jit(key, x_min, x_max, n_boundary)

@partial(jit)
def _compute_source_terms_jit(points, epsilon):
   """JIT-compiled vectorized computation of source terms"""
   def source_term_fn(point):
      return gaussian_delta(point[0], point[1], epsilon)
   return vmap(source_term_fn)(points)

def _generate_data_impl(n_random_interior, n_close_interior, n_boundary, epsilon, seed):
   """
   Core implementation of data generation - pure function for JIT compilation
   """
   # Generate sampling points using the PRNG key
   key = jax.random.PRNGKey(seed)
   key1, key2, key3 = jax.random.split(key, 3)

   # Sample interior points directly using the implementation functions
   # Use direct calls to jit-compiled functions to allow XLA optimization
   random_interior_points = _sample_interior_points_jit(key1, 0.0, 1.0, n_random_interior)

   # Get grid parameters for close points - use the JIT compatible version directly
   num_cells, h = compute_grid_params(epsilon, 0.0, 1.0)
   close_interior_points = _sample_closexy_points_jit(key2, num_cells, h, 0.0, n_close_interior)

   # Sample boundary points
   boundary_points = _sample_boundary_points_jit(key3, 0.0, 1.0, n_boundary)

   # Calculate source terms for interior points
   interior_points = jnp.concatenate([random_interior_points, close_interior_points], axis=0)
   interior_source_terms = _compute_source_terms_jit(interior_points, epsilon)

   return random_interior_points, close_interior_points, boundary_points, interior_source_terms

@partial(jit, static_argnums=(0, 1, 2))
def _generate_data_jit(n_random_interior, n_close_interior, n_boundary, epsilon, seed):
   """JIT-compiled data generation with static shape parameters"""
   return _generate_data_impl(n_random_interior, n_close_interior, n_boundary, epsilon, seed)

def generate_data(n_random_interior, n_close_interior, n_boundary, epsilon, seed=42):
   """
   Generate training data points for the Green's function problem
   Fully JIT-compiled for maximum performance

   Args:
      n_random_interior: Number of random interior points
      n_close_interior: Number of interior points with close x and y
      n_boundary: Number of boundary points
      epsilon: Width parameter for the delta function approximation
      seed: Random seed

   Returns:
      random_interior_points: Shape (n_random_interior, 2) array of random interior points
      close_interior_points: Shape (n_close_interior, 2) array of interior points with close x and y
      boundary_points: Shape (n_boundary, 2) array of boundary points
      interior_source_terms: Source term values for both types of interior points
   """
   # Ensure all shape parameters are concrete integers
   n_random = int(n_random_interior)
   n_close = int(n_close_interior)
   n_bound = int(n_boundary)

   # Call the JIT-compiled implementation
   return _generate_data_jit(n_random, n_close, n_bound, epsilon, seed)

def visualize_dataset(random_interior_points, close_interior_points, boundary_points, figure_dir, sigma, filename="dataset_visualization.png"):
   """
   Visualize the sampled interior and boundary points to verify sampling correctness
   
   Args:
      random_interior_points: Array of random interior points
      close_interior_points: Array of interior points with close x and y
      boundary_points: Array of boundary points
      figure_dir: Directory to save the visualization
      sigma: Width parameter for the Gaussian delta function
      filename: Name of the visualization file
   """
   plt.figure(figsize=(12, 10))
   
   # Plot random interior points (blue)
   plt.scatter(random_interior_points[:, 0], random_interior_points[:, 1], 
            s=5, alpha=0.6, c='blue', label=f'Random Interior ({len(random_interior_points)} points)')
   
   # Plot close interior points (green)
   plt.scatter(close_interior_points[:, 0], close_interior_points[:, 1], 
            s=15, alpha=0.8, c='green', label=f'Close (x,y) Interior ({len(close_interior_points)} points)')
   
   # Plot boundary points (red)
   plt.scatter(boundary_points[:, 0], boundary_points[:, 1], 
            s=10, alpha=0.8, c='red', label=f'Boundary ({len(boundary_points)} points)')
   
   # Draw the domain boundary
   plt.plot([0, 1, 1, 0, 0], [0, 0, 1, 1, 0], 'k-', lw=2)
   
   # Add a diagonal line to visualize "closeness" of points
   plt.plot([0, 1], [0, 1], 'k--', alpha=0.5, label='x=y line')
   
   # Annotate the sigma parameter
   plt.annotate(f'σ = {sigma:.4f}', xy=(0.05, 0.95), xycoords='axes fraction',
            bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="gray", alpha=0.8))
   
   plt.xlim(-0.05, 1.05)
   plt.ylim(-0.05, 1.05)
   plt.xlabel('x', fontsize=20)
   plt.ylabel('y', fontsize=20)
   plt.title('Green\'s Function Dataset Points Visualization', fontsize=25)
   plt.grid(True, alpha=0.3)
   plt.legend(fontsize=20)
   
   # Save the visualization (avoid redundant path join operations)
   output_path = os.path.join(figure_dir, filename)
   plt.savefig(output_path, dpi=300)
   plt.close()
   
   print(f"Dataset visualization saved to {output_path}")