"""
Data statistics computation for GLEAM-AI.

This module contains functions for computing data statistics used in
normalization and preprocessing of epidemiological data.
"""

import torch
import numpy as np
from pathlib import Path
from typing import Tuple, Union

from .utils import FeatureDataset, collate_fn


def get_z_score_stat(
    data_path: Union[str, Path],
    x_dim: int,
    xt_dim: int,
    y_dim: int,
    seq_len: int,
    num_comp: int,
    populations: np.ndarray,
    category: str = "train",
    batch_size: int = 64
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Compute z-score statistics for data normalization.
    
    Args:
        data_path: Path to data directory
        x_dim: Dimension of x features
        xt_dim: Dimension of temporal features
        y_dim: Dimension of y features
        seq_len: Sequence length
        num_comp: Number of compartments
        populations: Population data
        category: Data category ("train", "val", "test")
        batch_size: Batch size for computation
        
    Returns:
        Tuple of (x_mean, x_std, y_mean, y_std) statistics
    """
    data_path = Path(data_path) if not isinstance(data_path, Path) else data_path
    
    # Create dataset
    dataset = FeatureDataset(data_path, seq_len, category="train", populations=populations)
    
    # Create data loader
    loader = torch.utils.data.DataLoader(
        dataset, 
        batch_size=batch_size, 
        shuffle=False, 
        collate_fn=collate_fn
    )
    
    # Initialize accumulators
    yd = y_dim * num_comp
    xsum = torch.zeros(x_dim)
    xsum2 = torch.zeros(x_dim)
    ysum = torch.zeros(yd)
    ysum2 = torch.zeros(yd)
    
    x_counts = 0
    xt_counts = 0
    y_counts = 0
    
    # Process batches
    for batch_idx, (x, xt, y_hosp_inc, y_hosp_prev, y_latent_inc, y_latent_prev, y0) in enumerate(loader):
        # Add initial conditions to sequences
        y_hosp_inc = torch.cat([torch.zeros(*y0.size()).unsqueeze(1), y_hosp_inc], dim=1)
        y_hosp_prev = torch.cat([torch.zeros(*y0.size()).unsqueeze(1), y_hosp_prev], dim=1)
        y_latent_inc = torch.cat([torch.zeros(*y0.size()).unsqueeze(1), y_latent_inc], dim=1)
        y_latent_prev = torch.cat([y0.unsqueeze(1), y_latent_prev], dim=1)
        
        # Concatenate all y components
        y = torch.concat([y_hosp_inc, y_hosp_prev, y_latent_inc, y_latent_prev], dim=-1)
        
        # Accumulate x statistics (spatial features)
        xsum[xt_dim:] += x.sum(dim=(0, 1))
        xsum2[xt_dim:] += (x**2.0).sum(dim=(0, 1))
        
        # Accumulate xt statistics (temporal features)
        xsum[:xt_dim] += xt.sum(dim=(0, 1, 2))
        xsum2[:xt_dim] += (xt**2.0).sum(dim=(0, 1, 2))
        
        # Accumulate y statistics
        ysum += y.sum(dim=(0, 1))
        ysum2 += (y**2.0).sum(dim=(0, 1))
        
        # Update counts
        x_counts += np.prod(x.shape[:-1])
        xt_counts += np.prod(xt.shape[:-1])
        y_counts += np.prod(y.shape[:-1])
    
    # Compute means and standard deviations
    x_mean = np.zeros_like(xsum)
    x_std = np.zeros_like(xsum)
    
    # Spatial features statistics
    x_mean[xt_dim:] = xsum[xt_dim:].numpy() / x_counts
    x_std[xt_dim:] = np.sqrt(xsum2[xt_dim:].numpy() / x_counts - x_mean[xt_dim:]**2.0)
    
    # Temporal features statistics
    x_mean[:xt_dim] = xsum[:xt_dim].numpy() / xt_counts
    x_std[:xt_dim] = np.sqrt(xsum2[:xt_dim].numpy() / xt_counts - x_mean[:xt_dim]**2.0)
    
    # Ensure non-zero standard deviations
    x_std[x_std == 0.0] = 1e-8
    
    # Y statistics
    y_mean = ysum.numpy() / y_counts
    y_std = np.sqrt(ysum2.numpy() / y_counts - y_mean**2.0)
    y_std[y_std == 0.0] = 1e-8
    
    return x_mean, x_std, y_mean, y_std


def compute_dataset_statistics(
    dataset: FeatureDataset,
    batch_size: int = 64,
    device: str = "cpu"
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Compute statistics for a given dataset.
    
    Args:
        dataset: FeatureDataset to compute statistics for
        batch_size: Batch size for computation
        device: Device to run computation on
        
    Returns:
        Tuple of (x_mean, x_std, y_mean, y_std) statistics
    """
    # Create data loader
    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=collate_fn
    )
    
    # Initialize accumulators
    x_sum = None
    x_sum2 = None
    y_sum = None
    y_sum2 = None
    x_count = 0
    y_count = 0
    
    # Process batches
    for batch_idx, (x, xt, y_hosp_inc, y_hosp_prev, y_latent_inc, y_latent_prev, y0) in enumerate(loader):
        # Combine x and xt features
        x_combined = torch.cat([xt.flatten(1, 2), x.flatten(1, 2)], dim=-1)
        
        # Combine y components
        y_hosp_inc_full = torch.cat([torch.zeros(*y0.size()).unsqueeze(1), y_hosp_inc], dim=1)
        y_hosp_prev_full = torch.cat([torch.zeros(*y0.size()).unsqueeze(1), y_hosp_prev], dim=1)
        y_latent_inc_full = torch.cat([torch.zeros(*y0.size()).unsqueeze(1), y_latent_inc], dim=1)
        y_latent_prev_full = torch.cat([y0.unsqueeze(1), y_latent_prev], dim=1)
        y_combined = torch.cat([y_hosp_inc_full, y_hosp_prev_full, y_latent_inc_full, y_latent_prev_full], dim=-1)
        
        # Initialize accumulators on first batch
        if x_sum is None:
            x_sum = torch.zeros(x_combined.shape[-1])
            x_sum2 = torch.zeros(x_combined.shape[-1])
            y_sum = torch.zeros(y_combined.shape[-1])
            y_sum2 = torch.zeros(y_combined.shape[-1])
        
        # Accumulate statistics
        x_sum += x_combined.sum(dim=(0, 1))
        x_sum2 += (x_combined**2.0).sum(dim=(0, 1))
        y_sum += y_combined.sum(dim=(0, 1))
        y_sum2 += (y_combined**2.0).sum(dim=(0, 1))
        
        # Update counts
        x_count += x_combined.shape[0] * x_combined.shape[1]
        y_count += y_combined.shape[0] * y_combined.shape[1]
    
    # Compute means and standard deviations
    x_mean = x_sum.numpy() / x_count
    x_std = np.sqrt(x_sum2.numpy() / x_count - x_mean**2.0)
    x_std[x_std == 0.0] = 1e-8
    
    y_mean = y_sum.numpy() / y_count
    y_std = np.sqrt(y_sum2.numpy() / y_count - y_mean**2.0)
    y_std[y_std == 0.0] = 1e-8
    
    return x_mean, x_std, y_mean, y_std


def normalize_data(
    data: np.ndarray,
    mean: np.ndarray,
    std: np.ndarray,
    eps: float = 1e-8
) -> np.ndarray:
    """
    Normalize data using mean and standard deviation.
    
    Args:
        data: Data to normalize
        mean: Mean values
        std: Standard deviation values
        eps: Small value to avoid division by zero
        
    Returns:
        Normalized data
    """
    std_safe = np.where(std == 0.0, eps, std)
    return (data - mean) / std_safe


def denormalize_data(
    data: np.ndarray,
    mean: np.ndarray,
    std: np.ndarray
) -> np.ndarray:
    """
    Denormalize data using mean and standard deviation.
    
    Args:
        data: Normalized data
        mean: Mean values used for normalization
        std: Standard deviation values used for normalization
        
    Returns:
        Denormalized data
    """
    return data * std + mean


def compute_feature_statistics(
    data: np.ndarray,
    axis: Union[int, Tuple[int, ...]] = None
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Compute mean and standard deviation for data along specified axes.
    
    Args:
        data: Input data
        axis: Axes along which to compute statistics
        
    Returns:
        Tuple of (mean, std)
    """
    mean = np.mean(data, axis=axis)
    std = np.std(data, axis=axis)
    
    # Ensure non-zero standard deviations
    std = np.where(std == 0.0, 1e-8, std)
    
    return mean, std


def validate_statistics(
    mean: np.ndarray,
    std: np.ndarray,
    data_shape: Tuple[int, ...]
) -> bool:
    """
    Validate that statistics are compatible with data shape.
    
    Args:
        mean: Mean values
        std: Standard deviation values
        data_shape: Expected data shape
        
    Returns:
        True if statistics are valid, False otherwise
    """
    if mean.shape != std.shape:
        return False
    
    if len(mean.shape) != len(data_shape):
        return False
    
    if np.any(std < 0):
        return False
    
    return True
