"""
Input constructors for GLEAM-AI.

This module contains functions for constructing input features, temporal features,
and initial conditions for the GLEAM-AI epidemiological forecasting system.
"""

import pandas as pd
import numpy as np
import datetime
from typing import Union, List, Tuple, Optional


def prepare_data(
    r0: List[float],
    days: List[int],
    frac_susceptible: List[float],
    frac_latent: List[float],
    frac_latent_vax: List[float],
    frac_recovered: List[float],
    seasonality_min: List[float],
    starting_dates: List[str],
    config: dict,
    populations: dict
) -> dict:
    """
    Prepare data for model input construction.
    
    Args:
        r0: Basic reproduction numbers
        days: Number of days
        frac_susceptible: Fraction of susceptible population
        frac_latent: Fraction of latent population
        frac_latent_vax: Fraction of vaccinated latent population
        frac_recovered: Fraction of recovered population
        seasonality_min: Minimum seasonality values
        starting_dates: Starting dates for simulations
        config: Configuration dictionary
        populations: Population data
        
    Returns:
        Dictionary containing constructed features
    """
    # Validate inputs
    list_params = [r0, days, frac_susceptible, frac_latent, frac_latent_vax, 
                   frac_recovered, seasonality_min, starting_dates]
    for param in list_params:
        if not isinstance(param, list):
            raise TypeError(f"{param} must be a list")
    
    # Create DataFrame
    data = {
        'R0': r0,
        'days': days,
        'Susceptible': frac_susceptible,
        'Latent': frac_latent,
        "Latent_vax": frac_latent_vax,
        'Recovered': frac_recovered,
        'seasonality_min': seasonality_min,
        'starting_date': starting_dates
    }
    
    # Extract configuration parameters
    frac_pops_names = config["data"]["frac_pops_names"]
    x_col_names = config["data"]["x_col_names"]
    seq_len = config["model"]["seq_len"]
    num_nodes = config["model"]["num_nodes"]
    population_scaler = config["model"]["POPULATION_SCALER"]
    
    df = pd.DataFrame(data)
    
    # Construct features
    x = construct_x(df, x_col_names, frac_pops_names, populations, population_scaler)
    xt = construct_temporal_features(df, seq_len, num_nodes)
    y0 = construct_initial_conditions(df, populations)
    
    return {"x": x, "xt": xt, "y0_latent_prev": y0}


def construct_x(
    df: pd.DataFrame,
    x_col_names: List[str],
    frac_pops_names: List[str],
    populations: np.ndarray,
    population_scaler: float = 1.0
) -> np.ndarray:
    """
    Construct input features x.
    
    Args:
        df: Input DataFrame
        x_col_names: Column names for x features
        frac_pops_names: Column names for population fractions
        populations: Population data
        population_scaler: Population scaling factor
        
    Returns:
        Constructed x features [B, num_nodes, x_dim + pop_dim]
    """
    pops = populations / float(population_scaler)
    df_copy = df.copy()
    df_copy["days"] /= 366  # Normalize days
    
    # Extract x features
    x = df_copy[x_col_names].values
    if np.ndim(x) < 2:
        x = x[np.newaxis, ...]
    
    # Extract population fractions
    x_frac_pops = df_copy[frac_pops_names].values
    if x_frac_pops.ndim < 2:
        x_frac_pops = x_frac_pops[np.newaxis, ...]
    
    # Expand to all nodes
    x_frac_pops = np.repeat(x_frac_pops[:, np.newaxis, :], len(pops), axis=1)
    x = np.repeat(x[:, np.newaxis, :], len(pops), axis=1)
    
    # Add population features
    pop_feat = pops[np.newaxis, :, np.newaxis]
    pop_feat = np.repeat(pop_feat, x_frac_pops.shape[0], axis=0)
    
    return np.concatenate([x, x_frac_pops, pop_feat], axis=-1)


def construct_y0(df: pd.DataFrame, pops: np.ndarray) -> np.ndarray:
    """
    Construct initial conditions y0.
    
    Args:
        df: Input DataFrame
        pops: Population data
        
    Returns:
        Initial conditions [B, num_nodes]
    """
    latent = df["Latent"].values
    
    if np.ndim(latent) > 1:
        latent = np.squeeze(latent)
    if np.ndim(latent) < 1:
        latent = latent[np.newaxis]
    
    y0 = np.einsum("i,j->ij", latent, pops)
    return y0


def construct_initial_conditions(
    df: Union[pd.DataFrame, pd.Series],
    populations: np.ndarray,
    population_scaler: float = 1.0
) -> np.ndarray:
    """
    Construct initial conditions for latent prevalence.
    
    Args:
        df: Input DataFrame or Series
        populations: Population data
        population_scaler: Population scaling factor
        
    Returns:
        Initial conditions [B, num_nodes]
    """
    # Combine latent and vaccinated latent populations
    latent = (df[["Latent"]].values + df[["Latent_vax"]].values).astype(np.float64)
    
    if latent.ndim > 1:
        if np.squeeze(latent).ndim == 0:
            latent = np.squeeze(latent).reshape(1)
        else:
            latent = np.squeeze(latent)
    
    # Scale by populations
    y0_latent = np.einsum("i,j->ij", latent, populations)
    return np.round(y0_latent)


def construct_temporal_features(
    df: Union[pd.DataFrame, pd.Series],
    seq_len: int,
    num_nodes: int
) -> np.ndarray:
    """
    Construct temporal features.
    
    Args:
        df: Input DataFrame or Series
        seq_len: Sequence length
        num_nodes: Number of nodes
        
    Returns:
        Temporal features [B, L, num_nodes, xt_dim]
    """
    if isinstance(df, pd.Series):
        df = df.to_frame().transpose()
    
    starting_date = np.array(df["starting_date"].tolist())
    seasonality_min = np.array(df["seasonality_min"].tolist())
    
    # Get seasonality features
    seasonality = get_seasonality(starting_date, seasonality_min, seq_len)
    
    # Expand to all nodes
    seasonality = seasonality[..., np.newaxis, np.newaxis]
    seasonality = np.repeat(seasonality, num_nodes, axis=-2)
    
    if np.ndim(seasonality) < 4:
        seasonality = seasonality[np.newaxis, ...]
    
    return seasonality


def get_seasonality(
    starting_date: Union[np.ndarray, List],
    seasonality_min: Union[np.ndarray, List, float],
    seq_len: int
) -> np.ndarray:
    """
    Compute seasonality features based on starting dates.
    
    Args:
        starting_date: Starting dates
        seasonality_min: Minimum seasonality values
        seq_len: Sequence length
        
    Returns:
        Seasonality features [B, L]
    """
    # Convert to numpy arrays
    if isinstance(starting_date, (list, np.ndarray)):
        starting_date = np.array(starting_date, dtype='datetime64[D]')
    else:
        starting_date = np.array([starting_date], dtype='datetime64[D]')
    
    if isinstance(seasonality_min, (float, int)):
        seasonality_min = np.array([seasonality_min])
    elif isinstance(seasonality_min, (list, np.ndarray)):
        seasonality_min = np.array(seasonality_min)
    else:
        raise TypeError("seasonality_min must be integer or float")
    
    # Create date range
    date_range = np.arange(seq_len + 1)
    dates = starting_date[:, np.newaxis] + np.timedelta64(1, 'D') * date_range
    
    # Calculate day of year
    day_of_year = (dates.astype('datetime64[D]') - dates.astype('datetime64[Y]')).astype(int) + 1
    
    # Calculate day of year for January 15th (peak of winter in Northern Hemisphere)
    starting_year = starting_date.astype('datetime64[Y]')
    jan_15 = starting_year + np.timedelta64(14, 'D')
    jan_15_doy = (jan_15 - jan_15.astype('datetime64[Y]')).astype('timedelta64[D]').astype(int) + 1
    
    # Compute seasonal adjustment
    s_r = seasonality_min[:, np.newaxis] / 1.0  # seasonality_max is 1.0
    days_since_peak = (day_of_year - jan_15_doy[:, np.newaxis] + 365) % 365
    seasonal_adjustment = 0.5 * (
        (1 - s_r) * np.sin(2 * np.pi / 365 * days_since_peak + 0.5 * np.pi) + 1 + s_r
    )
    
    return np.squeeze(seasonal_adjustment)


def get_starting_date_array(df_starting_date: Union[pd.Series, pd.Timestamp]) -> np.ndarray:
    """
    Extract date features from starting dates.
    
    Args:
        df_starting_date: Starting dates
        
    Returns:
        Date features [B, 3] (year, month, day)
    """
    df = df_starting_date.copy()
    df = pd.to_datetime(df)
    
    if isinstance(df, pd.Timestamp):
        year = pd.to_numeric(df.year)
        month = pd.to_numeric(df.month)
        day = pd.to_numeric(df.day)
    elif isinstance(df, pd.Series):
        year = pd.to_numeric(df.dt.year)
        month = pd.to_numeric(df.dt.month)
        day = pd.to_numeric(df.dt.day)
    else:
        raise TypeError("Unsupported date type")
    
    # Convert to arrays
    year = year.values if np.ndim(year) != 0 else np.array([year])
    month = month.values if np.ndim(month) != 0 else np.array([month])
    day = day.values if np.ndim(day) != 0 else np.array([day])
    
    return np.stack([year, month, day], axis=-1)


def get_date_features_from_numpy(arr: np.ndarray, t: int) -> np.ndarray:
    """
    Get date features from numpy array.
    
    Args:
        arr: Date array [B, nodes, 3]
        t: Time step
        
    Returns:
        Date features [B, nodes, 3]
    """
    B, nodes, _ = arr.shape
    
    # Convert array to datetime
    date_list = convert_array_to_datetime(arr)
    
    # Extract date features
    dates = [np.array([d.year % 100, d.month, d.day]) for d in date_list]
    dates = np.stack(dates, axis=0)[np.newaxis, ...]
    dates = np.repeat(dates, B, axis=0)
    
    return dates


def convert_array_to_datetime(arr: np.ndarray) -> List[datetime.datetime]:
    """
    Convert numpy array to datetime objects.
    
    Args:
        arr: Date array [B, nodes, 3]
        
    Returns:
        List of datetime objects
    """
    x = arr[0].astype(np.int64)
    dates = [datetime.datetime(*x[i, :]) for i in range(x.shape[0])]
    return dates


def apply_seasonality(
    day: datetime.datetime,
    seasonality_min: float,
    seasonality_max: float = 1.0
) -> float:
    """
    Apply seasonality adjustment to a specific day.
    
    Args:
        day: Date to apply seasonality to
        seasonality_min: Minimum seasonality value
        seasonality_max: Maximum seasonality value
        
    Returns:
        Seasonal adjustment factor
    """
    s_r = seasonality_min / seasonality_max
    day_max_north = datetime.datetime(day.year, 1, 15)  # January 15th
    
    # Northern hemisphere seasonal adjustment
    seasonal_adjustment = 0.5 * (
        (1 - s_r) * np.sin(2 * np.pi / 365 * ((day - day_max_north).days) + 0.5 * np.pi) + 1 + s_r
    )
    
    return seasonal_adjustment


def validate_input_data(
    df: pd.DataFrame,
    required_columns: List[str]
) -> bool:
    """
    Validate that input DataFrame contains required columns.
    
    Args:
        df: Input DataFrame
        required_columns: List of required column names
        
    Returns:
        True if all required columns are present, False otherwise
    """
    missing_columns = set(required_columns) - set(df.columns)
    if missing_columns:
        raise ValueError(f"Missing required columns: {missing_columns}")
    return True


def normalize_features(
    features: np.ndarray,
    mean: Optional[np.ndarray] = None,
    std: Optional[np.ndarray] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Normalize features using z-score normalization.
    
    Args:
        features: Input features
        mean: Mean values (computed if None)
        std: Standard deviation values (computed if None)
        
    Returns:
        Tuple of (normalized_features, mean, std)
    """
    if mean is None:
        mean = np.mean(features, axis=0, keepdims=True)
    if std is None:
        std = np.std(features, axis=0, keepdims=True)
        std[std == 0] = 1e-8  # Avoid division by zero
    
    normalized_features = (features - mean) / std
    return normalized_features, mean, std


def denormalize_features(
    normalized_features: np.ndarray,
    mean: np.ndarray,
    std: np.ndarray
) -> np.ndarray:
    """
    Denormalize features.
    
    Args:
        normalized_features: Normalized features
        mean: Mean values used for normalization
        std: Standard deviation values used for normalization
        
    Returns:
        Denormalized features
    """
    return normalized_features * std + mean
