"""
Random seed management and reproducibility utilities.
"""

import random
import numpy as np
import torch
from typing import Optional


def set_seed(seed: int, deterministic: bool = True):
    """
    Set random seed for reproducibility across all libraries.
    
    Args:
        seed: Random seed value
        deterministic: If True, use deterministic algorithms (may be slower)
    
    Note:
        When deterministic=True, some operations may be slower but results
        will be fully reproducible. Set to False for faster training if
        exact reproducibility is not required.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # For multi-GPU
    
    if deterministic:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        # For PyTorch >= 1.8 (check version compatibility)
        if hasattr(torch, 'use_deterministic_algorithms'):
            try:
                torch.use_deterministic_algorithms(True)
            except Exception:
                pass  # Some operations may not support deterministic mode
    else:
        torch.backends.cudnn.benchmark = True
    
    print(f"✓ Random seed set to {seed} (deterministic={deterministic})")


def get_rng_state():
    """
    Get the current state of all random number generators.
    
    Returns:
        Dictionary containing RNG states for all libraries
    """
    state = {
        'python': random.getstate(),
        'numpy': np.random.get_state(),
        'torch': torch.get_rng_state(),
    }
    
    if torch.cuda.is_available():
        state['torch_cuda'] = torch.cuda.get_rng_state_all()
    
    return state


def set_rng_state(state: dict):
    """
    Restore the state of all random number generators.
    
    Args:
        state: Dictionary containing RNG states from get_rng_state()
    """
    random.setstate(state['python'])
    np.random.set_state(state['numpy'])
    torch.set_rng_state(state['torch'])
    
    if torch.cuda.is_available() and 'torch_cuda' in state:
        torch.cuda.set_rng_state_all(state['torch_cuda'])


class RandomContext:
    """
    Context manager for temporary random seed changes.
    
    Usage:
        with RandomContext(42):
            # Code here uses seed 42
            x = torch.randn(10)
        # Original seed is restored
    """
    
    def __init__(self, seed: Optional[int] = None):
        self.seed = seed
        self.state = None
    
    def __enter__(self):
        if self.seed is not None:
            self.state = get_rng_state()
            set_seed(self.seed, deterministic=False)
        return self
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        if self.state is not None:
            set_rng_state(self.state)


def sample_uniform(low: float, high: float, size: tuple, device: str = 'cpu') -> torch.Tensor:
    """
    Sample from uniform distribution U(low, high).
    
    Args:
        low: Lower bound
        high: Upper bound
        size: Shape of output tensor
        device: Device to create tensor on
    
    Returns:
        Tensor of shape `size` with uniform random values
    """
    return torch.rand(size, device=device) * (high - low) + low


def sample_normal(mean: float, std: float, size: tuple, device: str = 'cpu') -> torch.Tensor:
    """
    Sample from normal distribution N(mean, std^2).
    
    Args:
        mean: Mean of distribution
        std: Standard deviation
        size: Shape of output tensor
        device: Device to create tensor on
    
    Returns:
        Tensor of shape `size` with normal random values
    """
    return torch.randn(size, device=device) * std + mean


def stratified_sampling_1d(n_samples: int, bounds: tuple, device: str = 'cpu') -> torch.Tensor:
    """
    Stratified sampling in 1D for better coverage.
    
    Divides the domain into equal bins and samples uniformly within each bin.
    Useful for collocation point sampling with better spatial distribution.
    
    Args:
        n_samples: Number of samples
        bounds: (lower, upper) bounds of domain
        device: Device to create tensor on
    
    Returns:
        Tensor of shape (n_samples,) with stratified samples
    """
    lower, upper = bounds
    bin_width = (upper - lower) / n_samples
    
    # Sample uniformly within each bin
    bin_starts = torch.linspace(lower, upper - bin_width, n_samples, device=device)
    offsets = torch.rand(n_samples, device=device) * bin_width
    
    return bin_starts + offsets


def latin_hypercube_sampling(n_samples: int, n_dims: int, bounds: list, device: str = 'cpu') -> torch.Tensor:
    """
    Latin Hypercube Sampling for multi-dimensional domains.
    
    Provides better space-filling properties than pure random sampling.
    
    Args:
        n_samples: Number of samples
        n_dims: Number of dimensions
        bounds: List of (lower, upper) tuples for each dimension
        device: Device to create tensor on
    
    Returns:
        Tensor of shape (n_samples, n_dims) with LHS samples
    """
    samples = torch.zeros(n_samples, n_dims, device=device)
    
    for d in range(n_dims):
        lower, upper = bounds[d]
        
        # Divide into bins
        bin_edges = torch.linspace(lower, upper, n_samples + 1, device=device)
        
        # Shuffle bins for this dimension
        perm = torch.randperm(n_samples, device=device)
        
        # Sample within each bin
        for i, bin_idx in enumerate(perm):
            bin_lower = bin_edges[bin_idx]
            bin_upper = bin_edges[bin_idx + 1]
            samples[i, d] = torch.rand(1, device=device) * (bin_upper - bin_lower) + bin_lower
    
    return samples