# Training and Evaluation

This guide covers the process of training HMM-GLM models and evaluating their performance.

## Training Models

### Basic Training

```python
from src.core.hmm_glm import CategoricalHMMComponent, LogisticGLMComponent, HMMGLMModel

# Create HMM-GLM model
hmm_comp = CategoricalHMMComponent(n_states=3, n_categories=2)
glm_comp = LogisticGLMComponent()
model = HMMGLMModel(hmm_component=hmm_comp, glm_component=glm_comp)

# Fit the model
model.fit(X_train, y_train, sequences=sequences_train)
```

### Training with Class Weights

```python
from src.core.weighting import compute_inverse_frequency_weights

# Compute weights
weights = compute_inverse_frequency_weights(y_train)

# Fit model with weights
model.fit(X_train, y_train, sequences=sequences_train, sample_weight=weights)
```

### Training with Context-Aware Transitions

```python
from src.core.context_transitions import ContextAwareTransitionMatrix

# Create context features (e.g., score differential, time remaining)
context_features = X_train[:, [0, 1, 2]]  # Example: first three features are context

# Create context-aware transition matrix
context_transitions = ContextAwareTransitionMatrix(n_states=3, n_context_features=3)

# Create HMM component with context-aware transitions
hmm_comp = CategoricalHMMComponent(n_states=3, n_categories=2, 
                                 transition_matrix=context_transitions)

# Create and fit model
model = HMMGLMModel(hmm_component=hmm_comp, glm_component=glm_comp)
model.fit(X_train, y_train, sequences=sequences_train, context_features=context_features)
```

## Evaluating Models

### Basic Evaluation

```python
from src.evaluation import evaluate_model

# Make predictions
y_pred = model.predict(X_test, sequences=sequences_test)
y_pred_proba = model.predict_proba(X_test, sequences=sequences_test)

# Evaluate model
metrics = evaluate_model(model, X_test, y_test, sequences=sequences_test)
print(f"Accuracy: {metrics['accuracy']:.3f}")
print(f"AUC: {metrics['auc']:.3f}")
print(f"Brier Score: {metrics['brier_score']:.3f}")
print(f"Delta Log-Likelihood: {metrics['delta_loglikelihood']:.3f}")
print(f"State Diversity: {metrics['state_diversity']:.3f}")
```

### Per-State Metrics

```python
from src.evaluation import calculate_per_state_metrics

# Calculate per-state metrics
state_metrics = calculate_per_state_metrics(model, X_test, y_test, sequences=sequences_test)
print(state_metrics)
```

### Transition Metrics

```python
from src.evaluation import calculate_transition_metrics

# Calculate transition metrics
transition_metrics = calculate_transition_metrics(model, X_test, sequences=sequences_test)
print("Transition Probabilities:")
print(transition_metrics['transition_probabilities'])
print(f"Self-Transition Rate: {transition_metrics['self_transition_rate']:.3f}")
print(f"State Change Rate: {transition_metrics['state_change_rate']:.3f}")
```

## Visualizing Results

### Confusion Matrix

```python
import matplotlib.pyplot as plt
from src.evaluation import plot_confusion_matrix

# Plot confusion matrix
fig, ax = plt.subplots(figsize=(8, 6))
plot_confusion_matrix(y_test, y_pred, ax=ax)
plt.show()
```

### ROC Curve

```python
from src.evaluation import plot_roc_curve

# Plot ROC curve
fig, ax = plt.subplots(figsize=(8, 6))
plot_roc_curve(y_test, y_pred_proba, ax=ax)
plt.show()
```

### State Transitions

```python
from src.evaluation import plot_state_transitions

# Plot state transitions
fig, ax = plt.subplots(figsize=(8, 6))
plot_state_transitions(model.hmm_component.model.transmat_, ax=ax)
plt.show()
```

### Feature Importance

```python
from src.evaluation import plot_feature_importance

# Plot feature importance
feature_names = [f"Feature_{i}" for i in range(X_test.shape[1])]
fig = plot_feature_importance(model, feature_names=feature_names)
plt.show()
```

## Comparing Models

### Multiple Models

```python
from src.evaluation import compare_models

# Create models with different configurations
model1 = HMMGLMModel(
    hmm_component=CategoricalHMMComponent(n_states=2, n_categories=2),
    glm_component=LogisticGLMComponent()
)
model2 = HMMGLMModel(
    hmm_component=CategoricalHMMComponent(n_states=3, n_categories=2),
    glm_component=LogisticGLMComponent()
)
model3 = HMMGLMModel(
    hmm_component=CategoricalHMMComponent(n_states=4, n_categories=2),
    glm_component=LogisticGLMComponent()
)

# Fit models
for model in [model1, model2, model3]:
    model.fit(X_train, y_train, sequences=sequences_train)

# Compare models
model_names = ["2 States", "3 States", "4 States"]
comparison = compare_models([model1, model2, model3], model_names, 
                          X_test, y_test, sequences=sequences_test)
print(comparison)
```

### Baseline Comparison

```python
from src.evaluation import compare_with_baseline

# Compare with baseline
baseline_comparison = compare_with_baseline(model, X_test, y_test, 
                                          sequences=sequences_test,
                                          baseline_type='logistic')
print(baseline_comparison)
```

### Statistical Significance

```python
from src.evaluation import statistical_significance_test
from sklearn.linear_model import LogisticRegression

# Create baseline model
baseline = LogisticRegression().fit(X_train, y_train)

# Perform significance test
test_results = statistical_significance_test(
    model, baseline, X_test, y_test, sequences=sequences_test,
    metric='accuracy', n_bootstrap=1000
)

print(f"p-value: {test_results['p_value']:.4f}")
print(f"Better model: {test_results['better_model']}")
```

## Cross-Validation

```python
from src.evaluation import cross_validation_comparison

# Create models with different configurations
models = [
    HMMGLMModel(
        hmm_component=CategoricalHMMComponent(n_states=i, n_categories=2),
        glm_component=LogisticGLMComponent()
    ) for i in range(2, 5)
]

# Compare models using cross-validation
model_names = ["2 States", "3 States", "4 States"]
results_df, summary_df = cross_validation_comparison(
    models, model_names, X, y, sequences,
    n_splits=5, stratify=True
)

print("Cross-Validation Summary:")
print(summary_df)
```
