"""
Multimodal data integration example.

This script demonstrates how to integrate multiple data modalities (spatiotemporal,
biomechanical, physiological) into a unified feature representation for HMM-GLM modeling.
"""

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

from src.core.hmm_glm import CategoricalHMMComponent, LogisticGLMComponent, HMMGLMModel
from src.core.hmm_glm import evaluate_hmm_glm_model
from src.features.multimodal import extract_multimodal_features

# Set random seed for reproducibility
np.random.seed(42)

# Generate synthetic data
def generate_multimodal_data(n_samples=1000):
    """Generate synthetic multimodal data."""
    # Generate player IDs and timestamps
    player_ids = np.random.choice(50, n_samples)
    timestamps = pd.date_range(start='2023-01-01', periods=n_samples, freq='1min')
    
    # Generate spatiotemporal data
    spatial_data = np.zeros((n_samples, 4))  # x, y, distance, angle
    spatial_data[:, 0] = np.random.uniform(-50, 50, n_samples)  # x
    spatial_data[:, 1] = np.random.uniform(-25, 25, n_samples)  # y
    spatial_data[:, 2] = np.sqrt(spatial_data[:, 0]**2 + spatial_data[:, 1]**2)  # distance
    spatial_data[:, 3] = np.arctan2(spatial_data[:, 1], spatial_data[:, 0])  # angle
    
    # Generate temporal data
    temporal_data = np.zeros((n_samples, 3))  # game_seconds, remaining_seconds, score_diff
    temporal_data[:, 0] = np.random.uniform(0, 3600, n_samples)  # game_seconds
    temporal_data[:, 1] = 3600 - temporal_data[:, 0]  # remaining_seconds
    temporal_data[:, 2] = np.random.randint(-5, 6, n_samples)  # score_diff
    
    # Create DataFrame
    df = pd.DataFrame({
        'player_id': player_ids,
        'date_time': timestamps,
        'x': spatial_data[:, 0],
        'y': spatial_data[:, 1],
        'distance_to_net': spatial_data[:, 2],
        'shot_angle': spatial_data[:, 3],
        'game_seconds': temporal_data[:, 0],
        'period_seconds_remaining': temporal_data[:, 1],
        'score_differential': temporal_data[:, 2]
    })
    
    # Generate biomechanical data (for a subset of players/timestamps)
    biomech_data = {}
    for i in range(n_samples):
        if np.random.random() < 0.7:  # 70% of samples have biomechanical data
            player_id = player_ids[i]
            timestamp = timestamps[i]
            
            # Generate 5 biomechanical features
            features = np.random.normal(0, 1, 5)
            biomech_data[(player_id, timestamp)] = features
    
    # Generate physiological data (for a subset of players/timestamps)
    physio_data = {}
    for i in range(n_samples):
        if np.random.random() < 0.5:  # 50% of samples have physiological data
            player_id = player_ids[i]
            timestamp = timestamps[i]
            
            # Generate 3 physiological features
            features = np.random.normal(0, 1, 3)
            physio_data[(player_id, timestamp)] = features
    
    # Generate outcomes (success/failure)
    # Higher probability of success when:
    # - Closer to net
    # - Better angle
    # - More time remaining
    # - Leading in score
    
    success_prob = 1 / (1 + np.exp(
        0.05 * spatial_data[:, 2] -  # distance effect
        0.5 * np.cos(spatial_data[:, 3]) +  # angle effect
        0.0002 * temporal_data[:, 0] -  # time effect
        0.1 * temporal_data[:, 2]  # score effect
    ))
    
    y = np.random.binomial(1, success_prob)
    
    return df, biomech_data, physio_data, y

# Generate data
print("Generating synthetic multimodal data...")
df, biomech_data, physio_data, y = generate_multimodal_data(n_samples=1000)

# Define feature columns
spatial_cols = ['x', 'y', 'distance_to_net', 'shot_angle']
temporal_cols = ['game_seconds', 'period_seconds_remaining', 'score_differential']

# Extract multimodal features
print("Extracting multimodal features...")
X_integrated = extract_multimodal_features(
    df, biomech_data, physio_data,
    spatial_cols, temporal_cols,
    player_id_col='player_id', timestamp_col='date_time',
    weights=(1.0, 0.5, 0.3)
)

# Split data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X_integrated, y, test_size=0.2, random_state=42)

# Create HMM-GLM model
print("Creating HMM-GLM model...")
hmm_comp = CategoricalHMMComponent(n_states=3, n_categories=2, random_state=42)
glm_comp = LogisticGLMComponent(random_state=42)
model = HMMGLMModel(hmm_component=hmm_comp, glm_component=glm_comp)

# Fit the model
print("Fitting HMM-GLM model...")
model.fit(y_train, X_train, y_train)

# Evaluate the model
print("Evaluating HMM-GLM model...")
metrics = evaluate_hmm_glm_model(model, y_test, X_test, y_test)
print("HMM-GLM metrics:")
for metric, value in metrics.items():
    print(f"  {metric}: {value:.4f}")

# Compare with different feature sets
print("\nComparing different feature sets...")

# Spatiotemporal features only
X_spatiotemporal = df[spatial_cols + temporal_cols].values
X_train_st, X_test_st, y_train_st, y_test_st = train_test_split(X_spatiotemporal, y, test_size=0.2, random_state=42)

model_st = HMMGLMModel(
    hmm_component=CategoricalHMMComponent(n_states=3, n_categories=2, random_state=42),
    glm_component=LogisticGLMComponent(random_state=42)
)
model_st.fit(y_train_st, X_train_st, y_train_st)
metrics_st = evaluate_hmm_glm_model(model_st, y_test_st, X_test_st, y_test_st)

print("Spatiotemporal features only:")
for metric, value in metrics_st.items():
    print(f"  {metric}: {value:.4f}")

print("\nFeature importance comparison:")
print(f"  Integrated features: {metrics['auc']:.4f} AUC")
print(f"  Spatiotemporal only: {metrics_st['auc']:.4f} AUC")
print(f"  Improvement: {(metrics['auc'] - metrics_st['auc']) * 100:.2f}%")

print("\nDone!")


