"""
Lorenz system problem implementation.
"""
import torch
import numpy as np
from typing import Tuple, Dict, Any
from losses.physics_losses import LorenzPhysicsLoss


class LorenzSystem:
    """
    Stochastic Lorenz system problem for testing aleatoric uncertainty.
    
    The Lorenz system is:
    dx/dt = σ(y - x)
    dy/dt = x(ρ - z) - y
    dz/dt = xy - βz
    
    With stochastic forcing terms added for noise.
    """
    
    def __init__(
        self,
        sigma: float = 10.0,
        rho: float = 28.0,
        beta: float = 8.0/3.0,
        noise_std: float = 0.5,
        device: str = "cpu"
    ):
        """
        Initialize Lorenz system.
        
        Args:
            sigma: Prandtl number
            rho: Rayleigh number
            beta: Geometric factor
            noise_std: Standard deviation of stochastic forcing
            device: Device to run on
        """
        self.sigma = sigma
        self.rho = rho
        self.beta = beta
        self.noise_std = noise_std
        self.device = device
        
        # Create physics loss
        self.physics_loss = LorenzPhysicsLoss(sigma, rho, beta, device)
    
    def generate_trajectory(
        self,
        t_start: float = 0.0,
        t_end: float = 10.0,
        dt: float = 0.01,
        initial_conditions: Tuple[float, float, float] = None,
        add_noise: bool = True
    ) -> Tuple[np.ndarray, np.ndarray]:
        """
        Generate a trajectory of the Lorenz system.
        
        Args:
            t_start: Start time
            t_end: End time
            dt: Time step
            initial_conditions: Initial (x, y, z) values
            add_noise: Whether to add stochastic forcing
            
        Returns:
            Tuple of (time_points, trajectory) arrays
        """
        if initial_conditions is None:
            # Random initial conditions in the basin of attraction
            initial_conditions = (
                np.random.uniform(-20, 20),
                np.random.uniform(-20, 20),
                np.random.uniform(0, 50)
            )
        
        # Time points
        t = np.arange(t_start, t_end + dt, dt)
        n_points = len(t)
        
        # Initialize trajectory
        trajectory = np.zeros((n_points, 3))
        trajectory[0] = initial_conditions
        
        # Integrate using Euler-Maruyama method
        for i in range(1, n_points):
            x, y, z = trajectory[i-1]
            
            # Deterministic part
            dx_dt = self.sigma * (y - x)
            dy_dt = x * (self.rho - z) - y
            dz_dt = x * y - self.beta * z
            
            # Stochastic part
            if add_noise:
                dW = np.random.normal(0, np.sqrt(dt), 3)
                dx_dt += self.noise_std * dW[0]
                dy_dt += self.noise_std * dW[1]
                dz_dt += self.noise_std * dW[2]
            
            # Update
            trajectory[i] = trajectory[i-1] + dt * np.array([dx_dt, dy_dt, dz_dt])
        
        return t, trajectory
    
    def generate_training_data(
        self,
        num_trajectories: int = 10,
        num_time_points: int = 1000,
        t_start: float = 0.0,
        t_end: float = 10.0,
        corruption_level: float = 0.1
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Generate training data for the Lorenz system.
        
        Args:
            num_trajectories: Number of different trajectories
            num_time_points: Number of time points per trajectory
            t_start: Start time
            t_end: End time
            corruption_level: Level of data corruption
            
        Returns:
            Tuple of (time_points, trajectory_data) tensors
        """
        dt = (t_end - t_start) / num_time_points
        
        all_times = []
        all_trajectories = []
        
        for _ in range(num_trajectories):
            t, trajectory = self.generate_trajectory(
                t_start=t_start,
                t_end=t_end,
                dt=dt,
                add_noise=True
            )
            
            all_times.append(t)
            all_trajectories.append(trajectory)
        
        # Combine all trajectories
        time_points = np.concatenate(all_times)
        trajectory_data = np.concatenate(all_trajectories)
        
        # Add corruption
        if corruption_level > 0:
            n_corrupt = int(len(trajectory_data) * corruption_level)
            corrupt_indices = np.random.choice(len(trajectory_data), n_corrupt, replace=False)
            trajectory_data[corrupt_indices] += np.random.normal(0, 1.0, (n_corrupt, 3))
        
        # Convert to tensors
        time_tensor = torch.tensor(time_points, dtype=torch.float32, device=self.device).unsqueeze(1)
        data_tensor = torch.tensor(trajectory_data, dtype=torch.float32, device=self.device)
        
        return time_tensor, data_tensor
    
    def generate_collocation_points(
        self,
        num_points: int = 10000,
        t_start: float = 0.0,
        t_end: float = 10.0
    ) -> torch.Tensor:
        """
        Generate collocation points for physics loss.
        
        Args:
            num_points: Number of collocation points
            t_start: Start time
            t_end: End time
            
        Returns:
            Collocation points tensor
        """
        t_points = np.random.uniform(t_start, t_end, num_points)
        t_tensor = torch.tensor(t_points, dtype=torch.float32, device=self.device).unsqueeze(1)
        
        return t_tensor
    
    def compute_physics_loss(self, model: torch.nn.Module, x: torch.Tensor) -> torch.Tensor:
        """
        Compute physics loss for the Lorenz system.
        
        Args:
            model: Neural network model
            x: Time points
            
        Returns:
            Physics loss tensor
        """
        residual = self.physics_loss.compute_residual(model, x)
        return torch.mean(residual**2)
    
    def get_problem_info(self) -> Dict[str, Any]:
        """Get problem information."""
        return {
            'problem_type': 'lorenz',
            'sigma': self.sigma,
            'rho': self.rho,
            'beta': self.beta,
            'noise_std': self.noise_std,
            'input_dim': 1,  # Time only
            'output_dim': 3,  # x, y, z components
        }
