#!/usr/bin/env python
"""
NHL Analysis Experiment

This script performs HMM-GLM analysis on NHL data with goalie impact adjustments.
"""

import os
import sys
import argparse
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datetime import datetime

# Add project root to path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))

from src.data import load_nhl_data, extract_shots_data
from src.core.hmm_glm import CategoricalHMMComponent, LogisticGLMComponent, HMMGLMModel
from src.evaluation import evaluate_model, plot_confusion_matrix, compare_with_baseline
from src.evaluation import plot_state_transitions, plot_feature_importance
from src.nhl_adv_seq.goalie_impact import (
    model_goalie_save_probability,
    adjust_shots_for_goalie_effect,
    calculate_goalie_quality_index,
    build_integrated_model_with_goalie_effects
)


def parse_args():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(description='NHL Analysis Experiment')
    parser.add_argument('--data-path', type=str, default='data',
                        help='Path to data directory')
    parser.add_argument('--seasons', type=str, nargs='+', default=['2022', '2023'],
                        help='Seasons to analyze')
    parser.add_argument('--n-states', type=int, default=3,
                        help='Number of hidden states')
    parser.add_argument('--min-shots', type=int, default=50,
                        help='Minimum number of shots for a player to be included')
    parser.add_argument('--output-dir', type=str, default='results/nhl',
                        help='Output directory for results')
    parser.add_argument('--adjust-for-goalie', action='store_true',
                        help='Adjust for goalie effects')
    
    return parser.parse_args()


def preprocess_shots(shots_df):
    """Preprocess shot data."""
    # Add is_goal column if not present
    if 'is_goal' not in shots_df.columns:
        shots_df['is_goal'] = (shots_df['event_type'] == 'GOAL').astype(int)
    
    # Add shot coordinates if not present
    if 'x' in shots_df.columns and 'y' in shots_df.columns:
        shots_df['x_coord'] = shots_df['x']
        shots_df['y_coord'] = shots_df['y']
    
    # Calculate shot distance and angle if not present
    if 'shot_distance' not in shots_df.columns and 'x_coord' in shots_df.columns:
        # Determine which goal the shot was directed at
        goal_x = np.where(shots_df['x_coord'] < 0, -89, 89)
        
        # Calculate Euclidean distance to goal
        shots_df['shot_distance'] = np.sqrt(
            (shots_df['x_coord'] - goal_x)**2 + shots_df['y_coord']**2
        )
    
    if 'shot_angle' not in shots_df.columns and 'x_coord' in shots_df.columns:
        goal_x = np.where(shots_df['x_coord'] < 0, -89, 89)
        
        # Calculate angle (in degrees)
        shots_df['shot_angle'] = np.abs(
            np.arctan2(shots_df['y_coord'], shots_df['x_coord'] - goal_x)
        ) * (180 / np.pi)
    
    # Add game context features
    if 'period' in shots_df.columns and 'period_time' in shots_df.columns:
        # Convert period_time to seconds if it's a string
        if isinstance(shots_df['period_time'].iloc[0], str):
            shots_df['period_seconds'] = shots_df['period_time'].apply(
                lambda x: int(x.split(':')[0]) * 60 + int(x.split(':')[1])
            )
        
        # Calculate game progress
        shots_df['game_progress'] = (
            (shots_df['period'] - 1) * 1200 + shots_df['period_seconds']
        ) / (3 * 1200)  # 3 periods, 20 minutes each
        
        # Clip to [0, 1] for overtime
        shots_df['game_progress'] = np.clip(shots_df['game_progress'], 0, 1)
    
    # Add score differential if available
    if 'goals_home' in shots_df.columns and 'goals_away' in shots_df.columns:
        shots_df['score_differential'] = shots_df['goals_home'] - shots_df['goals_away']
        
        # Normalize score differential
        shots_df['score_differential_norm'] = np.clip(
            shots_df['score_differential'] / 3,  # Normalize by typical max lead
            -1, 1  # Clip to [-1, 1]
        )
    
    return shots_df


def create_player_datasets(df, min_shots=50, adjust_for_goalie=False):
    """Create datasets for each player with sufficient shots."""
    # Preprocess shots
    shots_df = preprocess_shots(df)
    
    # Adjust for goalie effects if requested
    if adjust_for_goalie:
        print("Adjusting for goalie effects...")
        
        # Model goalie save probability
        goalie_model = model_goalie_save_probability(shots_df)
        
        # Adjust shots for goalie effect
        shots_df = adjust_shots_for_goalie_effect(shots_df, goalie_model)
        
        # Calculate goalie quality index
        goalie_quality = calculate_goalie_quality_index(shots_df)
        
        # Add goalie quality to shots
        shots_df = pd.merge(
            shots_df,
            goalie_quality[['goalie_id', 'quality_index']],
            on='goalie_id',
            how='left'
        )
        
        # Fill missing goalie quality with average
        shots_df['quality_index'] = shots_df['quality_index'].fillna(0)
    
    # Group by player
    player_column = 'shooter_id' if 'shooter_id' in shots_df.columns else 'player1_id'
    player_name_column = 'shooter_name' if 'shooter_name' in shots_df.columns else 'player1_name'
    
    player_groups = shots_df.groupby(player_column)
    
    # Filter players with sufficient shots
    player_counts = player_groups.size()
    qualified_players = player_counts[player_counts >= min_shots].index
    
    print(f"Found {len(qualified_players)} players with at least {min_shots} shots")
    
    # Create datasets
    player_datasets = {}
    for player_id in qualified_players:
        player_df = shots_df[shots_df[player_column] == player_id].copy()
        
        # Create sequences based on game_id
        game_column = 'game_id' if 'game_id' in player_df.columns else 'game_pk'
        player_df['sequence_id'] = player_df.groupby(game_column).ngroup()
        
        # Extract features
        feature_columns = [
            'shot_distance', 'shot_angle',
            'game_progress', 'score_differential_norm'
        ]
        
        # Add goalie quality if available
        if 'quality_index' in player_df.columns:
            feature_columns.append('quality_index')
        
        # Add shot type if available
        shot_type_columns = [col for col in player_df.columns if col.startswith('shot_type_')]
        feature_columns.extend(shot_type_columns)
        
        # Extract features and target
        X = player_df[feature_columns].values
        y = player_df['is_goal'].values
        sequences = player_df['sequence_id'].values
        
        player_datasets[player_id] = {
            'X': X,
            'y': y,
            'sequences': sequences,
            'player_name': player_df[player_name_column].iloc[0],
            'n_shots': len(player_df),
            'goal_rate': player_df['is_goal'].mean(),
            'feature_names': feature_columns
        }
    
    return player_datasets


def run_analysis(player_datasets, n_states=3, output_dir='results/nhl'):
    """Run HMM-GLM analysis for each player."""
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Initialize results
    results = []
    
    # Analyze each player
    for player_id, data in player_datasets.items():
        print(f"Analyzing player {data['player_name']} (ID: {player_id})")
        
        try:
            # Create and fit model
            hmm_comp = CategoricalHMMComponent(n_states=n_states, n_categories=2)
            glm_comp = LogisticGLMComponent()
            model = HMMGLMModel(hmm_component=hmm_comp, glm_component=glm_comp)
            
            model.fit(data['X'], data['y'], sequences=data['sequences'])
            
            # Evaluate model
            metrics = evaluate_model(model, data['X'], data['y'], sequences=data['sequences'])
            
            # Compare with baseline
            baseline_comparison = compare_with_baseline(
                model, data['X'], data['y'], sequences=data['sequences'],
                baseline_type='logistic'
            )
            
            # Get state sequences
            states = model.predict_states(data['X'], sequences=data['sequences'])
            state_counts = np.bincount(states, minlength=n_states)
            state_proportions = state_counts / len(states)
            
            # Calculate goal rate by state
            goal_rates = []
            for state in range(n_states):
                state_mask = (states == state)
                if np.sum(state_mask) > 0:
                    goal_rate = np.mean(data['y'][state_mask])
                    goal_rates.append(goal_rate)
                else:
                    goal_rates.append(np.nan)
            
            # Add to results
            results.append({
                'player_id': player_id,
                'player_name': data['player_name'],
                'n_shots': data['n_shots'],
                'goal_rate': data['goal_rate'],
                'accuracy': metrics['accuracy'],
                'auc': metrics['auc'],
                'brier_score': metrics['brier_score'],
                'delta_loglikelihood': metrics['delta_loglikelihood'],
                'state_diversity': metrics['state_diversity'],
                'baseline_accuracy': baseline_comparison.loc[baseline_comparison['model'] == 'Logistic Regression', 'accuracy'].values[0],
                'baseline_auc': baseline_comparison.loc[baseline_comparison['model'] == 'Logistic Regression', 'auc'].values[0],
                'state_proportions': state_proportions.tolist(),
                'state_goal_rates': goal_rates
            })
            
            # Create visualizations
            fig, axes = plt.subplots(2, 2, figsize=(12, 10))
            
            # Plot confusion matrix
            y_pred = model.predict(data['X'], sequences=data['sequences'])
            plot_confusion_matrix(data['y'], y_pred, ax=axes[0, 0])
            
            # Plot state transitions
            plot_state_transitions(model.hmm_component.model.transmat_, ax=axes[0, 1])
            
            # Plot goal rate by state
            axes[1, 0].bar(range(n_states), goal_rates)
            axes[1, 0].set_xlabel('State')
            axes[1, 0].set_ylabel('Goal Rate')
            axes[1, 0].set_title('Goal Rate by State')
            
            # Plot state proportions
            axes[1, 1].bar(range(n_states), state_proportions)
            axes[1, 1].set_xlabel('State')
            axes[1, 1].set_ylabel('Proportion')
            axes[1, 1].set_title('State Proportions')
            
            # Save figure
            fig.suptitle(f"Player: {data['player_name']} (ID: {player_id})")
            plt.tight_layout(rect=[0, 0, 1, 0.95])
            plt.savefig(os.path.join(output_dir, f"player_{player_id}.png"))
            plt.close()
            
            # Feature importance
            fig = plot_feature_importance(model, feature_names=data['feature_names'])
            fig.suptitle(f"Feature Importance: {data['player_name']}")
            plt.tight_layout(rect=[0, 0, 1, 0.95])
            plt.savefig(os.path.join(output_dir, f"player_{player_id}_features.png"))
            plt.close()
            
        except Exception as e:
            print(f"Error analyzing player {player_id}: {e}")
    
    # Save results
    results_df = pd.DataFrame(results)
    results_df.to_csv(os.path.join(output_dir, 'results.csv'), index=False)
    
    return results_df


def main():
    """Main function."""
    args = parse_args()
    
    # Load data
    print(f"Loading NHL data from {args.data_path}")
    nhl_data = load_nhl_data(args.data_path, seasons=args.seasons)
    
    # Extract shot data
    shots_df = extract_shots_data(nhl_data, sport='nhl')
    
    # Create player datasets
    player_datasets = create_player_datasets(
        shots_df, 
        min_shots=args.min_shots,
        adjust_for_goalie=args.adjust_for_goalie
    )
    
    # Run analysis
    results_df = run_analysis(player_datasets, n_states=args.n_states, output_dir=args.output_dir)
    
    # Print summary
    print("\nAnalysis Summary:")
    print(f"Number of players analyzed: {len(results_df)}")
    print(f"Average accuracy: {results_df['accuracy'].mean():.3f}")
    print(f"Average AUC: {results_df['auc'].mean():.3f}")
    print(f"Average improvement over baseline: {(results_df['accuracy'] - results_df['baseline_accuracy']).mean():.3f}")
    
    # Find players with highest improvement
    results_df['improvement'] = results_df['accuracy'] - results_df['baseline_accuracy']
    top_players = results_df.sort_values('improvement', ascending=False).head(5)
    
    print("\nTop 5 players with highest improvement:")
    for _, row in top_players.iterrows():
        print(f"{row['player_name']}: {row['improvement']:.3f} improvement, {row['accuracy']:.3f} accuracy")


if __name__ == '__main__':
    main()
