"""
Dataset classes for GLEAM-AI.

This module contains dataset classes and utility functions for loading
and processing epidemiological data.
"""

import torch
from torch.utils.data import Dataset
import numpy as np
from pathlib import Path
from typing import Literal, Optional, Callable, Tuple, List, Any
import logging

logger = logging.getLogger(__name__)


class FeatureDataset(Dataset):
    """
    Dataset for epidemiological features and targets.
    
    This dataset loads pre-processed numpy arrays containing features (x, xt)
    and targets (y_inc, y_prev, y0) for training the STNP model.
    """
    
    def __init__(
        self,
        data_path: str,
        seq_len: int,
        category: Literal["train", "val", "test"],
        populations: Optional[np.ndarray] = None,
        x_transform: Optional[Callable] = None,
        y_hosp_inc_transform: Optional[Callable] = None,
        y_hosp_prev_transform: Optional[Callable] = None,
        y_latent_inc_transform: Optional[Callable] = None,
        y_latent_prev_transform: Optional[Callable] = None
    ):
        """
        Initialize the FeatureDataset.
        
        Args:
            data_path: Path to the data directory
            seq_len: Sequence length for temporal data
            category: Dataset category (train/val/test)
            populations: Population data array
            x_transform: Optional transform for x features
            y_hosp_inc_transform: Optional transform for hospital incidence
            y_hosp_prev_transform: Optional transform for hospital prevalence
            y_latent_inc_transform: Optional transform for latent incidence
            y_latent_prev_transform: Optional transform for latent prevalence
        """
        self.data_path = Path(data_path) if not isinstance(data_path, Path) else data_path
        self.max_seq_len = seq_len + 1  # +1 for initial condition
        self.category = category
        self.populations = populations
        
        # Set up paths
        self.x_path = self.data_path / f"x_{category}"
        self.xt_path = self.data_path / f"xt_{category}"
        self.y_inc_path = self.data_path / f"y_inc_{category}"
        self.y_prev_path = self.data_path / f"y_prev_{category}"
        self.y0_path = self.data_path / f"y0_{category}"
        
        # Get sorted file lists
        self.x_filenames = sorted(list(self.x_path.glob("*.npy")))
        self.xt_filenames = sorted(list(self.xt_path.glob("*.npy")))
        self.y_inc_filenames = sorted(list(self.y_inc_path.glob("*.npy")))
        self.y_prev_filenames = sorted(list(self.y_prev_path.glob("*.npy")))
        self.y0_filenames = sorted(list(self.y0_path.glob("*.npy")))
        
        # Validate file counts
        if not all([
            len(self.x_filenames) == len(self.xt_filenames),
            len(self.x_filenames) == len(self.y_inc_filenames),
            len(self.x_filenames) == len(self.y_prev_filenames),
            len(self.x_filenames) == len(self.y0_filenames)
        ]):
            raise ValueError("Inconsistent number of files across data types")
        
        # Store transforms
        self.x_transform = x_transform
        self.y_hosp_inc_transform = y_hosp_inc_transform
        self.y_hosp_prev_transform = y_hosp_prev_transform
        self.y_latent_inc_transform = y_latent_inc_transform
        self.y_latent_prev_transform = y_latent_prev_transform
        
        logger.info(f"Initialized {category} dataset with {len(self)} samples")
    
    def __len__(self) -> int:
        """Get dataset length."""
        return len(self.x_filenames)
    
    def __getitem__(self, idx: int) -> Tuple[np.ndarray, ...]:
        """
        Get a single item from the dataset.
        
        Args:
            idx: Index of the item
            
        Returns:
            Tuple of (x, xt, y_hosp_inc, y_hosp_prev, y_latent_inc, y_latent_prev, y0, idx)
            where:
            - x: Static features [num_nodes, x_dim]
            - xt: Temporal features [seq_len, num_nodes, xt_dim]
            - y_hosp_inc: Hospital incidence [seq_len, y_dim]
            - y_hosp_prev: Hospital prevalence [seq_len, y_dim]
            - y_latent_inc: Latent incidence [seq_len, y_dim]
            - y_latent_prev: Latent prevalence [seq_len, y_dim]
            - y0: Initial conditions [y_dim]
            - idx: Sample index
        """
        # Load data
        x = np.load(self.x_filenames[idx]).astype(np.float32)
        xt = np.load(self.xt_filenames[idx]).astype(np.float32)[:, :self.max_seq_len, ...]
        y0 = np.load(self.y0_filenames[idx]).astype(np.float32)
        
        # Load and split incidence data
        y_inc = np.load(self.y_inc_filenames[idx]).astype(np.float32)[:, :self.max_seq_len, ...]
        y_hosp_inc, y_latent_inc = np.split(y_inc, 2, axis=-1)
        
        # Load and split prevalence data
        y_prev = np.load(self.y_prev_filenames[idx]).astype(np.float32)[:, :self.max_seq_len, ...]
        y_hosp_prev, y_latent_prev = np.split(y_prev, 2, axis=-1)
        
        # Apply transforms if specified
        if self.x_transform:
            x = self.x_transform(x)
        if self.y_hosp_inc_transform:
            y_hosp_inc = self.y_hosp_inc_transform(y_hosp_inc)
        if self.y_hosp_prev_transform:
            y_hosp_prev = self.y_hosp_prev_transform(y_hosp_prev)
        if self.y_latent_inc_transform:
            y_latent_inc = self.y_latent_inc_transform(y_latent_inc)
        if self.y_latent_prev_transform:
            y_latent_prev = self.y_latent_prev_transform(y_latent_prev)
        
        # Remove the first time dimension if it's size 1
        if x.ndim > 2 and x.shape[0] == 1:
            x = x.squeeze(0)
        if xt.ndim > 3 and xt.shape[0] == 1:
            xt = xt.squeeze(0)
        if y_hosp_inc.ndim > 2 and y_hosp_inc.shape[0] == 1:
            y_hosp_inc = y_hosp_inc.squeeze(0)
        if y_hosp_prev.ndim > 2 and y_hosp_prev.shape[0] == 1:
            y_hosp_prev = y_hosp_prev.squeeze(0)
        if y_latent_inc.ndim > 2 and y_latent_inc.shape[0] == 1:
            y_latent_inc = y_latent_inc.squeeze(0)
        if y_latent_prev.ndim > 2 and y_latent_prev.shape[0] == 1:
            y_latent_prev = y_latent_prev.squeeze(0)
        if y0.ndim > 1 and y0.shape[0] == 1:
            y0 = y0.squeeze(0)
        
        return x, xt, y_hosp_inc, y_hosp_prev, y_latent_inc, y_latent_prev, y0, idx


class PoolDataset(Dataset):
    """
    Dataset for pool-based active learning.
    
    This dataset wraps a FeatureDataset for use in active learning scenarios
    where we need to select samples from a pool.
    """
    
    def __init__(self, feature_dataset: FeatureDataset, indices: Optional[List[int]] = None):
        """
        Initialize the PoolDataset.
        
        Args:
            feature_dataset: The underlying FeatureDataset
            indices: Optional list of indices to include in the pool
        """
        self.feature_dataset = feature_dataset
        self.indices = indices if indices is not None else list(range(len(feature_dataset)))
    
    def __len__(self) -> int:
        """Get pool size."""
        return len(self.indices)
    
    def __getitem__(self, idx: int) -> Tuple[np.ndarray, ...]:
        """
        Get a single item from the pool.
        
        Args:
            idx: Index in the pool
            
        Returns:
            Tuple of (x, xt, y0, pool_idx) for acquisition function evaluation
        """
        pool_idx = self.indices[idx]
        x, xt, _, _, _, _, y0, _ = self.feature_dataset[pool_idx]
        return x, xt, y0, pool_idx


def collate_fn(batch: List[Tuple], device: Optional[torch.device] = None) -> Tuple[torch.Tensor, ...]:
    """
    Collate function for batching FeatureDataset samples.
    
    Args:
        batch: List of samples from FeatureDataset
        device: Optional device to move tensors to
        
    Returns:
        Tuple of batched tensors (x, xt, y_hosp_inc, y_hosp_prev, y_latent_inc, y_latent_prev, y0)
    """
    x, xt, y_hosp_inc, y_hosp_prev, y_latent_inc, y_latent_prev, y0, _ = zip(*batch)
    
    # Convert to tensors and concatenate
    x = torch.cat([torch.FloatTensor(xx) if xx.ndim > 0 else torch.FloatTensor([xx]) for xx in x], dim=0)
    xt = torch.cat([torch.FloatTensor(xx) if xx.ndim > 0 else torch.FloatTensor([xx]) for xx in xt], dim=0)
    y_hosp_inc = torch.cat([torch.FloatTensor(yy) if yy.ndim > 0 else torch.FloatTensor([yy]) for yy in y_hosp_inc], dim=0)
    y_hosp_prev = torch.cat([torch.FloatTensor(yy) if yy.ndim > 0 else torch.FloatTensor([yy]) for yy in y_hosp_prev], dim=0)
    y_latent_inc = torch.cat([torch.FloatTensor(yy) if yy.ndim > 0 else torch.FloatTensor([yy]) for yy in y_latent_inc], dim=0)
    y_latent_prev = torch.cat([torch.FloatTensor(yy) if yy.ndim > 0 else torch.FloatTensor([yy]) for yy in y_latent_prev], dim=0)
    y0 = torch.cat([torch.FloatTensor(yy) if yy.ndim > 0 else torch.FloatTensor([yy]) for yy in y0], dim=0)
    
    # Random permutation for better training
    perm_idx = np.random.permutation(x.shape[0])
    
    # Move to device if specified
    if device is not None:
        x = x.to(device)
        xt = xt.to(device)
        y_hosp_inc = y_hosp_inc.to(device)
        y_hosp_prev = y_hosp_prev.to(device)
        y_latent_inc = y_latent_inc.to(device)
        y_latent_prev = y_latent_prev.to(device)
        y0 = y0.to(device)
    
    return (
        x[perm_idx, ...],
        xt[perm_idx, ...],
        y_hosp_inc[perm_idx, ...],
        y_hosp_prev[perm_idx, ...],
        y_latent_inc[perm_idx, ...],
        y_latent_prev[perm_idx, ...],
        y0[perm_idx, ...]
    )


def pool_collate_fn(batch: List[Tuple], device: Optional[torch.device] = None) -> Tuple[torch.Tensor, ...]:
    """
    Collate function for batching pool samples in active learning.
    
    Args:
        batch: List of samples from PoolDataset
        device: Optional device to move tensors to
        
    Returns:
        Tuple of batched tensors (x, xt, y0, pool_indices)
    """
    x, xt, y0, pool_indices = zip(*batch)
    
    # Convert to tensors
    x = torch.cat([torch.FloatTensor(xx) if xx.ndim > 0 else torch.FloatTensor([xx]) for xx in x], dim=0)
    xt = torch.cat([torch.FloatTensor(xx) if xx.ndim > 0 else torch.FloatTensor([xx]) for xx in xt], dim=0)
    y0 = torch.cat([torch.FloatTensor(yy) if yy.ndim > 0 else torch.FloatTensor([yy]) for yy in y0], dim=0)
    
    pool_indices = np.array(pool_indices, dtype=np.int64)
    
    # Move to device if specified
    if device is not None:
        x = x.to(device)
        xt = xt.to(device)
        y0 = y0.to(device)
    
    return x, xt, y0, pool_indices


def get_z_score_stats(
    data_path: str,
    x_dim: int,
    xt_dim: int,
    y_dim: int,
    seq_len: int,
    NUM_COMP: int,
    populations: Optional[np.ndarray] = None,
    category: str = "train",
    batch_size: int = 64
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Compute z-score statistics (mean and std) for data normalization.
    
    Args:
        data_path: Path to data directory
        x_dim: Dimension of static features
        xt_dim: Dimension of temporal features
        y_dim: Dimension of target variables
        seq_len: Sequence length
        NUM_COMP: Number of target components
        populations: Population data
        category: Dataset category (train/val/test)
        batch_size: Batch size for data loading
        
    Returns:
        Tuple of (x_mean, x_std, y_mean, y_std)
    """
    # Create dataset
    dataset = FeatureDataset(data_path, seq_len, category=category, populations=populations)
    
    # Create data loader
    loader = torch.utils.data.DataLoader(
        dataset, 
        batch_size=batch_size, 
        shuffle=False, 
        collate_fn=collate_fn
    )
    
    # Initialize accumulators
    yd = y_dim * NUM_COMP
    xsum = torch.zeros(x_dim)
    xsum2 = torch.zeros(x_dim)
    ysum = torch.zeros(yd)
    ysum2 = torch.zeros(yd)
    
    x_counts = 0
    xt_counts = 0
    y_counts = 0
    
    # Iterate through batches to compute statistics
    for batch_idx, (x, xt, y_hosp_inc, y_hosp_prev, y_latent_inc, y_latent_prev, y0) in enumerate(loader):
        # Prepare y data with initial conditions
        y_hosp_inc = torch.cat([torch.zeros(*y0.size()).unsqueeze(1), y_hosp_inc], dim=1)
        y_hosp_prev = torch.cat([torch.zeros(*y0.size()).unsqueeze(1), y_hosp_prev], dim=1)
        y_latent_inc = torch.cat([torch.zeros(*y0.size()).unsqueeze(1), y_latent_inc], dim=1)
        y_latent_prev = torch.cat([y0.unsqueeze(1), y_latent_prev], dim=1)
        y = torch.concat([y_hosp_inc, y_hosp_prev, y_latent_inc, y_latent_prev], dim=-1)
        
        # Accumulate statistics for x (static features)
        xsum[xt_dim:] += x.sum(dim=(0, 1))
        xsum2[xt_dim:] += (x**2.0).sum(dim=(0, 1))
        
        # Accumulate statistics for xt (temporal features)
        xsum[:xt_dim] += xt.sum(dim=(0, 1, 2))
        xsum2[:xt_dim] += (xt**2.0).sum(dim=(0, 1, 2))
        
        # Accumulate statistics for y
        ysum += y.sum(dim=(0, 1))
        ysum2 += (y**2.0).sum(dim=(0, 1))
        
        # Update counts
        x_counts += np.prod(x.shape[:-1])
        xt_counts += np.prod(xt.shape[:-1])
        y_counts += np.prod(y.shape[:-1])
    
    # Compute means
    x_mean = np.zeros_like(xsum)
    x_std = np.zeros_like(xsum)
    
    x_mean[xt_dim:] = xsum[xt_dim:].numpy() / x_counts
    x_mean[:xt_dim] = xsum[:xt_dim].numpy() / xt_counts
    
    # Compute standard deviations
    x_std[xt_dim:] = np.sqrt(xsum2[xt_dim:].numpy() / x_counts - x_mean[xt_dim:]**2.0)
    x_std[:xt_dim] = np.sqrt(xsum2[:xt_dim].numpy() / xt_counts - x_mean[:xt_dim]**2.0)
    
    # Prevent division by zero
    x_std[x_std == 0.0] = 1e-8
    
    # Compute y statistics
    y_mean = ysum.numpy() / y_counts
    y_std = np.sqrt(ysum2.numpy() / y_counts - y_mean**2.0)
    y_std[y_std == 0.0] = 1e-8
    
    return x_mean, x_std, y_mean, y_std


class MySubsetRandomSampler(torch.utils.data.Sampler):
    """
    Custom sampler for subset random sampling with index tracking.
    
    This is used in active learning to sample from a subset of the pool
    while keeping track of the original indices.
    """
    
    def __init__(self, indices: List[int], pool_indices: Optional[List[int]] = None):
        """
        Initialize the sampler.
        
        Args:
            indices: List of indices to sample from
            pool_indices: Original pool indices (for tracking)
        """
        self.indices = indices
        self.pool_indices = pool_indices if pool_indices is not None else indices
    
    def __iter__(self):
        """Iterate through indices."""
        return iter(self.indices)
    
    def __len__(self):
        """Get number of indices."""
        return len(self.indices)