"""
2D Burgers' equation problem implementation.
"""
import torch
import numpy as np
from typing import Tuple, Dict, Any
from losses.physics_losses import BurgersPhysicsLoss


class Burgers2D:
    """
    2D Burgers' equation problem for testing computational scalability.
    
    The 2D Burgers' equation is:
    ∂u/∂t + u·∇u = ν∇²u
    
    where u = (u, v) is the velocity field.
    """
    
    def __init__(
        self,
        nu: float = 0.01,
        x_start: float = 0.0,
        x_end: float = 1.0,
        y_start: float = 0.0,
        y_end: float = 1.0,
        t_start: float = 0.0,
        t_end: float = 1.0,
        device: str = "cpu"
    ):
        """
        Initialize 2D Burgers' problem.
        
        Args:
            nu: Viscosity coefficient
            x_start, x_end: Spatial domain in x
            y_start, y_end: Spatial domain in y
            t_start, t_end: Temporal domain
            device: Device to run on
        """
        self.nu = nu
        self.x_start = x_start
        self.x_end = x_end
        self.y_start = y_start
        self.y_end = y_end
        self.t_start = t_start
        self.t_end = t_end
        self.device = device
        
        # Create physics loss
        self.physics_loss = BurgersPhysicsLoss(nu, device)
    
    def initial_condition(self, x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Initial condition for the velocity field.
        
        Args:
            x: x coordinates
            y: y coordinates
            
        Returns:
            Tuple of (u0, v0) initial velocities
        """
        u0 = torch.sin(np.pi * x) * torch.sin(np.pi * y)
        v0 = torch.cos(np.pi * x) * torch.cos(np.pi * y)
        return u0, v0
    
    def generate_training_data(
        self,
        num_initial_points: int = 1000,
        num_boundary_points: int = 400,
        num_interior_points: int = 500,
        noise_std: float = 0.1
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Generate training data for 2D Burgers' equation.
        
        Args:
            num_initial_points: Number of initial condition points
            num_boundary_points: Number of boundary condition points
            num_interior_points: Number of interior observation points
            noise_std: Standard deviation of noise
            
        Returns:
            Tuple of (x_data, y_data) tensors
        """
        data_points = []
        data_values = []
        
        # Initial condition points (t=0)
        x_init = torch.rand(num_initial_points, device=self.device) * (self.x_end - self.x_start) + self.x_start
        y_init = torch.rand(num_initial_points, device=self.device) * (self.y_end - self.y_start) + self.y_start
        t_init = torch.zeros(num_initial_points, device=self.device)
        
        u_init, v_init = self.initial_condition(x_init, y_init)
        
        # Add noise
        u_init += torch.randn_like(u_init) * noise_std
        v_init += torch.randn_like(v_init) * noise_std
        
        data_points.append(torch.stack([x_init, y_init, t_init], dim=1))
        data_values.append(torch.stack([u_init, v_init], dim=1))
        
        # Boundary condition points
        n_boundary_per_side = num_boundary_points // 4
        
        # Left boundary (x=0)
        x_left = torch.zeros(n_boundary_per_side, device=self.device)
        y_left = torch.rand(n_boundary_per_side, device=self.device) * (self.y_end - self.y_start) + self.y_start
        t_left = torch.rand(n_boundary_per_side, device=self.device) * (self.t_end - self.t_start) + self.t_start
        
        # Right boundary (x=1)
        x_right = torch.ones(n_boundary_per_side, device=self.device)
        y_right = torch.rand(n_boundary_per_side, device=self.device) * (self.y_end - self.y_start) + self.y_start
        t_right = torch.rand(n_boundary_per_side, device=self.device) * (self.t_end - self.t_start) + self.t_start
        
        # Bottom boundary (y=0)
        x_bottom = torch.rand(n_boundary_per_side, device=self.device) * (self.x_end - self.x_start) + self.x_start
        y_bottom = torch.zeros(n_boundary_per_side, device=self.device)
        t_bottom = torch.rand(n_boundary_per_side, device=self.device) * (self.t_end - self.t_start) + self.t_start
        
        # Top boundary (y=1)
        x_top = torch.rand(n_boundary_per_side, device=self.device) * (self.x_end - self.x_start) + self.x_start
        y_top = torch.ones(n_boundary_per_side, device=self.device)
        t_top = torch.rand(n_boundary_per_side, device=self.device) * (self.t_end - self.t_start) + self.t_start
        
        # Combine boundary points
        x_boundary = torch.cat([x_left, x_right, x_bottom, x_top])
        y_boundary = torch.cat([y_left, y_right, y_bottom, y_top])
        t_boundary = torch.cat([t_left, t_right, t_bottom, t_top])
        
        # Boundary values (periodic boundary conditions)
        u_boundary = torch.zeros_like(x_boundary)
        v_boundary = torch.zeros_like(x_boundary)
        
        data_points.append(torch.stack([x_boundary, y_boundary, t_boundary], dim=1))
        data_values.append(torch.stack([u_boundary, v_boundary], dim=1))
        
        # Interior observation points
        x_interior = torch.rand(num_interior_points, device=self.device) * (self.x_end - self.x_start) + self.x_start
        y_interior = torch.rand(num_interior_points, device=self.device) * (self.y_end - self.y_start) + self.y_start
        t_interior = torch.rand(num_interior_points, device=self.device) * (self.t_end - self.t_start) + self.t_start
        
        # For interior points, we'll use a simplified analytical solution
        # In practice, this would come from numerical simulation or experimental data
        u_interior = torch.sin(np.pi * x_interior) * torch.sin(np.pi * y_interior) * torch.exp(-self.nu * np.pi**2 * t_interior)
        v_interior = torch.cos(np.pi * x_interior) * torch.cos(np.pi * y_interior) * torch.exp(-self.nu * np.pi**2 * t_interior)
        
        # Add noise
        u_interior += torch.randn_like(u_interior) * noise_std
        v_interior += torch.randn_like(v_interior) * noise_std
        
        data_points.append(torch.stack([x_interior, y_interior, t_interior], dim=1))
        data_values.append(torch.stack([u_interior, v_interior], dim=1))
        
        # Combine all data
        x_data = torch.cat(data_points, dim=0)
        y_data = torch.cat(data_values, dim=0)
        
        return x_data, y_data
    
    def generate_collocation_points(
        self,
        num_points: int = 10000,
        num_x: int = 100,
        num_y: int = 100,
        num_t: int = 50
    ) -> torch.Tensor:
        """
        Generate collocation points for physics loss.
        
        Args:
            num_points: Total number of collocation points
            num_x: Number of x points
            num_y: Number of y points
            num_t: Number of t points
            
        Returns:
            Collocation points tensor
        """
        # Create grid
        x = torch.linspace(self.x_start, self.x_end, num_x, device=self.device)
        y = torch.linspace(self.y_start, self.y_end, num_y, device=self.device)
        t = torch.linspace(self.t_start, self.t_end, num_t, device=self.device)
        
        # Create meshgrid
        X, Y, T = torch.meshgrid(x, y, t, indexing='ij')
        
        # Flatten and sample
        points = torch.stack([X.flatten(), Y.flatten(), T.flatten()], dim=1)
        
        # Randomly sample if we have too many points
        if len(points) > num_points:
            indices = torch.randperm(len(points))[:num_points]
            points = points[indices]
        
        return points
    
    def compute_physics_loss(self, model: torch.nn.Module, x: torch.Tensor) -> torch.Tensor:
        """
        Compute physics loss for the 2D Burgers' equation.
        
        Args:
            model: Neural network model
            x: Spatiotemporal points (x, y, t)
            
        Returns:
            Physics loss tensor
        """
        residual = self.physics_loss.compute_residual(model, x)
        return torch.mean(residual**2)
    
    def get_problem_info(self) -> Dict[str, Any]:
        """Get problem information."""
        return {
            'problem_type': 'burgers_2d',
            'nu': self.nu,
            'x_domain': (self.x_start, self.x_end),
            'y_domain': (self.y_start, self.y_end),
            't_domain': (self.t_start, self.t_end),
            'input_dim': 3,  # x, y, t
            'output_dim': 2,  # u, v components
        }
