"""
Class imbalance handling example.

This script demonstrates how to handle class imbalance in the HMM-GLM framework
using various weighting strategies.
"""

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score

from src.core.hmm_glm import CategoricalHMMComponent, LogisticGLMComponent, HMMGLMModel
from src.core.hmm_glm import evaluate_hmm_glm_model
from src.core.weighting import (
    calculate_basic_class_weights,
    calculate_context_aware_weights,
    calculate_temporal_decay_weights,
    calculate_feature_based_weights,
    calculate_combined_weights
)

# Set random seed for reproducibility
np.random.seed(42)

# Generate synthetic imbalanced data
def generate_imbalanced_data(n_samples=1000, n_states=3, imbalance_ratio=0.1):
    """Generate synthetic imbalanced data."""
    # Define HMM parameters
    startprob = np.array([0.6, 0.3, 0.1])
    transmat = np.array([
        [0.7, 0.2, 0.1],
        [0.3, 0.6, 0.1],
        [0.2, 0.2, 0.6]
    ])
    
    # Define emission probabilities with high imbalance
    # Overall success rate will be approximately imbalance_ratio
    emissionprob = np.array([
        [0.98, 0.02],  # State 1: 2% success rate
        [0.90, 0.10],  # State 2: 10% success rate
        [0.70, 0.30]   # State 3: 30% success rate
    ])
    
    # Generate state sequence
    states = np.zeros(n_samples, dtype=int)
    states[0] = np.random.choice(n_states, p=startprob)
    
    for t in range(1, n_samples):
        states[t] = np.random.choice(n_states, p=transmat[states[t-1]])
    
    # Generate observations
    X = np.zeros((n_samples, 5))  # Features
    y = np.zeros(n_samples, dtype=int)
    
    # Generate context variables
    contexts = np.zeros((n_samples, 3))
    contexts[:, 0] = np.random.randint(-3, 4, n_samples)  # Score differential
    contexts[:, 1] = np.random.uniform(0, 1, n_samples)   # Normalized time
    contexts[:, 2] = np.random.normal(0, 1, n_samples)    # Random context
    
    # Generate timestamps for temporal decay
    timestamps = np.arange(n_samples)
    
    # Generate sequence IDs (groups of 20 events)
    sequence_ids = np.repeat(np.arange(n_samples // 20 + 1), 20)[:n_samples]
    
    for t in range(n_samples):
        # Generate outcome
        y[t] = np.random.choice(2, p=emissionprob[states[t]])
        
        # Generate features
        if states[t] == 0:
            X[t, :2] = np.random.normal([-1, -1], [0.5, 0.5])
        elif states[t] == 1:
            X[t, :2] = np.random.normal([0, 0], [0.5, 0.5])
        else:
            X[t, :2] = np.random.normal([1, 1], [0.5, 0.5])
        
        # Add some noise features
        X[t, 2:] = np.random.normal(0, 1, 3)
    
    return X, y, states, contexts, timestamps, sequence_ids

# Generate data
print("Generating synthetic imbalanced data...")
X, y, true_states, contexts, timestamps, sequence_ids = generate_imbalanced_data(
    n_samples=1000, n_states=3, imbalance_ratio=0.1)

# Check class balance
positive_rate = np.mean(y)
print(f"Class balance: {positive_rate:.2%} positive, {1-positive_rate:.2%} negative")

# Split data into training and testing sets
X_train, X_test, y_train, y_test, contexts_train, contexts_test, \
timestamps_train, timestamps_test, sequence_ids_train, sequence_ids_test = train_test_split(
    X, y, contexts, timestamps, sequence_ids, test_size=0.2, random_state=42)

# Calculate different sample weights
print("Calculating sample weights...")

# Basic class weights
basic_weights = calculate_basic_class_weights(y_train)

# Context-aware weights
context_weights = calculate_context_aware_weights(y_train, contexts_train)

# Temporal decay weights
temporal_weights = calculate_temporal_decay_weights(y_train, timestamps_train, sequence_ids_train)

# Feature-based weights
feature_weights = calculate_feature_based_weights(y_train, X_train)

# Combined weights
combined_weights = calculate_combined_weights(
    y_train, contexts_train, X_train, timestamps_train, sequence_ids_train)

# Create models with different weighting strategies
print("Creating and fitting models with different weighting strategies...")

# No weighting
model_none = HMMGLMModel(
    hmm_component=CategoricalHMMComponent(n_states=3, n_categories=2, random_state=42),
    glm_component=LogisticGLMComponent(random_state=42)
)
model_none.fit(y_train, X_train, y_train)

# Basic class weighting
model_basic = HMMGLMModel(
    hmm_component=CategoricalHMMComponent(n_states=3, n_categories=2, random_state=42),
    glm_component=LogisticGLMComponent(random_state=42)
)
model_basic.fit(y_train, X_train, y_train, sample_weight=basic_weights)

# Context-aware weighting
model_context = HMMGLMModel(
    hmm_component=CategoricalHMMComponent(n_states=3, n_categories=2, random_state=42),
    glm_component=LogisticGLMComponent(random_state=42)
)
model_context.fit(y_train, X_train, y_train, sample_weight=context_weights)

# Temporal decay weighting
model_temporal = HMMGLMModel(
    hmm_component=CategoricalHMMComponent(n_states=3, n_categories=2, random_state=42),
    glm_component=LogisticGLMComponent(random_state=42)
)
model_temporal.fit(y_train, X_train, y_train, sample_weight=temporal_weights)

# Feature-based weighting
model_feature = HMMGLMModel(
    hmm_component=CategoricalHMMComponent(n_states=3, n_categories=2, random_state=42),
    glm_component=LogisticGLMComponent(random_state=42)
)
model_feature.fit(y_train, X_train, y_train, sample_weight=feature_weights)

# Combined weighting
model_combined = HMMGLMModel(
    hmm_component=CategoricalHMMComponent(n_states=3, n_categories=2, random_state=42),
    glm_component=LogisticGLMComponent(random_state=42)
)
model_combined.fit(y_train, X_train, y_train, sample_weight=combined_weights)

# Evaluate models
print("Evaluating models...")

models = {
    'No Weighting': model_none,
    'Basic Class Weights': model_basic,
    'Context-Aware Weights': model_context,
    'Temporal Decay Weights': model_temporal,
    'Feature-Based Weights': model_feature,
    'Combined Weights': model_combined
}

results = {}
for name, model in models.items():
    results[name] = evaluate_hmm_glm_model(model, y_test, X_test, y_test)
    print(f"\n{name} metrics:")
    for metric, value in results[name].items():
        print(f"  {metric}: {value:.4f}")

# Plot ROC curves
plt.figure(figsize=(10, 8))

for name, model in models.items():
    y_pred_proba = model.predict_proba(y_test, X_test)
    fpr, tpr, _ = roc_curve(y_test, y_pred_proba)
    roc_auc = auc(fpr, tpr)
    plt.plot(fpr, tpr, label=f'{name} (AUC = {roc_auc:.3f})')

plt.plot([0, 1], [0, 1], 'k--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC) Curves')
plt.legend(loc="lower right")

# Plot Precision-Recall curves
plt.figure(figsize=(10, 8))

for name, model in models.items():
    y_pred_proba = model.predict_proba(y_test, X_test)
    precision, recall, _ = precision_recall_curve(y_test, y_pred_proba)
    avg_precision = average_precision_score(y_test, y_pred_proba)
    plt.plot(recall, precision, label=f'{name} (AP = {avg_precision:.3f})')

plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curves')
plt.legend(loc="lower left")

# Plot weight distributions
plt.figure(figsize=(12, 8))

weights_dict = {
    'Basic Class Weights': basic_weights,
    'Context-Aware Weights': context_weights,
    'Temporal Decay Weights': temporal_weights,
    'Feature-Based Weights': feature_weights,
    'Combined Weights': combined_weights
}

for i, (name, weights) in enumerate(weights_dict.items()):
    plt.subplot(2, 3, i+1)
    plt.hist(weights, bins=30, alpha=0.7)
    plt.title(name)
    plt.xlabel('Weight Value')
    plt.ylabel('Count')

plt.tight_layout()
plt.show()

print("\nDone!")


