"""
Data generation and visualization for Poisson equation PINN example.
Provides functions for sampling points, computing source terms, and visualization.
"""

import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import os

def true_solution(x, y):
   """
   True analytical solution to the Poisson equation for this test case.
   u(x,y) = sin(2πx)sin(2πy)
   """
   return jnp.sin(2 * jnp.pi * x) * jnp.sin(2 * jnp.pi * y)

def source_term(x, y):
   """
   Source term for the Poisson equation: -Δu(x,y) = f(x,y)
   For u(x,y) = sin(2πx)sin(2πy), the source term is f(x,y) = 8π²sin(2πx)sin(2πy)
   """
   return 8 * (jnp.pi**2) * jnp.sin(2 * jnp.pi * x) * jnp.sin(2 * jnp.pi * y)

def sample_rectangle_points(n_interior, n_boundary, domain_lb=(0, 0), domain_ub=(1, 1), seed=42):
   """
   Sample points inside a rectangular domain [a,b]×[c,d] and on its boundary.
   
   Args:
      n_interior: Number of interior points to sample
      n_boundary: Number of boundary points to sample (distributed equally among the 4 sides)
      domain_lb: Lower bounds of the domain (a, c)
      domain_ub: Upper bounds of the domain (b, d)
      seed: Random seed for reproducibility
      
   Returns:
      interior_points: Array of sampled interior points, shape (n_interior, 2)
      boundary_points: Array of sampled boundary points, shape (n_boundary, 2)
   """
   # Create PRNG key
   key = jax.random.PRNGKey(seed)
   
   # Unpack domain bounds
   x_min, y_min = domain_lb
   x_max, y_max = domain_ub
   
   # Sample interior points
   key, subkey = jax.random.split(key)
   x_interior = jax.random.uniform(subkey, (n_interior,), minval=x_min, maxval=x_max)
   
   key, subkey = jax.random.split(key)
   y_interior = jax.random.uniform(subkey, (n_interior,), minval=y_min, maxval=y_max)
   
   interior_points = jnp.column_stack((x_interior, y_interior))
   
   # Sample boundary points (equal number on each side)
   n_per_side = n_boundary // 4
   remainder = n_boundary % 4  # Distribute any remainder
   
   # Adjust number of points per side to account for remainder
   n_sides = [n_per_side + (1 if i < remainder else 0) for i in range(4)]
   
   # Generate boundary points
   boundary_points = []
   
   # Bottom edge (y = y_min)
   key, subkey = jax.random.split(key)
   x_bottom = jax.random.uniform(subkey, (n_sides[0],), minval=x_min, maxval=x_max)
   y_bottom = jnp.ones_like(x_bottom) * y_min
   boundary_points.append(jnp.column_stack((x_bottom, y_bottom)))
   
   # Right edge (x = x_max)
   key, subkey = jax.random.split(key)
   y_right = jax.random.uniform(subkey, (n_sides[1],), minval=y_min, maxval=y_max)
   x_right = jnp.ones_like(y_right) * x_max
   boundary_points.append(jnp.column_stack((x_right, y_right)))
   
   # Top edge (y = y_max)
   key, subkey = jax.random.split(key)
   x_top = jax.random.uniform(subkey, (n_sides[2],), minval=x_min, maxval=x_max)
   y_top = jnp.ones_like(x_top) * y_max
   boundary_points.append(jnp.column_stack((x_top, y_top)))
   
   # Left edge (x = x_min)
   key, subkey = jax.random.split(key)
   y_left = jax.random.uniform(subkey, (n_sides[3],), minval=y_min, maxval=y_max)
   x_left = jnp.ones_like(y_left) * x_min
   boundary_points.append(jnp.column_stack((x_left, y_left)))
   
   # Combine all boundary points
   boundary_points = jnp.concatenate(boundary_points, axis=0)
   
   return interior_points, boundary_points

def generate_data(n_interior, n_boundary, seed=42):
   """
   Generate training data for the Poisson PDE with Physics-Informed Neural Network.
   
   Args:
      n_interior: Number of interior collocation points
      n_boundary: Number of boundary points
      seed: Random seed for reproducibility
      
   Returns:
      interior_points: Interior collocation points, shape (n_interior, 2)
      boundary_points: Boundary points, shape (n_boundary, 2)
      f_interior: Source term evaluations at interior points (for PDE residual)
      u_boundary: True solution evaluations at boundary points (for boundary conditions)
   """
   # Sample interior and boundary points
   interior_points, boundary_points = sample_rectangle_points(n_interior, n_boundary, seed=seed)
   
   # Map the source term function over interior points (used in PDE constraint)
   f_interior = jax.vmap(lambda p: source_term(p[0], p[1]))(interior_points)
   
   # Map the true solution over boundary points (used in Dirichlet boundary conditions)
   u_boundary = jax.vmap(lambda p: true_solution(p[0], p[1]))(boundary_points)
   
   return interior_points, boundary_points, f_interior, u_boundary

def visualize_dataset(interior_points, boundary_points, figure_dir, filename="dataset_visualization.png"):
   """
   Visualize the training dataset with interior collocation points and boundary points.
   Creates a 2D plot showing the distribution of points in the domain.
   
   Args:
      interior_points: Interior collocation points, shape (n_interior, 2)
      boundary_points: Boundary points, shape (n_boundary, 2)
      figure_dir: Directory to save the visualization
      filename: Name of the output file
   """
   # Create figure and axis
   plt.figure(figsize=(8, 8))
   
   # Plot interior points
   plt.scatter(interior_points[:, 0], interior_points[:, 1], 
               c='blue', alpha=0.5, label='Interior points', 
               s=10, marker='o')
   
   # Plot boundary points
   plt.scatter(boundary_points[:, 0], boundary_points[:, 1], 
               c='red', s=15, label='Boundary points',
               marker='x')
   
   # Add labels and legend
   plt.xlabel('x', fontsize=20)
   plt.ylabel('y', fontsize=20)
   plt.title(f'Dataset: {len(interior_points)} interior points, {len(boundary_points)} boundary points', fontsize=25)
   plt.legend(fontsize=20)
   plt.axis('equal')
   plt.grid(alpha=0.3)
   
   # Save the visualization
   plt.savefig(os.path.join(figure_dir, filename), dpi=300, bbox_inches='tight')
   plt.close()
   
   print(f"Dataset visualization saved to {os.path.join(figure_dir, filename)}")