"""
Context-aware transitions example.

This script demonstrates how to use context-aware transition matrices in the HMM-GLM framework.
"""

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

from src.core.hmm_glm import CategoricalHMMComponent, LogisticGLMComponent, HMMGLMModel
from src.core.hmm_glm import evaluate_hmm_glm_model, plot_state_transition_matrix

# Set random seed for reproducibility
np.random.seed(42)

# Generate synthetic data with context
def generate_context_dependent_data(n_samples=1000, n_states=3):
    """Generate synthetic data with context-dependent transitions."""
    # Define context variables
    score_diff = np.random.randint(-5, 6, n_samples)  # Score differential
    time_remaining = np.random.uniform(0, 1, n_samples)  # Normalized time remaining
    
    # Combine into context array
    contexts = np.column_stack([score_diff, time_remaining])
    
    # Define base transition probabilities
    base_transmat = np.array([
        [0.7, 0.2, 0.1],
        [0.3, 0.6, 0.1],
        [0.2, 0.2, 0.6]
    ])
    
    # Define context effects on transitions
    # Effect of score differential: more likely to transition to higher states when leading
    score_effect = np.array([
        [0.0, 0.01, 0.02],
        [-0.01, 0.0, 0.01],
        [-0.02, -0.01, 0.0]
    ])
    
    # Effect of time remaining: more likely to stay in current state when less time remains
    time_effect = np.array([
        [0.1, -0.05, -0.05],
        [-0.05, 0.1, -0.05],
        [-0.05, -0.05, 0.1]
    ])
    
    # Generate state sequence
    states = np.zeros(n_samples, dtype=int)
    states[0] = np.random.choice(n_states)
    
    for t in range(1, n_samples):
        # Calculate context-dependent transition probabilities
        transmat = base_transmat.copy()
        
        # Apply score differential effect
        transmat += score_effect * score_diff[t]
        
        # Apply time remaining effect
        transmat += time_effect * (1 - time_remaining[t])
        
        # Ensure valid probabilities
        transmat = np.clip(transmat, 0.01, 0.99)
        transmat = transmat / transmat.sum(axis=1, keepdims=True)
        
        # Generate next state
        states[t] = np.random.choice(n_states, p=transmat[states[t-1]])
    
    # Define emission probabilities
    emissionprob = np.array([
        [0.8, 0.2],  # State 1: 20% success rate
        [0.5, 0.5],  # State 2: 50% success rate
        [0.2, 0.8]   # State 3: 80% success rate
    ])
    
    # Generate observations
    X = np.zeros((n_samples, 2))  # Simple features
    y = np.zeros(n_samples, dtype=int)
    
    for t in range(n_samples):
        # Generate outcome
        y[t] = np.random.choice(2, p=emissionprob[states[t]])
        
        # Generate features
        if states[t] == 0:
            X[t] = np.random.normal([-1, -1], [0.5, 0.5])
        elif states[t] == 1:
            X[t] = np.random.normal([0, 0], [0.5, 0.5])
        else:
            X[t] = np.random.normal([1, 1], [0.5, 0.5])
    
    return X, y, states, contexts

# Generate data
print("Generating synthetic data with context...")
X, y, true_states, contexts = generate_context_dependent_data(n_samples=1000, n_states=3)

# Split data into training and testing sets
X_train, X_test, y_train, y_test, contexts_train, contexts_test = train_test_split(
    X, y, contexts, test_size=0.2, random_state=42)

# Create standard HMM-GLM model (without context)
print("Creating standard HMM-GLM model...")
hmm_comp_std = CategoricalHMMComponent(n_states=3, n_categories=2, random_state=42)
glm_comp_std = LogisticGLMComponent(random_state=42)
model_std = HMMGLMModel(hmm_component=hmm_comp_std, glm_component=glm_comp_std)

# Fit the standard model
print("Fitting standard HMM-GLM model...")
model_std.fit(y_train, X_train, y_train)

# Create context-aware HMM-GLM model
print("Creating context-aware HMM-GLM model...")
hmm_comp_ctx = CategoricalHMMComponent(n_states=3, n_categories=2, random_state=42, n_contexts=2)
glm_comp_ctx = LogisticGLMComponent(random_state=42)
model_ctx = HMMGLMModel(hmm_component=hmm_comp_ctx, glm_component=glm_comp_ctx)

# Fit the context-aware model
print("Fitting context-aware HMM-GLM model...")
model_ctx.fit(y_train, X_train, y_train, contexts=contexts_train)

# Evaluate the models
print("Evaluating models...")
metrics_std = evaluate_hmm_glm_model(model_std, y_test, X_test, y_test)
metrics_ctx = evaluate_hmm_glm_model(model_ctx, y_test, X_test, y_test, contexts=contexts_test)

print("Standard HMM-GLM metrics:")
for metric, value in metrics_std.items():
    print(f"  {metric}: {value:.4f}")

print("\nContext-aware HMM-GLM metrics:")
for metric, value in metrics_ctx.items():
    print(f"  {metric}: {value:.4f}")

print("\nImprovement with context-aware transitions:")
for metric in metrics_std.keys():
    if metric in ['auc', 'accuracy']:
        improvement = metrics_ctx[metric] - metrics_std[metric]
    else:
        improvement = metrics_std[metric] - metrics_ctx[metric]
    print(f"  {metric}: {improvement:.4f}")

# Visualize transition matrices for different contexts
print("\nVisualizing transition matrices for different contexts...")

# Create different context scenarios
contexts_viz = np.array([
    [-3, 0.9],  # Trailing by 3, early in game
    [0, 0.9],   # Tied, early in game
    [3, 0.9],   # Leading by 3, early in game
    [-3, 0.1],  # Trailing by 3, late in game
    [0, 0.1],   # Tied, late in game
    [3, 0.1]    # Leading by 3, late in game
])

# Create subplots
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

# Plot transition matrices for different contexts
for i, context in enumerate(contexts_viz):
    # Compute context-dependent transition matrix
    if hasattr(model_ctx.hmm_component, 'alpha_') and hasattr(model_ctx.hmm_component, 'beta_'):
        from src.core.context_transitions import compute_context_dependent_transitions
        transmat = compute_context_dependent_transitions(
            model_ctx.hmm_component.alpha_, 
            model_ctx.hmm_component.beta_, 
            context.reshape(1, -1)
        )[0]
        
        # Plot heatmap
        im = axes[i].imshow(transmat, cmap='YlGnBu', vmin=0, vmax=1)
        
        # Add text annotations
        for i_row in range(transmat.shape[0]):
            for i_col in range(transmat.shape[1]):
                axes[i].text(i_col, i_row, f'{transmat[i_row, i_col]:.2f}',
                           ha='center', va='center', color='black')
        
        # Set title based on context
        score_diff = context[0]
        time_rem = context[1]
        score_status = "Leading" if score_diff > 0 else "Trailing" if score_diff < 0 else "Tied"
        game_phase = "Early" if time_rem > 0.5 else "Late"
        axes[i].set_title(f"{score_status} by {abs(score_diff)}, {game_phase} Game")
        
        # Set labels
        axes[i].set_xlabel('To State')
        axes[i].set_ylabel('From State')
        
        # Set tick labels
        axes[i].set_xticks(np.arange(3))
        axes[i].set_yticks(np.arange(3))
        axes[i].set_xticklabels([f'State {j+1}' for j in range(3)])
        axes[i].set_yticklabels([f'State {j+1}' for j in range(3)])

# Add colorbar
fig.subplots_adjust(right=0.85)
cbar_ax = fig.add_axes([0.88, 0.15, 0.03, 0.7])
fig.colorbar(im, cax=cbar_ax)

# Set overall title
fig.suptitle('Context-Dependent Transition Matrices', fontsize=16)
plt.tight_layout(rect=[0, 0, 0.85, 0.95])
plt.show()

print("\nDone!")


