"""
Data utilities for GLEAM-AI.

This module contains utilities for data loading, preprocessing, and dataset management
for the GLEAM-AI epidemiological forecasting system.
"""

import torch
import torch.utils.data as data
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Callable, Literal, Optional, List, Dict, Any, Union, Tuple
from functools import partial

from .constructors import construct_x, construct_initial_conditions, construct_temporal_features
from ..utils.save_utils import make_x_filename, make_xt_filename, make_y_filename, make_y0_filename


def load_graph_data(data_path: Union[str, Path]) -> Tuple[np.ndarray, np.ndarray]:
    """
    Load graph data from edge list file.
    
    Args:
        data_path: Path to the data directory containing weighted_edge_list.csv
        
    Returns:
        Tuple of (edge_index, edge_weight) where:
        - edge_index: [2, num_edges] array of edge connections
        - edge_weight: [num_edges] array of edge weights
    """
    if not isinstance(data_path, Path):
        data_path = Path(data_path)
    
    if not data_path.exists():
        raise FileNotFoundError(f"Data path does not exist: {data_path}")
    
    edge_list_path = data_path / "weighted_edge_list.csv"
    if not edge_list_path.exists():
        raise FileNotFoundError(f"Edge list file not found: {edge_list_path}")
    
    df = pd.read_csv(edge_list_path, sep=",", header=None, names=["source", "target", "weights"])
    
    edge_index = df[["source", "target"]].values.T
    edge_weight = df["weights"].values
    
    return edge_index, edge_weight


def get_filenames(
    data_path: Union[str, Path], 
    category: Literal["train", "val", "test"]
) -> Tuple[List[Path], List[Path], List[Path], List[Path], List[Path]]:
    """
    Get filenames for different data categories.
    
    Args:
        data_path: Path to the data directory
        category: Data category ("train", "val", "test")
        
    Returns:
        Tuple of (x_filenames, xt_filenames, y_inc_filenames, y_prev_filenames, y0_filenames)
    """
    src = Path(data_path) if not isinstance(data_path, Path) else data_path
    
    if not src.exists():
        raise FileNotFoundError(f"Data path does not exist: {src}")
    
    x_filenames = sorted(list(src.joinpath(f"x_{category}").glob("*.npy")))
    xt_filenames = sorted(list(src.joinpath(f"xt_{category}").glob("*.npy")))
    
    y_inc_filenames = sorted(list(src.joinpath(f"y_inc_{category}").glob("*.npy")))
    y_prev_filenames = sorted(list(src.joinpath(f"y_prev_{category}").glob("*.npy")))
    
    y0_filenames = sorted(list(src.joinpath(f"y0_{category}").glob("*.npy")))
    
    return x_filenames, xt_filenames, y_inc_filenames, y_prev_filenames, y0_filenames


def get_datasets(
    meta_path: Union[str, Path],
    data_path: Union[str, Path],
    src_path: Union[str, Path],
    x_col_names: List[str],
    frac_pops_names: List[str],
    initial_col_names: List[str],
    seq_len: int,
    num_nodes: int,
    population_csv_path: Union[str, Path],
    population_scaler: float = 1.0
) -> Tuple['FeatureDataset', 'FeatureDataset', 'PoolDataset']:
    """
    Get training, validation, and pool datasets.
    
    Args:
        meta_path: Path to metadata file
        data_path: Path to data directory
        src_path: Path to source directory
        x_col_names: Column names for x features
        frac_pops_names: Column names for population fractions
        initial_col_names: Column names for initial conditions
        seq_len: Sequence length
        num_nodes: Number of nodes
        population_csv_path: Path to population CSV file
        population_scaler: Population scaling factor
        
    Returns:
        Tuple of (train_dataset, val_dataset, pool_dataset)
    """
    data_path = Path(data_path) if not isinstance(data_path, Path) else data_path
    meta_path = Path(meta_path) if not isinstance(meta_path, Path) else meta_path
    src_path = Path(src_path) if not isinstance(src_path, Path) else src_path
    
    # Load population data
    populations = load_population_data(population_csv_path)
    
    # Create datasets
    train_dataset = FeatureDataset(data_path, seq_len, "train", populations)
    val_dataset = FeatureDataset(data_path, seq_len, "val", populations)
    pool_dataset = PoolDataset(
        meta_path / "x_df.csv", 
        x_col_names, 
        frac_pops_names, 
        initial_col_names, 
        seq_len, 
        num_nodes, 
        populations, 
        population_scaler
    )
    
    return train_dataset, val_dataset, pool_dataset


class PoolDataset(data.Dataset):
    """
    Dataset for pool data used in active learning.
    
    This dataset loads data from metadata files and constructs features
    on-the-fly for active learning scenarios.
    """
    
    def __init__(
        self,
        meta_filename: Union[str, Path],
        x_col_names: List[str],
        frac_pops_names: List[str],
        initial_col_names: List[str],
        seq_len: int,
        num_nodes: int,
        populations: np.ndarray,
        population_scaler: float = 1.0,
        x_transform: Optional[Callable] = None,
        y_transform: Optional[Callable] = None
    ):
        """
        Initialize the pool dataset.
        
        Args:
            meta_filename: Path to metadata CSV file
            x_col_names: Column names for x features
            frac_pops_names: Column names for population fractions
            initial_col_names: Column names for initial conditions
            seq_len: Sequence length
            num_nodes: Number of nodes
            populations: Population data array
            population_scaler: Population scaling factor
            x_transform: Transform function for x data
            y_transform: Transform function for y data
        """
        self.x_df = read_meta_df(meta_filename)
        
        self.x_col_names = x_col_names
        self.frac_pops_names = frac_pops_names
        self.initial_col_names = initial_col_names
        
        self.seq_len = seq_len
        self.max_seq_len = seq_len + 1  # Account for initial condition
        self.num_nodes = num_nodes
        
        self.pop_data = populations
        self.population_scaler = population_scaler
        
        self.x_transform = x_transform
        self.y_transform = y_transform
    
    def __len__(self) -> int:
        """Get the number of samples in the dataset."""
        return len(self.x_df)
    
    def __getitem__(self, i: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray, int]:
        """
        Get a sample from the dataset.
        
        Args:
            i: Sample index
            
        Returns:
            Tuple of (x, xt, y0, index) where:
            - x: [B, num_nodes, x_dim] input features
            - xt: [B, num_nodes, L, xt_dim] temporal features
            - y0: [B, y_dim] initial conditions
            - index: Sample index
        """
        df = self.x_df.iloc[i]
        
        # Construct features
        x = construct_x(df, self.x_col_names, self.frac_pops_names, self.pop_data, self.population_scaler)[0]
        xt = construct_temporal_features(df, self.seq_len, self.num_nodes)[:, :self.max_seq_len, ...]
        
        # Add batch axis if needed
        if np.ndim(x) < 3:
            x = x[np.newaxis, ...]
        
        if np.ndim(xt) < 4:
            xt = xt[np.newaxis, ...]
        
        # Get initial conditions
        df_y0 = df.loc[self.initial_col_names]
        y0 = construct_initial_conditions(df_y0, self.pop_data, self.population_scaler)[0]
        
        # Add batch dimension if needed
        if np.ndim(y0) < 2:
            y0 = y0[np.newaxis, ...]
        
        # Apply transforms
        if self.x_transform:
            x = self.x_transform(x)
        
        if self.y_transform:
            y0 = self.y_transform(y0)
        
        return x.astype(np.float32), xt.astype(np.float32), y0.astype(np.float32), i
    
    def update_transforms(self, x_transform: Optional[Callable], y_transform: Optional[Callable]) -> None:
        """Update transform functions."""
        self.x_transform = x_transform
        self.y_transform = y_transform


class FeatureDataset(data.Dataset):
    """
    Dataset for loading pre-computed features from files.
    
    This dataset loads data from pre-computed numpy files for training and validation.
    """
    
    def __init__(
        self,
        data_path: Union[str, Path],
        seq_len: int,
        category: Literal["train", "val", "test"],
        populations: np.ndarray,
        x_transform: Optional[Callable] = None,
        y_hosp_inc_transform: Optional[Callable] = None,
        y_hosp_prev_transform: Optional[Callable] = None,
        y_latent_inc_transform: Optional[Callable] = None,
        y_latent_prev_transform: Optional[Callable] = None
    ):
        """
        Initialize the feature dataset.
        
        Args:
            data_path: Path to data directory
            seq_len: Sequence length
            category: Data category ("train", "val", "test")
            populations: Population data array
            x_transform: Transform for x data
            y_hosp_inc_transform: Transform for hospital incidence data
            y_hosp_prev_transform: Transform for hospital prevalence data
            y_latent_inc_transform: Transform for latent incidence data
            y_latent_prev_transform: Transform for latent prevalence data
        """
        data_path = Path(data_path) if not isinstance(data_path, Path) else data_path
        self.max_seq_len = seq_len + 1  # Initial condition
        
        # Set up paths
        self.x_path = data_path / f"x_{category}"
        self.xt_path = data_path / f"xt_{category}"
        self.y_inc_path = data_path / f"y_inc_{category}"
        self.y_prev_path = data_path / f"y_prev_{category}"
        self.y0_path = data_path / f"y0_{category}"
        
        # Get filenames
        self.x_filenames = sorted(list(self.x_path.glob("*.npy")))
        self.xt_filenames = sorted(list(self.xt_path.glob("*.npy")))
        self.y_inc_filenames = sorted(list(self.y_inc_path.glob("*.npy")))
        self.y_prev_filenames = sorted(list(self.y_prev_path.glob("*.npy")))
        self.y0_filenames = sorted(list(self.y0_path.glob("*.npy")))
        
        self.pop_data = populations
        
        # Store transforms
        self.x_transform = x_transform
        self.y_hosp_inc_transform = y_hosp_inc_transform
        self.y_hosp_prev_transform = y_hosp_prev_transform
        self.y_latent_inc_transform = y_latent_inc_transform
        self.y_latent_prev_transform = y_latent_prev_transform
    
    def __len__(self) -> int:
        """Get the number of samples in the dataset."""
        return len(self.x_filenames)
    
    def __getitem__(self, i: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, int]:
        """
        Get a sample from the dataset.
        
        Args:
            i: Sample index
            
        Returns:
            Tuple of (x, xt, y_hosp_inc, y_hosp_prev, y_latent_inc, y_latent_prev, y0, index)
        """
        # Load data
        x = np.load(self.x_filenames[i]).astype(np.float64)
        xt = np.load(self.xt_filenames[i]).astype(np.float64)[:, :self.max_seq_len, ...]
        y0 = np.load(self.y0_filenames[i]).astype(np.float64)
        y_inc = np.load(self.y_inc_filenames[i]).astype(np.float64)[:, :self.max_seq_len, ...]
        y_prev = np.load(self.y_prev_filenames[i]).astype(np.float64)[:, :self.max_seq_len, ...]
        
        # Split into hospital and latent components
        y_hosp_inc, y_latent_inc = np.split(y_inc, 2, axis=-1)
        y_hosp_prev, y_latent_prev = np.split(y_prev, 2, axis=-1)
        
        # Apply transforms
        if self.x_transform:
            x = self.x_transform(x)
        if self.y_latent_prev_transform:
            y_latent_prev = self.y_latent_prev_transform.transform(y_latent_prev)
            y0 = self.y_latent_prev_transform.transform(y0)
        if self.y_latent_inc_transform:
            y_latent_inc = self.y_latent_inc_transform.transform(y_latent_inc)
        if self.y_hosp_inc_transform:
            y_hosp_inc = self.y_hosp_inc_transform.transform(y_hosp_inc)
        if self.y_hosp_prev_transform:
            y_hosp_prev = self.y_hosp_prev_transform.transform(y_hosp_prev)
        
        return (
            x.astype(np.float32),
            xt.astype(np.float32),
            y_hosp_inc[:, 1:, ...].astype(np.float32),
            y_hosp_prev[:, 1:, ...].astype(np.float32),
            y_latent_inc[:, 1:, ...].astype(np.float32),
            y_latent_prev[:, 1:, ...].astype(np.float32),
            y0.astype(np.float32),
            i
        )
    
    def update_transforms(
        self,
        x_transform: Optional[Callable],
        y_hosp_inc_transform: Optional[Callable],
        y_hosp_prev_transform: Optional[Callable],
        y_latent_inc_transform: Optional[Callable],
        y_latent_prev_transform: Optional[Callable]
    ) -> None:
        """Update transform functions."""
        self.x_transform = x_transform
        self.y_hosp_inc_transform = y_hosp_inc_transform
        self.y_hosp_prev_transform = y_hosp_prev_transform
        self.y_latent_inc_transform = y_latent_inc_transform
        self.y_latent_prev_transform = y_latent_prev_transform
    
    def update_dataset(self) -> None:
        """Update the dataset by refreshing file lists."""
        self.x_filenames = sorted(list(self.x_path.glob("*.npy")))
        self.xt_filenames = sorted(list(self.xt_path.glob("*.npy")))
        self.y_inc_filenames = sorted(list(self.y_inc_path.glob("*.npy")))
        self.y_prev_filenames = sorted(list(self.y_prev_path.glob("*.npy")))
        self.y0_filenames = sorted(list(self.y0_path.glob("*.npy")))


def collate_fn(batch: List[Tuple], device: Optional[torch.device] = None) -> Tuple[torch.Tensor, ...]:
    """
    Collate function for batching data.
    
    Args:
        batch: List of samples
        device: Device to move tensors to
        
    Returns:
        Tuple of batched tensors
    """
    x, xt, y_hosp_inc, y_hosp_prev, y_latent_inc, y_latent_prev, y0, _ = zip(*batch)
    
    # Concatenate tensors
    x = torch.cat([torch.FloatTensor(xx) for xx in x], dim=0)
    xt = torch.cat([torch.FloatTensor(xx) for xx in xt], dim=0)
    y_hosp_inc = torch.cat([torch.FloatTensor(yy) for yy in y_hosp_inc], dim=0)
    y_hosp_prev = torch.cat([torch.FloatTensor(yy) for yy in y_hosp_prev], dim=0)
    y_latent_inc = torch.cat([torch.FloatTensor(yy) for yy in y_latent_inc], dim=0)
    y_latent_prev = torch.cat([torch.FloatTensor(yy) for yy in y_latent_prev], dim=0)
    y0 = torch.cat([torch.FloatTensor(yy) for yy in y0], dim=0)
    
    # Random permutation
    perm_idx = np.random.permutation(x.shape[0])
    
    # Move to device if specified
    if device is not None:
        x = x.to(device)
        xt = xt.to(device)
        y_hosp_inc = y_hosp_inc.to(device)
        y_hosp_prev = y_hosp_prev.to(device)
        y_latent_inc = y_latent_inc.to(device)
        y_latent_prev = y_latent_prev.to(device)
        y0 = y0.to(device)
    
    return (
        x[perm_idx, ...],
        xt[perm_idx, ...],
        y_hosp_inc[perm_idx, ...],
        y_hosp_prev[perm_idx, ...],
        y_latent_inc[perm_idx, ...],
        y_latent_prev[perm_idx, ...],
        y0[perm_idx, ...]
    )


def pool_collate_fn(batch: List[Tuple], device: Optional[torch.device] = None) -> Tuple[torch.Tensor, ...]:
    """
    Collate function for pool data batching.
    
    Args:
        batch: List of pool samples
        device: Device to move tensors to
        
    Returns:
        Tuple of (x, xt, y0, pool_indices)
    """
    x, xt, y0, pool_indices = zip(*batch)
    
    # Concatenate tensors
    x = torch.cat([torch.FloatTensor(xx) for xx in x], dim=0)
    xt = torch.cat([torch.FloatTensor(xx) for xx in xt], dim=0)
    y0 = torch.cat([torch.FloatTensor(yy) for yy in y0], dim=0)
    pool_indices = np.array(pool_indices, dtype=np.int64)
    
    # Move to device if specified
    if device is not None:
        x = x.to(device)
        xt = xt.to(device)
        y0 = y0.to(device)
    
    return x, xt, y0, pool_indices


def get_z_score_transform(
    x_mean: np.ndarray, 
    x_std: np.ndarray, 
    y_mean: np.ndarray, 
    y_std: np.ndarray
) -> Tuple[Callable, Callable]:
    """
    Get z-score normalization transforms.
    
    Args:
        x_mean: Mean for x data
        x_std: Standard deviation for x data
        y_mean: Mean for y data
        y_std: Standard deviation for y data
        
    Returns:
        Tuple of (x_transform, y_transform) functions
    """
    x_transform = partial(standardize, mean=x_mean, std=x_std)
    y_transform = partial(standardize, mean=y_mean, std=y_std)
    return x_transform, y_transform


def standardize(data: np.ndarray, mean: np.ndarray, std: np.ndarray) -> np.ndarray:
    """
    Standardize data using mean and standard deviation.
    
    Args:
        data: Data to standardize
        mean: Mean values
        std: Standard deviation values
        
    Returns:
        Standardized data
    """
    std[std == 0.0] = 1e-8
    return (data - mean) / std


def read_meta_df(meta_filename: Union[str, Path], use_cols: Optional[List[str]] = None) -> pd.DataFrame:
    """
    Read metadata DataFrame from CSV file.
    
    Args:
        meta_filename: Path to metadata CSV file
        use_cols: Columns to read (None for all columns)
        
    Returns:
        Metadata DataFrame
    """
    if use_cols is not None:
        df = pd.read_csv(meta_filename, usecols=use_cols)
    else:
        df = pd.read_csv(meta_filename, parse_dates=["starting_date"])
    
    if use_cols is None or "run_ids" in use_cols:
        df["run_ids"] = df["run_ids"].apply(eval)
    
    return df


def add_pop_data_to_x(x: np.ndarray, pop_data: np.ndarray) -> np.ndarray:
    """
    Add population data to input features.
    
    Args:
        x: Input features [B, num_nodes, x_dim]
        pop_data: Population data [num_nodes]
        
    Returns:
        Enhanced features [B, num_nodes, x_dim + pop_dim]
    """
    x_1 = x[..., :2]
    x_frac_pops = x[..., 2:]
    
    # Scale by population
    x_pops = np.einsum("ikj,k->ikj", x_frac_pops.astype(np.float64), pop_data)
    
    return np.concatenate([x_1, x_pops], axis=-1)


def remove_extra_train_files(path: Union[str, Path], last_iter_train_file_id: List[int]) -> None:
    """
    Remove extra training files to keep only the latest iteration.
    
    Args:
        path: Path to training data directory
        last_iter_train_file_id: List of file IDs to keep
    """
    x_filenames, xt_filenames, y_inc_filenames, y_prev_filenames, y0_filenames = get_filenames(path, "train")
    
    # Get parent directories
    x_parent = x_filenames[0].parent
    xt_parent = xt_filenames[0].parent
    y_inc_parent = y_inc_filenames[0].parent
    y_prev_parent = y_prev_filenames[0].parent
    y0_parent = y0_filenames[0].parent
    
    # Generate filenames to keep
    last_iter_x_filenames = [make_x_filename(idx) for idx in last_iter_train_file_id]
    last_iter_xt_filenames = [make_xt_filename(idx) for idx in last_iter_train_file_id]
    last_iter_y_inc_filenames = [make_y_filename(idx, "y_inc") for idx in last_iter_train_file_id]
    last_iter_y_prev_filenames = [make_y_filename(idx, "y_prev") for idx in last_iter_train_file_id]
    last_iter_y0_filenames = [make_y0_filename(idx) for idx in last_iter_train_file_id]
    
    # Delete files
    for f in last_iter_x_filenames:
        Path(x_parent).joinpath(f).unlink(missing_ok=True)
    for f in last_iter_xt_filenames:
        Path(xt_parent).joinpath(f).unlink(missing_ok=True)
    for f in last_iter_y_inc_filenames:
        Path(y_inc_parent).joinpath(f).unlink(missing_ok=True)
    for f in last_iter_y_prev_filenames:
        Path(y_prev_parent).joinpath(f).unlink(missing_ok=True)
    for f in last_iter_y0_filenames:
        Path(y0_parent).joinpath(f).unlink(missing_ok=True)


def load_population_data(population_csv_path: Union[str, Path]) -> np.ndarray:
    """
    Load population data from CSV file.
    
    Args:
        population_csv_path: Path to population CSV file
        
    Returns:
        Population data array sorted by country_id
    """
    return pd.read_csv(population_csv_path).sort_values("country_id")["population"].values


class ZNormalize:
    """
    Z-score normalization class for data preprocessing.
    
    This class provides normalization functionality for both numpy arrays
    and PyTorch tensors with proper device handling.
    """
    
    def __init__(self, mean: Union[np.ndarray, torch.Tensor], std: Union[np.ndarray, torch.Tensor]):
        """
        Initialize the normalizer.
        
        Args:
            mean: Mean values for normalization
            std: Standard deviation values for normalization
        """
        self.eps = 1e-8
        
        # Convert mean to numpy
        if isinstance(mean, torch.Tensor):
            self.mean = mean.detach().numpy().astype(np.float64)
        elif isinstance(mean, np.ndarray):
            self.mean = mean.astype(np.float64)
        else:
            raise TypeError("Mean must be numpy array or torch Tensor")
        
        # Convert std to numpy
        if isinstance(std, torch.Tensor):
            self.std = std.detach().numpy().astype(np.float64)
        elif isinstance(std, np.ndarray):
            self.std = std.astype(np.float64)
        else:
            raise TypeError("Std must be numpy array or torch Tensor")
    
    def transform(self, y: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
        """
        Apply normalization transform.
        
        Args:
            y: Data to normalize
            
        Returns:
            Normalized data
        """
        if isinstance(y, torch.Tensor):
            torch_mean = torch.from_numpy(self.mean).to(y.device).to(y.dtype)
            torch_std = torch.from_numpy(self.std).to(y.device).to(y.dtype)
            return (y - torch_mean) / (torch_std + self.eps)
        elif isinstance(y, np.ndarray):
            return (y - self.mean) / (self.std + self.eps)
        else:
            raise TypeError("Input must be numpy array or torch Tensor")
    
    def update_stats(self, mean: Union[np.ndarray, torch.Tensor], std: Union[np.ndarray, torch.Tensor]) -> None:
        """Update normalization statistics."""
        if isinstance(mean, torch.Tensor):
            self.mean = mean.detach().numpy().astype(np.float64)
        else:
            self.mean = mean.astype(np.float64)
        
        if isinstance(std, torch.Tensor):
            self.std = std.detach().numpy().astype(np.float64)
        else:
            self.std = std.astype(np.float64)
    
    def get_params(self) -> Tuple[np.ndarray, np.ndarray]:
        """Get normalization parameters."""
        return self.mean, self.std
    
    def __call__(self, y: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
        """Call the transform method."""
        return self.transform(y)
