"""
PyTorch utilities for GLEAM-AI.

This module contains utility functions for PyTorch operations, model loading,
and evaluation metrics used in the GLEAM-AI system.
"""

import torch
import torch.nn as nn
import torch.distributions as dist
import numpy as np
import yaml
from pathlib import Path
from typing import Sequence, Iterator, Optional, Union, Dict, Any
from torch.utils.data.sampler import Sampler


def build_mlp(
    in_features: int, 
    hidden_dims: list, 
    out_features: int
) -> nn.Sequential:
    """
    Build a multi-layer perceptron.
    
    Args:
        in_features: Number of input features
        hidden_dims: List of hidden layer dimensions
        out_features: Number of output features
        
    Returns:
        Sequential model containing the MLP layers
    """
    if not hidden_dims:
        return nn.Sequential(nn.Linear(in_features, out_features))
    
    layers = []
    input_dim = in_features
    
    for hidden_dim in hidden_dims:
        layers.append(nn.Linear(input_dim, hidden_dim))
        layers.append(nn.ReLU())
        input_dim = hidden_dim
    
    # Add output layer
    layers.append(nn.Linear(hidden_dims[-1], out_features))
    
    return nn.Sequential(*layers)


def load_model(model: nn.Module, checkpoint_path: Union[str, Path]) -> nn.Module:
    """
    Load model weights from checkpoint.
    
    Args:
        model: Model to load weights into
        checkpoint_path: Path to checkpoint file
        
    Returns:
        Model with loaded weights
    """
    checkpoint = torch.load(checkpoint_path, map_location="cpu")
    model.load_state_dict(checkpoint["state_dict"])
    return model


def load_active_data(active_data_path: Union[str, Path]) -> Dict[str, Any]:
    """
    Load active learning data from YAML file.
    
    Args:
        active_data_path: Path to active data YAML file
        
    Returns:
        Dictionary containing active learning data
    """
    with open(active_data_path, "r") as fp:
        active_data = yaml.load(fp, Loader=yaml.SafeLoader)
    return active_data


def crps_normal(mu: torch.Tensor, sigma: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
    """
    Compute the Continuous Ranked Probability Score (CRPS) for normal distributions.
    
    Args:
        mu: Mean of the normal distribution [batch_size, seq_len, y_dim]
        sigma: Standard deviation of the normal distribution [batch_size, seq_len, y_dim]
        y_true: True values [batch_size, seq_len, y_dim]
        
    Returns:
        CRPS scores [batch_size, seq_len, y_dim]
    """
    batch_size, seq_len, y_dim = mu.size()
    c1 = torch.sqrt(torch.tensor(2 / np.pi)).to(mu.device)
    c2 = torch.tensor(0.5).to(mu.device)
    
    term1 = ((mu - y_true) ** 2) / (2 * sigma ** 2)
    term2 = sigma ** 2 / 12 * (c1 - c2)
    crps = term1 + term2
    
    return crps


def crps_gaussian(y: torch.Tensor, mu: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
    """
    Compute CRPS for Gaussian distributions.
    
    This implementation follows the formula from:
    "Calibrated Probabilistic Forecasting Using Ensemble Model Output
    Statistics and Minimum CRPS Estimation" by Gneiting et al.
    
    Args:
        y: Observations
        mu: Mean of the forecast normal distribution
        sigma: Standard deviation of the forecast distribution
        
    Returns:
        CRPS scores
    """
    # Standardized observations
    sx = (y - mu) / (sigma + 1e-6)
    
    # Standard normal distribution
    normal = dist.Normal(torch.zeros_like(mu), torch.ones_like(sigma))
    
    # Precompute PDF and CDF
    pdf = normal.log_prob(sx).exp()
    cdf = normal.cdf(sx)
    
    # CRPS formula
    pi_inv = 1.0 / np.sqrt(np.pi)
    crps = sigma * (sx * (2 * cdf - 1) + 2 * pdf - pi_inv)
    
    return crps


def crps_by_each_cdf(y_true: torch.Tensor, yhat: torch.Tensor) -> torch.Tensor:
    """
    Compute CRPS using empirical CDF for each sample.
    
    Args:
        y_true: True values [batch_size, seq_len, y_dim]
        yhat: Predicted values [n_samples, batch_size, seq_len, y_dim]
        
    Returns:
        CRPS scores [batch_size, seq_len, y_dim]
    """
    yhat_sort, _ = torch.sort(yhat, dim=0)
    n_samples = yhat_sort.size(0)
    cdf_yhat = torch.cumsum(
        torch.ones_like(yhat_sort, device=yhat.device, dtype=torch.float32), 
        dim=0
    ) / n_samples
    
    crps = []
    for i in range(y_true.size(0)):
        heaviside = (yhat_sort >= y_true[i, :, :].unsqueeze(0)).int()
        crps_per_sample = ((cdf_yhat - heaviside) ** 2).mean(dim=0)
        crps.append(crps_per_sample)
    
    return torch.stack(crps, dim=0)


def crps_cdf(y_true: torch.Tensor, yhat: torch.Tensor) -> tuple:
    """
    Compute CRPS using CDF comparison.
    
    Args:
        y_true: True values
        yhat: Predicted values
        
    Returns:
        Tuple of (crps, cdf_yhat, yhat_sort, cdf_y_true, y_true_sort)
    """
    yhat_sort, _ = torch.sort(yhat, dim=0)
    n_samples = yhat_sort.size(0)
    cdf_yhat = torch.cumsum(
        torch.ones_like(yhat_sort, device=yhat.device, dtype=torch.float32), 
        dim=0
    ) / n_samples
    
    y_true_sort, _ = torch.sort(y_true, dim=0)
    cdf_y_true = torch.cumsum(
        torch.ones_like(y_true, device=y_true.device, dtype=torch.float32), 
        dim=0
    ) / y_true_sort.size(0)
    
    crps = ((cdf_yhat - cdf_y_true) ** 2.0).mean()
    
    return crps, cdf_yhat, yhat_sort, cdf_y_true, y_true_sort


class MySubsetRandomSampler(Sampler[int]):
    """
    Custom sampler for subset random sampling with pool indices mapping.
    
    This sampler allows for random sampling from a subset of indices while
    maintaining the correct mapping to original pool indices.
    """
    
    def __init__(
        self, 
        indices: Sequence[int], 
        pool_indices_list: Sequence[int], 
        generator: Optional[torch.Generator] = None
    ) -> None:
        """
        Initialize the sampler.
        
        Args:
            indices: Sequence of indices to sample from
            pool_indices_list: List of pool indices corresponding to the subset
            generator: Random number generator
        """
        self.indices = indices
        self.pool_indices_list = pool_indices_list
        self.generator = generator
    
    def __iter__(self) -> Iterator[int]:
        """Generate random indices."""
        for i in torch.randperm(len(self.indices), generator=self.generator):
            yield self.pool_indices_list[i]
    
    def __len__(self) -> int:
        """Get the number of samples."""
        return len(self.indices)


def compute_model_parameters(model: nn.Module) -> Dict[str, int]:
    """
    Compute the number of parameters in a model.
    
    Args:
        model: PyTorch model
        
    Returns:
        Dictionary with parameter counts
    """
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    return {
        "total_parameters": total_params,
        "trainable_parameters": trainable_params,
        "non_trainable_parameters": total_params - trainable_params
    }


def set_seed(seed: int) -> None:
    """
    Set random seed for reproducibility.
    
    Args:
        seed: Random seed value
    """
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def get_device() -> torch.device:
    """
    Get the best available device (CUDA if available, otherwise CPU).
    
    Returns:
        PyTorch device
    """
    if torch.cuda.is_available():
        return torch.device("cuda")
    elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        return torch.device("mps")
    else:
        return torch.device("cpu")


def move_to_device(tensor_or_model: Union[torch.Tensor, nn.Module], device: torch.device) -> Union[torch.Tensor, nn.Module]:
    """
    Move tensor or model to specified device.
    
    Args:
        tensor_or_model: Tensor or model to move
        device: Target device
        
    Returns:
        Tensor or model on the target device
    """
    return tensor_or_model.to(device)


def save_checkpoint(
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    epoch: int,
    loss: float,
    filepath: Union[str, Path],
    **kwargs
) -> None:
    """
    Save model checkpoint.
    
    Args:
        model: Model to save
        optimizer: Optimizer state
        epoch: Current epoch
        loss: Current loss
        filepath: Path to save checkpoint
        **kwargs: Additional data to save
    """
    checkpoint = {
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "loss": loss,
        **kwargs
    }
    torch.save(checkpoint, filepath)


def load_checkpoint(
    model: nn.Module,
    optimizer: Optional[torch.optim.Optimizer],
    filepath: Union[str, Path]
) -> Dict[str, Any]:
    """
    Load model checkpoint.
    
    Args:
        model: Model to load weights into
        optimizer: Optimizer to load state into (optional)
        filepath: Path to checkpoint file
        
    Returns:
        Dictionary containing checkpoint data
    """
    checkpoint = torch.load(filepath, map_location="cpu")
    model.load_state_dict(checkpoint["model_state_dict"])
    
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    
    return checkpoint


def compute_gradient_norm(model: nn.Module) -> float:
    """
    Compute the L2 norm of gradients.
    
    Args:
        model: Model to compute gradient norm for
        
    Returns:
        Gradient norm
    """
    total_norm = 0.0
    for p in model.parameters():
        if p.grad is not None:
            param_norm = p.grad.data.norm(2)
            total_norm += param_norm.item() ** 2
    total_norm = total_norm ** (1.0 / 2)
    return total_norm


def clip_gradients(model: nn.Module, max_norm: float) -> float:
    """
    Clip gradients to prevent exploding gradients.
    
    Args:
        model: Model to clip gradients for
        max_norm: Maximum gradient norm
        
    Returns:
        Actual gradient norm before clipping
    """
    return torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm).item()
