"""
Basic usage example of the HMM-GLM framework.

This script demonstrates how to use the HMM-GLM framework for a simple binary classification task.
"""

import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

from src.core.hmm_glm import CategoricalHMMComponent, LogisticGLMComponent, HMMGLMModel
from src.core.hmm_glm import evaluate_hmm_glm_model, compare_with_baseline
from src.core.hmm_glm import plot_state_transition_matrix, plot_emission_probabilities

# Set random seed for reproducibility
np.random.seed(42)

# Generate synthetic data
def generate_hmm_data(n_samples=1000, n_states=3):
    """Generate synthetic data from an HMM."""
    # Define HMM parameters
    startprob = np.array([0.6, 0.3, 0.1])
    transmat = np.array([
        [0.7, 0.2, 0.1],
        [0.3, 0.6, 0.1],
        [0.2, 0.2, 0.6]
    ])
    emissionprob = np.array([
        [0.8, 0.2],  # State 1: 20% success rate
        [0.5, 0.5],  # State 2: 50% success rate
        [0.2, 0.8]   # State 3: 80% success rate
    ])
    
    # Generate state sequence
    states = np.zeros(n_samples, dtype=int)
    states[0] = np.random.choice(n_states, p=startprob)
    
    for t in range(1, n_samples):
        states[t] = np.random.choice(n_states, p=transmat[states[t-1]])
    
    # Generate observations
    X = np.zeros((n_samples, 2))  # Simple features
    y = np.zeros(n_samples, dtype=int)
    
    for t in range(n_samples):
        # Generate outcome
        y[t] = np.random.choice(2, p=emissionprob[states[t]])
        
        # Generate features
        if states[t] == 0:
            X[t] = np.random.normal([-1, -1], [0.5, 0.5])
        elif states[t] == 1:
            X[t] = np.random.normal([0, 0], [0.5, 0.5])
        else:
            X[t] = np.random.normal([1, 1], [0.5, 0.5])
    
    return X, y, states

# Generate data
print("Generating synthetic data...")
X, y, true_states = generate_hmm_data(n_samples=1000, n_states=3)

# Split data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Create HMM-GLM model
print("Creating HMM-GLM model...")
hmm_comp = CategoricalHMMComponent(n_states=3, n_categories=2, random_state=42)
glm_comp = LogisticGLMComponent(random_state=42)
model = HMMGLMModel(hmm_component=hmm_comp, glm_component=glm_comp)

# Fit the model
print("Fitting HMM-GLM model...")
model.fit(y_train, X_train, y_train)

# Evaluate the model
print("Evaluating HMM-GLM model...")
metrics = evaluate_hmm_glm_model(model, y_test, X_test, y_test)
print("HMM-GLM metrics:")
for metric, value in metrics.items():
    print(f"  {metric}: {value:.4f}")

# Compare with baseline models
print("\nComparing with baseline models...")
results = compare_with_baseline(model, y_test, X_test, y_test)

print("Logistic regression metrics:")
for metric, value in results['logistic'].items():
    print(f"  {metric}: {value:.4f}")

print("\nBernoulli baseline metrics:")
for metric, value in results['bernoulli'].items():
    print(f"  {metric}: {value:.4f}")

print("\nDelta log-likelihood:")
for model_name, value in results['delta_ll'].items():
    print(f"  {model_name}: {value:.4f}")

# Visualize the model
print("\nVisualizing the model...")
fig1, ax1 = plot_state_transition_matrix(model, title="State Transition Matrix")
fig2, ax2 = plot_emission_probabilities(model, title="Emission Probabilities by State")

plt.show()

print("\nDone!")


