#!/usr/bin/env python
"""
MLB Analysis Experiment

This script performs HMM-GLM analysis on MLB data.
"""

import os
import sys
import argparse
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datetime import datetime

# Add project root to path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))

from src.data import load_mlb_data, extract_shots_data
from src.core.hmm_glm import CategoricalHMMComponent, LogisticGLMComponent, HMMGLMModel
from src.evaluation import evaluate_model, plot_confusion_matrix, compare_with_baseline
from src.evaluation import plot_state_transitions, plot_feature_importance


def parse_args():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(description='MLB Analysis Experiment')
    parser.add_argument('--data-path', type=str, default='data',
                        help='Path to data directory')
    parser.add_argument('--seasons', type=str, nargs='+', default=['2022', '2023'],
                        help='Seasons to analyze')
    parser.add_argument('--n-states', type=int, default=3,
                        help='Number of hidden states')
    parser.add_argument('--min-atbats', type=int, default=50,
                        help='Minimum number of at-bats for a player to be included')
    parser.add_argument('--output-dir', type=str, default='results/mlb',
                        help='Output directory for results')
    
    return parser.parse_args()


def create_player_datasets(df, min_atbats=50):
    """Create datasets for each player with sufficient at-bats."""
    # Group by player
    player_groups = df.groupby('batter_id')
    
    # Filter players with sufficient at-bats
    player_counts = player_groups.size()
    qualified_players = player_counts[player_counts >= min_atbats].index
    
    print(f"Found {len(qualified_players)} players with at least {min_atbats} at-bats")
    
    # Create datasets
    player_datasets = {}
    for player_id in qualified_players:
        player_df = df[df['batter_id'] == player_id].copy()
        
        # Create sequences based on game_id
        player_df['sequence_id'] = player_df.groupby('game_id').ngroup()
        
        # Extract features and target
        X = player_df[[
            'count_balls', 'count_strikes', 'outs_when_up',
            'inning', 'score_differential_norm', 'game_progress'
        ]].values
        
        y = player_df['is_hit'].values
        sequences = player_df['sequence_id'].values
        
        player_datasets[player_id] = {
            'X': X,
            'y': y,
            'sequences': sequences,
            'player_name': player_df['batter_name'].iloc[0],
            'n_atbats': len(player_df),
            'hit_rate': player_df['is_hit'].mean()
        }
    
    return player_datasets


def run_analysis(player_datasets, n_states=3, output_dir='results/mlb'):
    """Run HMM-GLM analysis for each player."""
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Initialize results
    results = []
    
    # Analyze each player
    for player_id, data in player_datasets.items():
        print(f"Analyzing player {data['player_name']} (ID: {player_id})")
        
        try:
            # Create and fit model
            hmm_comp = CategoricalHMMComponent(n_states=n_states, n_categories=2)
            glm_comp = LogisticGLMComponent()
            model = HMMGLMModel(hmm_component=hmm_comp, glm_component=glm_comp)
            
            model.fit(data['X'], data['y'], sequences=data['sequences'])
            
            # Evaluate model
            metrics = evaluate_model(model, data['X'], data['y'], sequences=data['sequences'])
            
            # Compare with baseline
            baseline_comparison = compare_with_baseline(
                model, data['X'], data['y'], sequences=data['sequences'],
                baseline_type='logistic'
            )
            
            # Get state sequences
            states = model.predict_states(data['X'], sequences=data['sequences'])
            state_counts = np.bincount(states, minlength=n_states)
            state_proportions = state_counts / len(states)
            
            # Calculate hit rate by state
            hit_rates = []
            for state in range(n_states):
                state_mask = (states == state)
                if np.sum(state_mask) > 0:
                    hit_rate = np.mean(data['y'][state_mask])
                    hit_rates.append(hit_rate)
                else:
                    hit_rates.append(np.nan)
            
            # Add to results
            results.append({
                'player_id': player_id,
                'player_name': data['player_name'],
                'n_atbats': data['n_atbats'],
                'hit_rate': data['hit_rate'],
                'accuracy': metrics['accuracy'],
                'auc': metrics['auc'],
                'brier_score': metrics['brier_score'],
                'delta_loglikelihood': metrics['delta_loglikelihood'],
                'state_diversity': metrics['state_diversity'],
                'baseline_accuracy': baseline_comparison.loc[baseline_comparison['model'] == 'Logistic Regression', 'accuracy'].values[0],
                'baseline_auc': baseline_comparison.loc[baseline_comparison['model'] == 'Logistic Regression', 'auc'].values[0],
                'state_proportions': state_proportions.tolist(),
                'state_hit_rates': hit_rates
            })
            
            # Create visualizations
            fig, axes = plt.subplots(2, 2, figsize=(12, 10))
            
            # Plot confusion matrix
            y_pred = model.predict(data['X'], sequences=data['sequences'])
            plot_confusion_matrix(data['y'], y_pred, ax=axes[0, 0])
            
            # Plot state transitions
            plot_state_transitions(model.hmm_component.model.transmat_, ax=axes[0, 1])
            
            # Plot hit rate by state
            axes[1, 0].bar(range(n_states), hit_rates)
            axes[1, 0].set_xlabel('State')
            axes[1, 0].set_ylabel('Hit Rate')
            axes[1, 0].set_title('Hit Rate by State')
            
            # Plot state proportions
            axes[1, 1].bar(range(n_states), state_proportions)
            axes[1, 1].set_xlabel('State')
            axes[1, 1].set_ylabel('Proportion')
            axes[1, 1].set_title('State Proportions')
            
            # Save figure
            fig.suptitle(f"Player: {data['player_name']} (ID: {player_id})")
            plt.tight_layout(rect=[0, 0, 1, 0.95])
            plt.savefig(os.path.join(output_dir, f"player_{player_id}.png"))
            plt.close()
            
        except Exception as e:
            print(f"Error analyzing player {player_id}: {e}")
    
    # Save results
    results_df = pd.DataFrame(results)
    results_df.to_csv(os.path.join(output_dir, 'results.csv'), index=False)
    
    return results_df


def main():
    """Main function."""
    args = parse_args()
    
    # Load data
    print(f"Loading MLB data from {args.data_path}")
    mlb_data = load_mlb_data(args.data_path, seasons=args.seasons)
    
    # Extract at-bat data
    atbats_df = extract_shots_data(mlb_data, sport='mlb')
    
    # Create player datasets
    player_datasets = create_player_datasets(atbats_df, min_atbats=args.min_atbats)
    
    # Run analysis
    results_df = run_analysis(player_datasets, n_states=args.n_states, output_dir=args.output_dir)
    
    # Print summary
    print("\nAnalysis Summary:")
    print(f"Number of players analyzed: {len(results_df)}")
    print(f"Average accuracy: {results_df['accuracy'].mean():.3f}")
    print(f"Average AUC: {results_df['auc'].mean():.3f}")
    print(f"Average improvement over baseline: {(results_df['accuracy'] - results_df['baseline_accuracy']).mean():.3f}")
    
    # Find players with highest improvement
    results_df['improvement'] = results_df['accuracy'] - results_df['baseline_accuracy']
    top_players = results_df.sort_values('improvement', ascending=False).head(5)
    
    print("\nTop 5 players with highest improvement:")
    for _, row in top_players.iterrows():
        print(f"{row['player_name']}: {row['improvement']:.3f} improvement, {row['accuracy']:.3f} accuracy")


if __name__ == '__main__':
    main()
