"""
NHL-specific adjustments example.

This script demonstrates how to apply NHL-specific adjustments to account for
goalie influence in hockey shot data.
"""

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split

from src.features.nhl_specific import (
    model_goalie_save_probability,
    adjust_shooter_performance,
    calculate_goalie_quality_index,
    integrate_goalie_adjustments
)

# Set random seed for reproducibility
np.random.seed(42)

# Generate synthetic NHL shot data
def generate_nhl_shot_data(n_shots=1000, n_players=50, n_goalies=30):
    """Generate synthetic NHL shot data."""
    # Generate player IDs
    shooter_ids = np.random.choice(range(8000000, 8000000 + n_players), n_shots)
    goalie_ids = np.random.choice(range(9000000, 9000000 + n_goalies), n_shots)
    
    # Generate game and shot IDs
    game_ids = np.random.choice(range(2000000, 2001000), n_shots)
    shot_ids = np.arange(1, n_shots + 1)
    
    # Generate shot features
    shot_distance = np.random.gamma(shape=2.0, scale=10.0, size=n_shots)  # Distance in feet
    shot_angle = np.random.uniform(-90, 90, n_shots)  # Angle in degrees
    
    # Generate shot types
    shot_types = np.random.choice(['WRIST', 'SLAP', 'SNAP', 'BACKHAND'], n_shots)
    
    # Generate game context
    period = np.random.choice([1, 2, 3], n_shots)
    period_seconds = np.random.uniform(0, 1200, n_shots)
    period_seconds_remaining = 1200 - period_seconds
    score_differential = np.random.randint(-3, 4, n_shots)
    
    # Generate additional features
    rebound = np.random.choice([0, 1], n_shots, p=[0.9, 0.1])
    rush_shot = np.random.choice([0, 1], n_shots, p=[0.8, 0.2])
    
    # Assign goalie skill levels (random effects)
    goalie_skill = {}
    for goalie_id in range(9000000, 9000000 + n_goalies):
        goalie_skill[goalie_id] = np.random.normal(0, 0.5)
    
    # Assign shooter skill levels
    shooter_skill = {}
    for shooter_id in range(8000000, 8000000 + n_players):
        shooter_skill[shooter_id] = np.random.normal(0, 0.5)
    
    # Generate outcomes (goals)
    is_goal = np.zeros(n_shots, dtype=bool)
    
    for i in range(n_shots):
        # Base probability from shot features
        base_prob = 0.3 - 0.01 * shot_distance[i] - 0.002 * abs(shot_angle[i])
        
        # Adjust for shot type
        if shot_types[i] == 'WRIST':
            base_prob += 0.02
        elif shot_types[i] == 'SLAP':
            base_prob += 0.01
        elif shot_types[i] == 'SNAP':
            base_prob += 0.015
        
        # Adjust for game context
        base_prob += 0.01 * score_differential[i]  # Better when leading
        base_prob -= 0.02 * (period[i] - 2)  # Harder in later periods
        
        # Adjust for additional features
        if rebound[i]:
            base_prob += 0.1
        if rush_shot[i]:
            base_prob += 0.05
        
        # Adjust for shooter and goalie skill
        shooter_effect = shooter_skill.get(shooter_ids[i], 0)
        goalie_effect = goalie_skill.get(goalie_ids[i], 0)
        
        # Final probability
        goal_prob = base_prob + shooter_effect - goalie_effect
        goal_prob = np.clip(goal_prob, 0.01, 0.99)
        
        # Generate outcome
        is_goal[i] = np.random.random() < goal_prob
    
    # Create DataFrame
    df = pd.DataFrame({
        'shot_id': shot_ids,
        'game_id': game_ids,
        'shooter_id': shooter_ids,
        'goalie_id': goalie_ids,
        'shot_distance': shot_distance,
        'shot_angle': shot_angle,
        'shot_type': shot_types,
        'period': period,
        'period_seconds': period_seconds,
        'period_seconds_remaining': period_seconds_remaining,
        'score_differential': score_differential,
        'rebound': rebound,
        'rush_shot': rush_shot,
        'is_goal': is_goal
    })
    
    # Add event type
    df['event_type'] = np.where(df['is_goal'], 'GOAL', 'SHOT')
    
    return df, goalie_skill, shooter_skill

# Generate data
print("Generating synthetic NHL shot data...")
shots_df, true_goalie_skill, true_shooter_skill = generate_nhl_shot_data(n_shots=5000)

# Check goal rate
goal_rate = shots_df['is_goal'].mean()
print(f"Goal rate: {goal_rate:.2%}")

# Split data for evaluation
train_df, test_df = train_test_split(shots_df, test_size=0.2, random_state=42)

# Apply NHL-specific adjustments
print("Applying NHL-specific adjustments...")

# 1. Model goalie save probability
print("1. Modeling goalie save probability...")
model, goalie_effects = model_goalie_save_probability(train_df, min_shots_per_goalie=20)

# 2. Adjust shooter performance
print("2. Adjusting shooter performance...")
adjusted_df = adjust_shooter_performance(train_df, goalie_effects)

# 3. Calculate Goalie Quality Index
print("3. Calculating Goalie Quality Index...")
player_gqi = calculate_goalie_quality_index(train_df, goalie_effects)

# 4. Integrate all adjustments
print("4. Integrating all adjustments...")
integrated_df, _, _ = integrate_goalie_adjustments(train_df, min_shots_per_goalie=20)

# Evaluate adjustments
print("Evaluating adjustments...")

# Compare estimated goalie effects with true goalie skill
goalie_ids = list(goalie_effects.keys())
estimated_effects = np.array([goalie_effects[g] for g in goalie_ids])
true_effects = np.array([true_goalie_skill.get(int(g), 0) for g in goalie_ids])

# Calculate correlation
goalie_corr = np.corrcoef(estimated_effects, true_effects)[0, 1]
print(f"Correlation between estimated goalie effects and true goalie skill: {goalie_corr:.4f}")

# Compare raw and adjusted goal rates
shooter_stats = adjusted_df.groupby('shooter_id').agg({
    'is_goal': 'mean',
    'adjusted_goal_rate': 'first',
    'shot_id': 'count'
}).reset_index()

# Calculate correlation with true shooter skill
shooter_ids = shooter_stats['shooter_id'].values
raw_rates = shooter_stats['is_goal'].values
adjusted_rates = shooter_stats['adjusted_goal_rate'].values
true_skills = np.array([true_shooter_skill.get(sid, 0) for sid in shooter_ids])

raw_corr = np.corrcoef(raw_rates, true_skills)[0, 1]
adjusted_corr = np.corrcoef(adjusted_rates, true_skills)[0, 1]

print(f"Correlation between raw goal rates and true shooter skill: {raw_corr:.4f}")
print(f"Correlation between adjusted goal rates and true shooter skill: {adjusted_corr:.4f}")
print(f"Improvement: {(adjusted_corr - raw_corr) * 100:.2f}%")

# Visualize results
print("Visualizing results...")

# Plot 1: Goalie Effects
plt.figure(figsize=(10, 6))
plt.scatter(true_effects, estimated_effects, alpha=0.7)
plt.plot([-1, 1], [-1, 1], 'r--')
plt.xlabel('True Goalie Skill')
plt.ylabel('Estimated Goalie Effect')
plt.title(f'Goalie Effects (Correlation: {goalie_corr:.4f})')
plt.grid(True, alpha=0.3)

# Plot 2: Raw vs. Adjusted Goal Rates
plt.figure(figsize=(12, 6))

plt.subplot(1, 2, 1)
plt.scatter(true_skills, raw_rates, alpha=0.7)
plt.plot([-1, 1], [0, 0.3], 'r--')
plt.xlabel('True Shooter Skill')
plt.ylabel('Raw Goal Rate')
plt.title(f'Raw Goal Rates (Correlation: {raw_corr:.4f})')
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.scatter(true_skills, adjusted_rates, alpha=0.7)
plt.plot([-1, 1], [0, 0.3], 'r--')
plt.xlabel('True Shooter Skill')
plt.ylabel('Adjusted Goal Rate')
plt.title(f'Adjusted Goal Rates (Correlation: {adjusted_corr:.4f})')
plt.grid(True, alpha=0.3)

plt.tight_layout()

# Plot 3: Goalie Quality Index Distribution
plt.figure(figsize=(10, 6))
sns.histplot(player_gqi['goalie_quality_index'], kde=True)
plt.xlabel('Goalie Quality Index')
plt.ylabel('Count')
plt.title('Distribution of Goalie Quality Index')
plt.grid(True, alpha=0.3)

# Plot 4: Expected vs. Actual Goals
plt.figure(figsize=(10, 6))
plt.scatter(integrated_df['expected_goal'], integrated_df['is_goal'], alpha=0.1)
plt.xlabel('Expected Goal Probability')
plt.ylabel('Actual Goal (1=Yes, 0=No)')
plt.title('Expected vs. Actual Goals')
plt.grid(True, alpha=0.3)

# Add a binned average line
bins = np.linspace(0, 0.5, 20)
bin_centers = (bins[:-1] + bins[1:]) / 2
bin_means = np.zeros_like(bin_centers)

for i in range(len(bins) - 1):
    mask = (integrated_df['expected_goal'] >= bins[i]) & (integrated_df['expected_goal'] < bins[i+1])
    if np.sum(mask) > 0:
        bin_means[i] = np.mean(integrated_df['is_goal'][mask])

plt.plot(bin_centers, bin_means, 'r-', linewidth=2, label='Binned Average')
plt.plot([0, 0.5], [0, 0.5], 'k--', label='Perfect Calibration')
plt.legend()

plt.show()

print("\nDone!")


