"""
NBA Data Converter

This module provides utilities to convert raw NBA data crawled from NBA Stats API
and Basketball-Reference into formats compatible with the HMM-GLM framework.
"""

import os
import pandas as pd
import numpy as np
from typing import Dict, List, Any, Optional, Union, Tuple
import logging

# Setup logging
logger = logging.getLogger(__name__)


class NBADataConverter:
    """
    Converter for NBA data to transform it into HMM-GLM compatible format.
    
    This class handles the conversion of raw NBA play-by-play data into
    the standardized format required by the HMM-GLM models, including
    feature engineering and sequence creation.
    """
    
    def __init__(self):
        """Initialize the NBA data converter."""
        pass
    
    def load_raw_data(self, data_path: str) -> pd.DataFrame:
        """
        Load raw NBA data from CSV files.
        
        Args:
            data_path: Path to the raw data CSV file or directory
            
        Returns:
            DataFrame containing the loaded data
        """
        if os.path.isdir(data_path):
            # If directory, find and load all relevant CSV files
            csv_files = [
                os.path.join(data_path, f) for f in os.listdir(data_path)
                if f.endswith('.csv') and ('pbp_processed' in f or 'merged' in f)
            ]
            
            if not csv_files:
                logger.warning(f"No relevant CSV files found in {data_path}")
                return pd.DataFrame()
            
            # Load and concatenate all files
            dfs = []
            for file in csv_files:
                try:
                    df = pd.read_csv(file)
                    dfs.append(df)
                except Exception as e:
                    logger.error(f"Error loading {file}: {e}")
            
            if dfs:
                return pd.concat(dfs, ignore_index=True)
            else:
                return pd.DataFrame()
        
        elif os.path.isfile(data_path) and data_path.endswith('.csv'):
            # Load single CSV file
            try:
                return pd.read_csv(data_path)
            except Exception as e:
                logger.error(f"Error loading {data_path}: {e}")
                return pd.DataFrame()
        
        else:
            logger.error(f"Invalid data path: {data_path}")
            return pd.DataFrame()
    
    def preprocess_pbp_data(self, df: pd.DataFrame) -> pd.DataFrame:
        """
        Preprocess raw play-by-play data.
        
        Args:
            df: Raw play-by-play DataFrame
            
        Returns:
            Preprocessed DataFrame
        """
        if df.empty:
            return df
        
        # Create a copy to avoid modifying the original
        processed_df = df.copy()
        
        # Filter to shot events only
        shot_events = [1, 2]  # 1=SHOT, 2=MISS
        if 'EVENTMSGTYPE' in processed_df.columns:
            processed_df = processed_df[processed_df['EVENTMSGTYPE'].isin(shot_events)]
        
        # Handle missing values
        numeric_cols = processed_df.select_dtypes(include=[np.number]).columns
        processed_df[numeric_cols] = processed_df[numeric_cols].fillna(0)
        
        # Create binary outcome variable
        if 'EVENTMSGTYPE' in processed_df.columns:
            processed_df['is_made_shot'] = (processed_df['EVENTMSGTYPE'] == 1).astype(int)
        
        # Create shot type features
        if 'HOMEDESCRIPTION' in processed_df.columns and 'VISITORDESCRIPTION' in processed_df.columns:
            # Combine descriptions for easier processing
            processed_df['description'] = processed_df['HOMEDESCRIPTION'].fillna('') + ' ' + processed_df['VISITORDESCRIPTION'].fillna('')
            
            # Extract 3-point shots
            processed_df['is_three_pointer'] = processed_df['description'].str.contains('3PT', case=False, na=False).astype(int)
        
        # Create period features
        if 'PERIOD' in processed_df.columns:
            # Create binary indicators for each quarter
            for i in range(1, 5):
                processed_df[f'quarter_{i}'] = (processed_df['PERIOD'] == i).astype(int)
            
            # Create overtime indicator
            processed_df['is_overtime'] = (processed_df['PERIOD'] > 4).astype(int)
        
        # Create time features
        if 'PERIOD' in processed_df.columns and 'PCTIMESTRING' in processed_df.columns:
            # Convert time string (MM:SS) to seconds remaining in period
            processed_df['period_seconds_remaining'] = processed_df['PCTIMESTRING'].apply(
                lambda x: int(x.split(':')[0]) * 60 + int(x.split(':')[1]) if isinstance(x, str) and ':' in x else 0
            )
            
            # Calculate total seconds remaining in game
            processed_df['game_seconds_remaining'] = (
                (4 - processed_df['PERIOD'].clip(upper=4)) * 12 * 60  # Regular quarters remaining
                + processed_df['period_seconds_remaining']  # Seconds in current period
                + (processed_df['PERIOD'] > 4) * (processed_df['PERIOD'] - 4) * 5 * 60  # Add overtime periods (5 min each)
            )
            
            # Calculate normalized game time (0 to 1)
            total_regulation_seconds = 4 * 12 * 60  # 48 minutes
            processed_df['game_time_norm'] = 1 - (processed_df['game_seconds_remaining'] / total_regulation_seconds).clip(0, 1)
        
        return processed_df
    
    def preprocess_shotchart_data(self, df: pd.DataFrame) -> pd.DataFrame:
        """
        Preprocess raw shot chart data.
        
        Args:
            df: Raw shot chart DataFrame
            
        Returns:
            Preprocessed DataFrame
        """
        if df.empty:
            return df
        
        # Create a copy to avoid modifying the original
        processed_df = df.copy()
        
        # Handle missing values
        numeric_cols = processed_df.select_dtypes(include=[np.number]).columns
        processed_df[numeric_cols] = processed_df[numeric_cols].fillna(0)
        
        # Create binary outcome variable
        if 'SHOT_MADE_FLAG' in processed_df.columns:
            processed_df['is_made_shot'] = processed_df['SHOT_MADE_FLAG'].astype(int)
        
        # Create shot type features
        if 'SHOT_TYPE' in processed_df.columns:
            processed_df['is_three_pointer'] = (processed_df['SHOT_TYPE'] == '3PT Field Goal').astype(int)
        
        # Create shot zone features
        if 'SHOT_ZONE_BASIC' in processed_df.columns:
            # One-hot encode shot zones
            zone_dummies = pd.get_dummies(processed_df['SHOT_ZONE_BASIC'], prefix='zone')
            processed_df = pd.concat([processed_df, zone_dummies], axis=1)
        
        # Calculate shot distance from basket
        if 'LOC_X' in processed_df.columns and 'LOC_Y' in processed_df.columns:
            # NBA coordinates: basket at (0, 0), court is 50x94 feet
            processed_df['shot_distance'] = np.sqrt(processed_df['LOC_X']**2 + processed_df['LOC_Y']**2) / 10
            
            # Calculate shot angle (0 degrees is straight on, 90/-90 degrees is from the side)
            processed_df['shot_angle'] = np.arctan2(processed_df['LOC_Y'], processed_df['LOC_X']) * (180 / np.pi)
            
            # Normalize to 0-180 range (left to right)
            processed_df['shot_angle'] = np.abs(processed_df['shot_angle'])
        
        return processed_df
    
    def engineer_features(self, df: pd.DataFrame) -> pd.DataFrame:
        """
        Engineer features for HMM-GLM modeling.
        
        Args:
            df: Preprocessed DataFrame
            
        Returns:
            DataFrame with engineered features
        """
        if df.empty:
            return df
        
        # Create a copy to avoid modifying the original
        feature_df = df.copy()
        
        # Create game context features
        
        # 1. Score differential
        if 'SCORE' in feature_df.columns:
            # Extract home and away scores
            feature_df[['home_score', 'away_score']] = feature_df['SCORE'].str.split('-', expand=True).apply(pd.to_numeric, errors='coerce')
            
            # Calculate score differential
            feature_df['score_differential'] = feature_df['home_score'] - feature_df['away_score']
            
            # Normalize score differential
            feature_df['score_differential_norm'] = np.clip(
                feature_df['score_differential'] / 20,  # Normalize by typical max lead
                -1, 1  # Clip to [-1, 1]
            )
        
        # 2. Game time pressure
        if 'game_time_norm' in feature_df.columns and 'score_differential' in feature_df.columns:
            # Calculate game pressure based on time and score
            # High pressure = close game (abs diff < 10) and late game (last 5 min)
            is_close_game = np.abs(feature_df['score_differential']) < 10
            is_late_game = feature_df['game_time_norm'] > 0.9  # Last 5 minutes of regulation
            
            feature_df['game_pressure'] = (is_close_game & is_late_game).astype(int)
        
        # 3. Shot clock situation
        if 'SHOTCLOCK' in feature_df.columns:
            # Normalize shot clock (24 seconds is full)
            feature_df['shot_clock_norm'] = feature_df['SHOTCLOCK'] / 24
            
            # Create shot clock pressure indicator (last 4 seconds)
            feature_df['shot_clock_pressure'] = (feature_df['SHOTCLOCK'] <= 4).astype(int)
        
        # 4. Shot quality features
        if 'shot_distance' in feature_df.columns and 'is_three_pointer' in feature_df.columns:
            # Expected field goal percentage based on distance and type
            # These are approximate values based on NBA averages
            
            # For two-point shots
            two_pt_mask = feature_df['is_three_pointer'] == 0
            feature_df.loc[two_pt_mask, 'expected_fg_pct'] = np.clip(
                0.65 - 0.02 * feature_df.loc[two_pt_mask, 'shot_distance'],
                0.2, 0.8
            )
            
            # For three-point shots
            three_pt_mask = feature_df['is_three_pointer'] == 1
            feature_df.loc[three_pt_mask, 'expected_fg_pct'] = 0.35  # Average 3PT percentage
        
        # 5. Home court advantage
        if 'HOMEDESCRIPTION' in feature_df.columns and 'VISITORDESCRIPTION' in feature_df.columns:
            # Determine if shot was by home team
            home_shot = feature_df['HOMEDESCRIPTION'].notna()
            visitor_shot = feature_df['VISITORDESCRIPTION'].notna()
            
            feature_df['is_home_team_shot'] = home_shot.astype(int)
            
            # Home court advantage factor
            feature_df['home_court_advantage'] = np.where(
                home_shot, 0.02, -0.02  # Small advantage for home team
            )
        
        return feature_df
    
    def create_sequences(self, df: pd.DataFrame, 
                         group_by: List[str] = ['GAME_ID', 'PLAYER1_ID'],
                         sort_by: List[str] = ['GAME_ID', 'EVENTNUM'],
                         min_sequence_length: int = 3,
                         max_time_between: Optional[int] = 600) -> pd.DataFrame:
        """
        Create sequences for HMM-GLM modeling.
        
        Args:
            df: DataFrame with engineered features
            group_by: Columns to group by to create sequences
            sort_by: Columns to sort by within sequences
            min_sequence_length: Minimum sequence length to include
            max_time_between: Maximum seconds between events in a sequence (optional)
            
        Returns:
            DataFrame with sequence IDs and positions
        """
        if df.empty:
            return df
        
        # Create a copy to avoid modifying the original
        seq_df = df.copy()
        
        # Ensure all group_by and sort_by columns exist
        missing_cols = [col for col in group_by + sort_by if col not in seq_df.columns]
        if missing_cols:
            logger.warning(f"Missing columns for sequence creation: {missing_cols}")
            # Create dummy columns for missing ones
            for col in missing_cols:
                seq_df[col] = 0
        
        # Sort the DataFrame
        seq_df = seq_df.sort_values(by=sort_by)
        
        # Create sequence IDs and positions
        seq_df['sequence_id'] = 0
        seq_df['sequence_pos'] = 0
        
        # Group by the specified columns
        current_seq_id = 1
        
        for _, group_df in seq_df.groupby(group_by):
            if len(group_df) < min_sequence_length:
                continue
            
            # Sort group by event number
            group_df = group_df.sort_values(by=sort_by)
            
            # Initialize sequence tracking
            current_seq_pos = 0
            current_seq_start_idx = group_df.index[0]
            last_event_time = None
            
            for i, (idx, row) in enumerate(group_df.iterrows()):
                # Check if we need to start a new sequence due to time gap
                if max_time_between is not None and 'game_seconds_remaining' in row and last_event_time is not None:
                    time_gap = abs(last_event_time - row['game_seconds_remaining'])
                    if time_gap > max_time_between:
                        # If current sequence is long enough, save it
                        if current_seq_pos >= min_sequence_length:
                            current_seq_id += 1
                        
                        # Start new sequence
                        current_seq_pos = 0
                        current_seq_start_idx = idx
                
                # Update sequence position
                seq_df.loc[idx, 'sequence_id'] = current_seq_id
                seq_df.loc[idx, 'sequence_pos'] = current_seq_pos
                current_seq_pos += 1
                
                # Update last event time
                if 'game_seconds_remaining' in row:
                    last_event_time = row['game_seconds_remaining']
            
            # Check if last sequence meets minimum length
            if current_seq_pos >= min_sequence_length:
                current_seq_id += 1
            else:
                # Reset sequence ID for the last short sequence
                seq_df.loc[seq_df['sequence_id'] == current_seq_id, 'sequence_id'] = 0
        
        # Filter out rows not part of valid sequences
        seq_df = seq_df[seq_df['sequence_id'] > 0]
        
        return seq_df
    
    def convert_to_hmm_glm_format(self, df: pd.DataFrame, 
                                  feature_cols: Optional[List[str]] = None,
                                  outcome_col: str = 'is_made_shot',
                                  sequence_id_col: str = 'sequence_id',
                                  sequence_pos_col: str = 'sequence_pos') -> Tuple[np.ndarray, np.ndarray, np.ndarray, List[int]]:
        """
        Convert DataFrame to HMM-GLM input format.
        
        Args:
            df: DataFrame with sequences and features
            feature_cols: List of feature column names (if None, use all numeric columns)
            outcome_col: Name of the outcome column
            sequence_id_col: Name of the sequence ID column
            sequence_pos_col: Name of the sequence position column
            
        Returns:
            Tuple of (X, y, sequences, lengths):
                X: Feature matrix [n_samples, n_features]
                y: Outcome vector [n_samples]
                sequences: Sequence IDs for each sample
                lengths: Length of each sequence
        """
        if df.empty:
            return np.array([]), np.array([]), np.array([]), []
        
        # Select feature columns
        if feature_cols is None:
            # Use all numeric columns except specific ones
            exclude_cols = [outcome_col, sequence_id_col, sequence_pos_col, 
                           'GAME_ID', 'EVENTNUM', 'PLAYER1_ID']
            feature_cols = [col for col in df.select_dtypes(include=[np.number]).columns 
                          if col not in exclude_cols]
        
        # Check if all columns exist
        missing_cols = [col for col in feature_cols + [outcome_col, sequence_id_col, sequence_pos_col] 
                       if col not in df.columns]
        if missing_cols:
            logger.error(f"Missing required columns: {missing_cols}")
            return np.array([]), np.array([]), np.array([]), []
        
        # Extract features, outcomes, and sequence info
        X = df[feature_cols].values
        y = df[outcome_col].values
        sequences = df[sequence_id_col].values
        
        # Calculate sequence lengths
        unique_sequences = df[sequence_id_col].unique()
        lengths = [len(df[df[sequence_id_col] == seq_id]) for seq_id in unique_sequences]
        
        return X, y, sequences, lengths


def convert_nba_data(raw_data_path: str, output_path: Optional[str] = None, 
                     min_sequence_length: int = 3) -> Tuple[pd.DataFrame, Dict[str, Any]]:
    """
    Convert raw NBA data to HMM-GLM compatible format.
    
    Args:
        raw_data_path: Path to raw NBA data (CSV file or directory)
        output_path: Path to save the converted data (optional)
        min_sequence_length: Minimum sequence length to include
        
    Returns:
        Tuple of (converted_df, metadata):
            converted_df: DataFrame with converted data
            metadata: Dictionary with metadata about the conversion
    """
    # Initialize converter and metadata
    converter = NBADataConverter()
    metadata = {
        'source': raw_data_path,
        'min_sequence_length': min_sequence_length,
        'n_original_rows': 0,
        'n_converted_rows': 0,
        'n_sequences': 0,
        'feature_columns': [],
        'outcome_column': 'is_made_shot'
    }
    
    # Load raw data
    raw_df = converter.load_raw_data(raw_data_path)
    if raw_df.empty:
        logger.error(f"No data loaded from {raw_data_path}")
        return pd.DataFrame(), metadata
    
    metadata['n_original_rows'] = len(raw_df)
    
    # Determine data type and preprocess accordingly
    if 'SHOT_ZONE_BASIC' in raw_df.columns:
        # This looks like shot chart data
        processed_df = converter.preprocess_shotchart_data(raw_df)
    else:
        # This looks like play-by-play data
        processed_df = converter.preprocess_pbp_data(raw_df)
    
    # Engineer features
    feature_df = converter.engineer_features(processed_df)
    
    # Create sequences
    sequence_df = converter.create_sequences(
        feature_df,
        group_by=['GAME_ID', 'PLAYER1_ID'] if 'PLAYER1_ID' in feature_df.columns else ['GAME_ID', 'PLAYER_ID'],
        sort_by=['GAME_ID', 'EVENTNUM'] if 'EVENTNUM' in feature_df.columns else ['GAME_ID', 'GAME_EVENT_ID'],
        min_sequence_length=min_sequence_length,
        max_time_between=300  # 5 minutes max between shots in a sequence
    )
    
    metadata['n_converted_rows'] = len(sequence_df)
    metadata['n_sequences'] = sequence_df['sequence_id'].nunique()
    
    # Identify feature columns
    numeric_cols = sequence_df.select_dtypes(include=[np.number]).columns
    exclude_cols = ['sequence_id', 'sequence_pos', 'is_made_shot', 'GAME_ID', 'EVENTNUM', 'PLAYER1_ID', 'PLAYER_ID']
    feature_cols = [col for col in numeric_cols if col not in exclude_cols]
    
    metadata['feature_columns'] = feature_cols
    
    # Save converted data if requested
    if output_path:
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        sequence_df.to_csv(output_path, index=False)
        
        # Save metadata
        metadata_path = output_path.replace('.csv', '_metadata.json')
        import json
        with open(metadata_path, 'w') as f:
            json.dump({k: v for k, v in metadata.items() if isinstance(v, (str, int, float, bool, list))}, f, indent=2)
    
    return sequence_df, metadata


if __name__ == "__main__":
    # Example usage
    import sys
    
    if len(sys.argv) > 1:
        raw_data_path = sys.argv[1]
        output_path = sys.argv[2] if len(sys.argv) > 2 else "data/nba/nba_converted.csv"
        
        print(f"Converting NBA data from {raw_data_path} to {output_path}")
        df, metadata = convert_nba_data(raw_data_path, output_path)
        
        print(f"Conversion complete. Converted {metadata['n_original_rows']} rows into {metadata['n_converted_rows']} rows.")
        print(f"Created {metadata['n_sequences']} sequences.")
    else:
        print("Usage: python nba_converter.py <raw_data_path> [output_path]")

