#!/usr/bin/env python3
"""
Data generation utilities for PINN problems.
"""

import torch
import numpy as np
from typing import Dict, Tuple


def generate_problem_data(problem_type: str, n_points: int = 1000, device: str = "cpu") -> Dict[str, torch.Tensor]:
    """
    Generate data for a specific problem.
    
    Args:
        problem_type: Type of problem ('lorenz', 'burgers', 'inverse_poisson')
        n_points: Number of data points to generate
        device: Device to place tensors on
    
    Returns:
        Dictionary containing training data
    """
    if problem_type == "lorenz":
        return generate_lorenz_data(n_points, device)
    elif problem_type == "burgers":
        return generate_burgers_data(n_points, device)
    elif problem_type == "inverse_poisson":
        return generate_inverse_poisson_data(n_points, device)
    else:
        raise ValueError(f"Unknown problem type: {problem_type}")


def generate_lorenz_data(n_points: int = 1000, device: str = "cpu") -> Dict[str, torch.Tensor]:
    """Generate data for Lorenz system."""
    # Time points
    t = torch.linspace(0, 2, n_points, device=device).unsqueeze(1)
    
    # Generate synthetic Lorenz trajectory
    dt = 0.01
    n_steps = int(2 / dt)
    
    # Initial conditions
    x, y, z = 1.0, 1.0, 1.0
    sigma, rho, beta = 10.0, 28.0, 8.0/3.0
    
    trajectory = []
    for _ in range(n_steps):
        dx = sigma * (y - x) * dt
        dy = (x * (rho - z) - y) * dt
        dz = (x * y - beta * z) * dt
        
        x += dx
        y += dy
        z += dz
        
        trajectory.append([x, y, z])
    
    trajectory = np.array(trajectory)
    
    # Sample points
    indices = np.linspace(0, len(trajectory)-1, n_points, dtype=int)
    y_data = torch.tensor(trajectory[indices], dtype=torch.float32, device=device)
    
    # Add some noise
    noise = torch.randn_like(y_data) * 0.1
    y_data += noise
    
    return {
        'x_collocation': t,
        'x_data': t,
        'y_data': y_data
    }


def generate_burgers_data(n_points: int = 1000, device: str = "cpu") -> Dict[str, torch.Tensor]:
    """Generate data for 2D Burgers equation."""
    # Create spatiotemporal grid
    t = torch.linspace(0, 1, int(np.sqrt(n_points)), device=device).unsqueeze(1)
    x = torch.linspace(0, 1, int(np.sqrt(n_points)), device=device).unsqueeze(1)
    y = torch.linspace(0, 1, int(np.sqrt(n_points)), device=device).unsqueeze(1)
    
    # Create meshgrid
    T, X, Y = torch.meshgrid(t.squeeze(), x.squeeze(), y.squeeze(), indexing='ij')
    points = torch.stack([T.flatten(), X.flatten(), Y.flatten()], dim=1)
    
    # Generate synthetic solution
    u = torch.sin(np.pi * X) * torch.sin(np.pi * Y) * torch.exp(-0.1 * T)
    v = torch.cos(np.pi * X) * torch.cos(np.pi * Y) * torch.exp(-0.1 * T)
    
    solution = torch.stack([u.flatten(), v.flatten()], dim=1)
    
    # Add noise
    noise = torch.randn_like(solution) * 0.05
    solution += noise
    
    return {
        'x_collocation': points,
        'x_data': points,
        'y_data': solution
    }


def generate_inverse_poisson_data(n_points: int = 1000, device: str = "cpu") -> Dict[str, torch.Tensor]:
    """Generate data for 1D inverse Poisson problem."""
    # Spatial points
    x = torch.linspace(0, 1, n_points, device=device).unsqueeze(1)
    
    # True solution: u(x) = sin(2πx)
    u_true = torch.sin(2 * np.pi * x)
    
    # True source: f(x) = 4π² sin(2πx)
    f_true = 4 * np.pi**2 * torch.sin(2 * np.pi * x)
    
    # Combine solution and source
    y_data = torch.cat([u_true, f_true], dim=1)
    
    # Add noise
    noise = torch.randn_like(y_data) * 0.05
    y_data += noise
    
    return {
        'x_collocation': x,
        'x_data': x,
        'y_data': y_data
    }
