"""
NHL Data Converter

This module provides utilities to convert raw NHL data crawled from NHL Stats API
and Hockey-Reference into formats compatible with the HMM-GLM framework.
"""

import os
import pandas as pd
import numpy as np
from typing import Dict, List, Any, Optional, Union, Tuple
import logging

# Setup logging
logger = logging.getLogger(__name__)


class NHLDataConverter:
    """
    Converter for NHL data to transform it into HMM-GLM compatible format.
    
    This class handles the conversion of raw NHL play-by-play data into
    the standardized format required by the HMM-GLM models, including
    feature engineering and sequence creation.
    """
    
    def __init__(self):
        """Initialize the NHL data converter."""
        pass
    
    def load_raw_data(self, data_path: str) -> pd.DataFrame:
        """
        Load raw NHL data from CSV files.
        
        Args:
            data_path: Path to the raw data CSV file or directory
            
        Returns:
            DataFrame containing the loaded data
        """
        if os.path.isdir(data_path):
            # If directory, find and load all relevant CSV files
            csv_files = [
                os.path.join(data_path, f) for f in os.listdir(data_path)
                if f.endswith('.csv') and ('pbp_' in f or 'shots' in f)
            ]
            
            if not csv_files:
                logger.warning(f"No relevant CSV files found in {data_path}")
                return pd.DataFrame()
            
            # Load and concatenate all files
            dfs = []
            for file in csv_files:
                try:
                    df = pd.read_csv(file)
                    dfs.append(df)
                except Exception as e:
                    logger.error(f"Error loading {file}: {e}")
            
            if dfs:
                return pd.concat(dfs, ignore_index=True)
            else:
                return pd.DataFrame()
        
        elif os.path.isfile(data_path) and data_path.endswith('.csv'):
            # Load single CSV file
            try:
                return pd.read_csv(data_path)
            except Exception as e:
                logger.error(f"Error loading {data_path}: {e}")
                return pd.DataFrame()
        
        else:
            logger.error(f"Invalid data path: {data_path}")
            return pd.DataFrame()
    
    def preprocess_pbp_data(self, df: pd.DataFrame) -> pd.DataFrame:
        """
        Preprocess raw play-by-play data.
        
        Args:
            df: Raw play-by-play DataFrame
            
        Returns:
            Preprocessed DataFrame
        """
        if df.empty:
            return df
        
        # Create a copy to avoid modifying the original
        processed_df = df.copy()
        
        # Filter to shot events only
        shot_events = ['Shot', 'Goal', 'Missed Shot', 'Blocked Shot']
        if 'event_type' in processed_df.columns:
            processed_df = processed_df[processed_df['event_type'].isin(shot_events)]
        
        # Handle missing values
        numeric_cols = processed_df.select_dtypes(include=[np.number]).columns
        processed_df[numeric_cols] = processed_df[numeric_cols].fillna(0)
        
        # Create binary outcome variable
        if 'event_type' in processed_df.columns:
            processed_df['is_goal'] = (processed_df['event_type'] == 'Goal').astype(int)
            processed_df['is_blocked'] = (processed_df['event_type'] == 'Blocked Shot').astype(int)
            processed_df['is_missed'] = (processed_df['event_type'] == 'Missed Shot').astype(int)
            processed_df['is_shot_on_goal'] = ((processed_df['event_type'] == 'Shot') | 
                                              (processed_df['event_type'] == 'Goal')).astype(int)
        
        # Create period features
        if 'period' in processed_df.columns:
            # Create binary indicators for each period
            for i in range(1, 4):
                processed_df[f'period_{i}'] = (processed_df['period'] == i).astype(int)
            
            # Create overtime indicator
            processed_df['is_overtime'] = (processed_df['period'] > 3).astype(int)
        
        # Create time features
        if 'period' in processed_df.columns and 'period_time' in processed_df.columns:
            # Convert time string (MM:SS) to seconds elapsed in period
            processed_df['period_seconds'] = processed_df['period_time'].apply(
                lambda x: int(x.split(':')[0]) * 60 + int(x.split(':')[1]) if isinstance(x, str) and ':' in x else 0
            )
            
            # Calculate total seconds elapsed in game
            processed_df['game_seconds'] = (
                (processed_df['period'] - 1) * 1200  # 20 minutes per period
                + processed_df['period_seconds']
            )
            
            # Calculate normalized game time (0 to 1)
            total_regulation_seconds = 3 * 1200  # 60 minutes
            processed_df['game_time_norm'] = np.clip(
                processed_df['game_seconds'] / total_regulation_seconds,
                0, 1
            )
        
        # Add shot coordinates if available
        if 'x_coord' in processed_df.columns and 'y_coord' in processed_df.columns:
            # NHL rink is 200 feet long and 85 feet wide
            # The goal is located at x = 89 feet from center (or -89 for the other goal)
            
            # Filter out rows with missing coordinates
            has_coords = (processed_df['x_coord'].notna() & processed_df['y_coord'].notna())
            processed_df = processed_df[has_coords]
            
            # Calculate shot distance if not already present
            if 'shot_distance' not in processed_df.columns:
                # Determine which goal the shot was directed at
                # Simplified calculation: assume shots with positive x are toward the right goal
                # and shots with negative x are toward the left goal
                goal_x = np.where(processed_df['x_coord'] < 0, -89, 89)
                
                # Calculate Euclidean distance to goal
                processed_df['shot_distance'] = np.sqrt(
                    (processed_df['x_coord'] - goal_x)**2 + processed_df['y_coord']**2
                )
            
            # Calculate shot angle if not already present
            if 'shot_angle' not in processed_df.columns:
                goal_x = np.where(processed_df['x_coord'] < 0, -89, 89)
                
                # Calculate angle (in degrees)
                # 0 degrees is straight on, 90 degrees is from the side
                processed_df['shot_angle'] = np.abs(
                    np.arctan2(processed_df['y_coord'], processed_df['x_coord'] - goal_x)
                ) * (180 / np.pi)
        
        return processed_df
    
    def engineer_features(self, df: pd.DataFrame) -> pd.DataFrame:
        """
        Engineer features for HMM-GLM modeling.
        
        Args:
            df: Preprocessed DataFrame
            
        Returns:
            DataFrame with engineered features
        """
        if df.empty:
            return df
        
        # Create a copy to avoid modifying the original
        feature_df = df.copy()
        
        # Create game context features
        
        # 1. Score differential
        if 'home_score' in feature_df.columns and 'away_score' in feature_df.columns:
            feature_df['score_differential'] = feature_df['home_score'] - feature_df['away_score']
            
            # Normalize score differential
            feature_df['score_differential_norm'] = np.clip(
                feature_df['score_differential'] / 3,  # Normalize by typical max lead
                -1, 1  # Clip to [-1, 1]
            )
        
        # 2. Game situation features
        if 'score_differential' in feature_df.columns and 'game_time_norm' in feature_df.columns:
            # Create trailing team indicator
            feature_df['is_trailing'] = (feature_df['score_differential'] < 0).astype(int)
            
            # Create late game indicator (last 10 minutes of regulation)
            feature_df['is_late_game'] = (feature_df['game_time_norm'] > 0.83).astype(int)
            
            # Create high leverage situation indicator
            # (close game in the late stages)
            feature_df['is_high_leverage'] = (
                (np.abs(feature_df['score_differential']) <= 1) &  # Within 1 goal
                (feature_df['game_time_norm'] > 0.9)  # Last 6 minutes
            ).astype(int)
        
        # 3. Shot quality features based on location
        if 'shot_distance' in feature_df.columns and 'shot_angle' in feature_df.columns:
            # Expected goal probability based on distance and angle
            # These are approximate values based on NHL averages
            
            # Base probability starts high for close shots and decreases with distance
            base_prob = np.clip(0.3 - 0.01 * feature_df['shot_distance'], 0.01, 0.3)
            
            # Adjust for angle (straight-on shots are better)
            angle_factor = np.clip(1 - (feature_df['shot_angle'] / 90) ** 1.5, 0.1, 1)
            
            feature_df['expected_goal_prob'] = base_prob * angle_factor
        
        # 4. Shot type features
        if 'shot_type' in feature_df.columns:
            # One-hot encode shot types
            shot_types = ['Wrist Shot', 'Slap Shot', 'Snap Shot', 'Backhand', 'Tip-In', 'Wrap-around', 'Deflected']
            for shot_type in shot_types:
                feature_df[f'shot_type_{shot_type.lower().replace("-", "_").replace(" ", "_")}'] = (
                    feature_df['shot_type'] == shot_type
                ).astype(int)
        
        # 5. Preceding event features
        preceding_cols = ['preceding_event_1', 'preceding_event_2', 'preceding_event_3']
        if all(col in feature_df.columns for col in preceding_cols):
            # Create indicators for key preceding events
            key_events = ['Faceoff', 'Takeaway', 'Giveaway', 'Hit']
            
            for event in key_events:
                feature_df[f'preceded_by_{event.lower()}'] = (
                    feature_df['preceding_event_1'].str.contains(event, case=False, na=False) |
                    feature_df['preceding_event_2'].str.contains(event, case=False, na=False)
                ).astype(int)
            
            # Create rebound indicator if not present
            if 'is_rebound' not in feature_df.columns:
                feature_df['is_rebound'] = (
                    feature_df['preceding_event_1'].str.contains('Shot|Missed Shot|Blocked Shot', case=False, na=False)
                ).astype(int)
        
        # 6. Goalie impact features
        if 'goalie_id' in feature_df.columns:
            # Calculate shots per goalie
            goalie_shots = feature_df.groupby('goalie_id').size()
            
            # Map back to DataFrame
            feature_df['goalie_shots'] = feature_df['goalie_id'].map(goalie_shots).fillna(0)
            
            # Calculate save percentage per goalie (if we have enough data)
            if 'is_goal' in feature_df.columns:
                goalie_goals = feature_df.groupby('goalie_id')['is_goal'].sum()
                goalie_save_pct = 1 - goalie_goals / goalie_shots
                
                # Map back to DataFrame
                feature_df['goalie_save_pct'] = feature_df['goalie_id'].map(goalie_save_pct).fillna(0.9)  # Default to league average
        
        return feature_df
    
    def create_sequences(self, df: pd.DataFrame, 
                         group_by: List[str] = ['game_id', 'player1_id'],
                         sort_by: List[str] = ['game_id', 'event_idx'],
                         min_sequence_length: int = 3,
                         max_time_between: Optional[int] = 300) -> pd.DataFrame:
        """
        Create sequences for HMM-GLM modeling.
        
        Args:
            df: DataFrame with engineered features
            group_by: Columns to group by to create sequences
            sort_by: Columns to sort by within sequences
            min_sequence_length: Minimum sequence length to include
            max_time_between: Maximum seconds between events in a sequence (optional)
            
        Returns:
            DataFrame with sequence IDs and positions
        """
        if df.empty:
            return df
        
        # Create a copy to avoid modifying the original
        seq_df = df.copy()
        
        # Ensure all group_by and sort_by columns exist
        missing_cols = [col for col in group_by + sort_by if col not in seq_df.columns]
        if missing_cols:
            logger.warning(f"Missing columns for sequence creation: {missing_cols}")
            # Create dummy columns for missing ones
            for col in missing_cols:
                seq_df[col] = 0
        
        # Sort the DataFrame
        seq_df = seq_df.sort_values(by=sort_by)
        
        # Create sequence IDs and positions
        seq_df['sequence_id'] = 0
        seq_df['sequence_pos'] = 0
        
        # Group by the specified columns
        current_seq_id = 1
        
        for _, group_df in seq_df.groupby(group_by):
            if len(group_df) < min_sequence_length:
                continue
            
            # Sort group by event number
            group_df = group_df.sort_values(by=sort_by)
            
            # Initialize sequence tracking
            current_seq_pos = 0
            current_seq_start_idx = group_df.index[0]
            last_event_time = None
            
            for i, (idx, row) in enumerate(group_df.iterrows()):
                # Check if we need to start a new sequence due to time gap
                if max_time_between is not None and 'game_seconds' in row and last_event_time is not None:
                    time_gap = abs(row['game_seconds'] - last_event_time)
                    if time_gap > max_time_between:
                        # If current sequence is long enough, save it
                        if current_seq_pos >= min_sequence_length:
                            current_seq_id += 1
                        
                        # Start new sequence
                        current_seq_pos = 0
                        current_seq_start_idx = idx
                
                # Update sequence position
                seq_df.loc[idx, 'sequence_id'] = current_seq_id
                seq_df.loc[idx, 'sequence_pos'] = current_seq_pos
                current_seq_pos += 1
                
                # Update last event time
                if 'game_seconds' in row:
                    last_event_time = row['game_seconds']
            
            # Check if last sequence meets minimum length
            if current_seq_pos >= min_sequence_length:
                current_seq_id += 1
            else:
                # Reset sequence ID for the last short sequence
                seq_df.loc[seq_df['sequence_id'] == current_seq_id, 'sequence_id'] = 0
        
        # Filter out rows not part of valid sequences
        seq_df = seq_df[seq_df['sequence_id'] > 0]
        
        return seq_df
    
    def convert_to_hmm_glm_format(self, df: pd.DataFrame, 
                                  feature_cols: Optional[List[str]] = None,
                                  outcome_col: str = 'is_goal',
                                  sequence_id_col: str = 'sequence_id',
                                  sequence_pos_col: str = 'sequence_pos') -> Tuple[np.ndarray, np.ndarray, np.ndarray, List[int]]:
        """
        Convert DataFrame to HMM-GLM input format.
        
        Args:
            df: DataFrame with sequences and features
            feature_cols: List of feature column names (if None, use all numeric columns)
            outcome_col: Name of the outcome column
            sequence_id_col: Name of the sequence ID column
            sequence_pos_col: Name of the sequence position column
            
        Returns:
            Tuple of (X, y, sequences, lengths):
                X: Feature matrix [n_samples, n_features]
                y: Outcome vector [n_samples]
                sequences: Sequence IDs for each sample
                lengths: Length of each sequence
        """
        if df.empty:
            return np.array([]), np.array([]), np.array([]), []
        
        # Select feature columns
        if feature_cols is None:
            # Use all numeric columns except specific ones
            exclude_cols = [outcome_col, sequence_id_col, sequence_pos_col, 
                           'game_id', 'event_idx', 'player1_id', 'is_goal', 'is_blocked', 'is_missed']
            feature_cols = [col for col in df.select_dtypes(include=[np.number]).columns 
                          if col not in exclude_cols]
        
        # Check if all columns exist
        missing_cols = [col for col in feature_cols + [outcome_col, sequence_id_col, sequence_pos_col] 
                       if col not in df.columns]
        if missing_cols:
            logger.error(f"Missing required columns: {missing_cols}")
            return np.array([]), np.array([]), np.array([]), []
        
        # Extract features, outcomes, and sequence info
        X = df[feature_cols].values
        y = df[outcome_col].values
        sequences = df[sequence_id_col].values
        
        # Calculate sequence lengths
        unique_sequences = df[sequence_id_col].unique()
        lengths = [len(df[df[sequence_id_col] == seq_id]) for seq_id in unique_sequences]
        
        return X, y, sequences, lengths


def convert_nhl_data(raw_data_path: str, output_path: Optional[str] = None, 
                     min_sequence_length: int = 3) -> Tuple[pd.DataFrame, Dict[str, Any]]:
    """
    Convert raw NHL data to HMM-GLM compatible format.
    
    Args:
        raw_data_path: Path to raw NHL data (CSV file or directory)
        output_path: Path to save the converted data (optional)
        min_sequence_length: Minimum sequence length to include
        
    Returns:
        Tuple of (converted_df, metadata):
            converted_df: DataFrame with converted data
            metadata: Dictionary with metadata about the conversion
    """
    # Initialize converter and metadata
    converter = NHLDataConverter()
    metadata = {
        'source': raw_data_path,
        'min_sequence_length': min_sequence_length,
        'n_original_rows': 0,
        'n_converted_rows': 0,
        'n_sequences': 0,
        'feature_columns': [],
        'outcome_column': 'is_goal'
    }
    
    # Load raw data
    raw_df = converter.load_raw_data(raw_data_path)
    if raw_df.empty:
        logger.error(f"No data loaded from {raw_data_path}")
        return pd.DataFrame(), metadata
    
    metadata['n_original_rows'] = len(raw_df)
    
    # Preprocess data
    processed_df = converter.preprocess_pbp_data(raw_df)
    
    # Engineer features
    feature_df = converter.engineer_features(processed_df)
    
    # Create sequences
    sequence_df = converter.create_sequences(
        feature_df,
        group_by=['game_id', 'player1_id'] if 'player1_id' in feature_df.columns else ['game_id', 'shooter_id'],
        sort_by=['game_id', 'event_idx'] if 'event_idx' in feature_df.columns else ['game_id', 'game_seconds'],
        min_sequence_length=min_sequence_length,
        max_time_between=300  # 5 minutes max between shots in a sequence
    )
    
    metadata['n_converted_rows'] = len(sequence_df)
    metadata['n_sequences'] = sequence_df['sequence_id'].nunique()
    
    # Identify feature columns
    numeric_cols = sequence_df.select_dtypes(include=[np.number]).columns
    exclude_cols = ['sequence_id', 'sequence_pos', 'is_goal', 'is_blocked', 'is_missed',
                   'game_id', 'event_idx', 'player1_id', 'shooter_id']
    feature_cols = [col for col in numeric_cols if col not in exclude_cols]
    
    metadata['feature_columns'] = feature_cols
    
    # Save converted data if requested
    if output_path:
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        sequence_df.to_csv(output_path, index=False)
        
        # Save metadata
        metadata_path = output_path.replace('.csv', '_metadata.json')
        import json
        with open(metadata_path, 'w') as f:
            json.dump({k: v for k, v in metadata.items() if isinstance(v, (str, int, float, bool, list))}, f, indent=2)
    
    return sequence_df, metadata


if __name__ == "__main__":
    # Example usage
    import sys
    
    if len(sys.argv) > 1:
        raw_data_path = sys.argv[1]
        output_path = sys.argv[2] if len(sys.argv) > 2 else "data/nhl/nhl_converted.csv"
        
        print(f"Converting NHL data from {raw_data_path} to {output_path}")
        df, metadata = convert_nhl_data(raw_data_path, output_path)
        
        print(f"Conversion complete. Converted {metadata['n_original_rows']} rows into {metadata['n_converted_rows']} rows.")
        print(f"Created {metadata['n_sequences']} sequences.")
    else:
        print("Usage: python nhl_converter.py <raw_data_path> [output_path]")

