"""
Class imbalance handling strategies for HMM-GLM.

This module provides functions for calculating various types of weights
to address class imbalance in sports data.
"""

import numpy as np
from scipy.stats import pearsonr

def calculate_basic_class_weights(y):
    """
    Calculate basic class weights using inverse frequency weighting.
    
    Parameters:
    -----------
    y : ndarray, shape (n_samples,)
        Binary outcome variable (1 for success, 0 for failure)
    
    Returns:
    --------
    weights : ndarray, shape (n_samples,)
        Sample weights
    """
    n_samples = len(y)
    n_success = np.sum(y)
    n_failure = n_samples - n_success
    
    weights = np.zeros(n_samples)
    weights[y == 1] = n_samples / (2 * n_success) if n_success > 0 else 0
    weights[y == 0] = n_samples / (2 * n_failure) if n_failure > 0 else 0
    
    return weights

def calculate_context_aware_weights(y, contexts):
    """
    Calculate context-aware weights.
    
    Parameters:
    -----------
    y : ndarray, shape (n_samples,)
        Binary outcome variable (1 for success, 0 for failure)
    contexts : ndarray, shape (n_samples, n_context_vars)
        Context variables for each sample
    
    Returns:
    --------
    weights : ndarray, shape (n_samples,)
        Context-aware weights
    """
    # Calculate basic class weights
    base_weights = calculate_basic_class_weights(y)
    
    # Calculate importance weights based on correlation with outcome
    delta = np.zeros(contexts.shape[1])
    for k in range(contexts.shape[1]):
        # Use absolute correlation as importance weight
        corr, _ = pearsonr(contexts[:, k], y)
        delta[k] = abs(corr)
    
    # Normalize importance weights
    if np.sum(delta) > 0:
        delta /= np.sum(delta)
    
    # Calculate context factors
    context_means = np.mean(contexts, axis=0)
    context_factors = np.zeros(len(y))
    
    for i in range(len(y)):
        context_factors[i] = np.sum(delta * np.abs(contexts[i] - context_means))
    
    # Apply context adjustment with scaling factor
    gamma = 0.5  # Context weight
    weights = base_weights * (1.0 + gamma * context_factors)
    
    return weights

def calculate_temporal_decay_weights(y, timestamps=None, sequence_ids=None):
    """
    Calculate temporal decay weights.
    
    Parameters:
    -----------
    y : ndarray, shape (n_samples,)
        Binary outcome variable (1 for success, 0 for failure)
    timestamps : ndarray, shape (n_samples,), optional
        Timestamps for each sample
    sequence_ids : ndarray, shape (n_samples,), optional
        Sequence identifiers for each sample
    
    Returns:
    --------
    weights : ndarray, shape (n_samples,)
        Temporal decay weights
    """
    # Calculate basic class weights
    base_weights = calculate_basic_class_weights(y)
    
    # If timestamps or sequence_ids are not provided, return base weights
    if timestamps is None or sequence_ids is None:
        return base_weights
    
    # Apply temporal decay within each sequence
    eta = 1.0  # Temporal decay factor
    weights = base_weights.copy()
    
    unique_sequences = np.unique(sequence_ids)
    
    for seq_id in unique_sequences:
        seq_mask = (sequence_ids == seq_id)
        seq_times = timestamps[seq_mask]
        
        if len(seq_times) > 1:
            t_start = np.min(seq_times)
            t_end = np.max(seq_times)
            
            if t_end > t_start:
                # Calculate normalized position in sequence
                normalized_times = (seq_times - t_start) / (t_end - t_start)
                
                # Apply temporal decay
                weights[seq_mask] *= (1.0 + eta * normalized_times)
    
    return weights

def calculate_feature_based_weights(y, features):
    """
    Calculate feature-based weights.
    
    Parameters:
    -----------
    y : ndarray, shape (n_samples,)
        Binary outcome variable (1 for success, 0 for failure)
    features : ndarray, shape (n_samples, n_features)
        Feature values for each sample
    
    Returns:
    --------
    weights : ndarray, shape (n_samples,)
        Feature-based weights
    """
    # Calculate basic class weights
    base_weights = calculate_basic_class_weights(y)
    
    # If features are not provided, return base weights
    if features is None:
        return base_weights
    
    # Calculate feature factors
    feature_means = np.mean(features, axis=0)
    feature_stds = np.std(features, axis=0)
    feature_stds[feature_stds == 0] = 1.0  # Avoid division by zero
    
    feature_factors = np.zeros(len(y))
    
    for i in range(len(y)):
        # Calculate average standardized distance from mean
        feature_factors[i] = np.mean(np.abs((features[i] - feature_means) / feature_stds))
    
    # Apply feature adjustment with scaling factor
    phi = 0.3  # Feature weight
    weights = base_weights * (1.0 + phi * feature_factors)
    
    return weights

def calculate_combined_weights(y, contexts=None, features=None, 
                             timestamps=None, sequence_ids=None):
    """
    Calculate combined weights using all available information.
    
    Parameters:
    -----------
    y : ndarray, shape (n_samples,)
        Binary outcome variable (1 for success, 0 for failure)
    contexts : ndarray, shape (n_samples, n_context_vars), optional
        Context variables for each sample
    features : ndarray, shape (n_samples, n_features), optional
        Feature values for each sample
    timestamps : ndarray, shape (n_samples,), optional
        Timestamps for each sample
    sequence_ids : ndarray, shape (n_samples,), optional
        Sequence identifiers for each sample
    
    Returns:
    --------
    weights : ndarray, shape (n_samples,)
        Combined weights
    """
    # Calculate individual weights
    base_weights = calculate_basic_class_weights(y)
    
    # Initialize weights for unavailable components
    context_weights = np.ones_like(base_weights)
    feature_weights = np.ones_like(base_weights)
    temporal_weights = np.ones_like(base_weights)
    
    # Calculate available component weights
    if contexts is not None:
        context_weights = calculate_context_aware_weights(y, contexts) / base_weights
    
    if features is not None:
        feature_weights = calculate_feature_based_weights(y, features) / base_weights
    
    if timestamps is not None and sequence_ids is not None:
        temporal_weights = calculate_temporal_decay_weights(y, timestamps, sequence_ids) / base_weights
    
    # Combine weights using weighted geometric mean
    alpha1 = 0.4  # Base weight importance
    alpha2 = 0.3  # Context weight importance
    alpha3 = 0.2  # Temporal weight importance
    alpha4 = 0.1  # Feature weight importance
    
    combined_weights = (base_weights ** alpha1) * \
                     (context_weights * base_weights) ** alpha2 * \
                     (temporal_weights * base_weights) ** alpha3 * \
                     (feature_weights * base_weights) ** alpha4
    
    # Normalize weights to sum to n_samples
    n_samples = len(y)
    combined_weights = combined_weights * n_samples / np.sum(combined_weights)
    
    return combined_weights


