"""
MLB Data Converter

This module provides utilities to convert raw MLB data crawled from Baseball Savant
and MLB GameDay API into formats compatible with the HMM-GLM framework.
"""

import os
import pandas as pd
import numpy as np
from typing import Dict, List, Any, Optional, Union, Tuple
import logging

# Setup logging
logger = logging.getLogger(__name__)


class MLBDataConverter:
    """
    Converter for MLB data to transform it into HMM-GLM compatible format.
    
    This class handles the conversion of raw MLB play-by-play data into
    the standardized format required by the HMM-GLM models, including
    feature engineering and sequence creation.
    """
    
    def __init__(self):
        """Initialize the MLB data converter."""
        pass
    
    def load_raw_data(self, data_path: str) -> pd.DataFrame:
        """
        Load raw MLB data from CSV files.
        
        Args:
            data_path: Path to the raw data CSV file or directory
            
        Returns:
            DataFrame containing the loaded data
        """
        if os.path.isdir(data_path):
            # If directory, find and load all relevant CSV files
            csv_files = [
                os.path.join(data_path, f) for f in os.listdir(data_path)
                if f.endswith('.csv') and ('merged' in f or 'pbp' in f)
            ]
            
            if not csv_files:
                logger.warning(f"No relevant CSV files found in {data_path}")
                return pd.DataFrame()
            
            # Load and concatenate all files
            dfs = []
            for file in csv_files:
                try:
                    df = pd.read_csv(file)
                    dfs.append(df)
                except Exception as e:
                    logger.error(f"Error loading {file}: {e}")
            
            if dfs:
                return pd.concat(dfs, ignore_index=True)
            else:
                return pd.DataFrame()
        
        elif os.path.isfile(data_path) and data_path.endswith('.csv'):
            # Load single CSV file
            try:
                return pd.read_csv(data_path)
            except Exception as e:
                logger.error(f"Error loading {data_path}: {e}")
                return pd.DataFrame()
        
        else:
            logger.error(f"Invalid data path: {data_path}")
            return pd.DataFrame()
    
    def preprocess_statcast_data(self, df: pd.DataFrame) -> pd.DataFrame:
        """
        Preprocess raw Statcast data.
        
        Args:
            df: Raw Statcast DataFrame
            
        Returns:
            Preprocessed DataFrame
        """
        if df.empty:
            return df
        
        # Create a copy to avoid modifying the original
        processed_df = df.copy()
        
        # Handle missing values
        numeric_cols = processed_df.select_dtypes(include=[np.number]).columns
        processed_df[numeric_cols] = processed_df[numeric_cols].fillna(0)
        
        # Convert categorical variables
        if 'pitch_type' in processed_df.columns:
            processed_df['pitch_type'] = processed_df['pitch_type'].fillna('UN')
        
        if 'bb_type' in processed_df.columns:
            processed_df['bb_type'] = processed_df['bb_type'].fillna('UN')
        
        # Create binary outcome variable
        if 'events' in processed_df.columns:
            # Define positive outcomes (hits, walks, etc.)
            positive_outcomes = ['single', 'double', 'triple', 'home_run', 'walk', 'hit_by_pitch']
            processed_df['is_positive_outcome'] = processed_df['events'].isin(positive_outcomes).astype(int)
            
            # Define hit outcomes only
            hit_outcomes = ['single', 'double', 'triple', 'home_run']
            processed_df['is_hit'] = processed_df['events'].isin(hit_outcomes).astype(int)
        
        # Create game state features
        if 'inning' in processed_df.columns and 'inning_topbot' in processed_df.columns:
            # Convert inning half to numeric (0 for top, 1 for bottom)
            processed_df['inning_half'] = (processed_df['inning_topbot'] == 'Bot').astype(int)
            
            # Calculate game state (inning and half)
            processed_df['game_state'] = (processed_df['inning'] - 1) * 2 + processed_df['inning_half']
        
        # Calculate count state
        if 'balls' in processed_df.columns and 'strikes' in processed_df.columns:
            # Create count state (0-0, 0-1, etc.)
            processed_df['count_state'] = processed_df['balls'].astype(str) + '-' + processed_df['strikes'].astype(str)
            
            # Create count advantage (positive for batter, negative for pitcher)
            processed_df['count_advantage'] = processed_df['balls'] - processed_df['strikes']
        
        # Calculate base state
        base_cols = ['on_1b', 'on_2b', 'on_3b']
        if all(col in processed_df.columns for col in base_cols):
            # Create binary indicators for bases occupied
            for col in base_cols:
                processed_df[f'{col}_occupied'] = processed_df[col].notna().astype(int)
            
            # Create base state (0-7 representing the 8 possible states)
            processed_df['base_state'] = (
                processed_df['on_1b_occupied'] * 1 +
                processed_df['on_2b_occupied'] * 2 +
                processed_df['on_3b_occupied'] * 4
            )
        
        return processed_df
    
    def preprocess_pbp_data(self, df: pd.DataFrame) -> pd.DataFrame:
        """
        Preprocess raw play-by-play data.
        
        Args:
            df: Raw play-by-play DataFrame
            
        Returns:
            Preprocessed DataFrame
        """
        if df.empty:
            return df
        
        # Create a copy to avoid modifying the original
        processed_df = df.copy()
        
        # Handle missing values
        processed_df = processed_df.fillna({
            'preceding_event_1': 'none',
            'preceding_event_2': 'none',
            'preceding_event_3': 'none',
            'is_rebound': False,
            'is_rush': False
        })
        
        # Create binary outcome variable if not present
        if 'is_goal' not in processed_df.columns and 'event_type' in processed_df.columns:
            processed_df['is_goal'] = (processed_df['event_type'] == 'Goal').astype(int)
        
        return processed_df
    
    def engineer_features(self, df: pd.DataFrame) -> pd.DataFrame:
        """
        Engineer features for HMM-GLM modeling.
        
        Args:
            df: Preprocessed DataFrame
            
        Returns:
            DataFrame with engineered features
        """
        if df.empty:
            return df
        
        # Create a copy to avoid modifying the original
        feature_df = df.copy()
        
        # Create game context features
        
        # 1. Score differential
        if 'home_score' in feature_df.columns and 'away_score' in feature_df.columns:
            feature_df['score_differential'] = feature_df['home_score'] - feature_df['away_score']
            
            # Normalize score differential
            feature_df['score_differential_norm'] = np.clip(
                feature_df['score_differential'] / 5,  # Normalize by typical max lead
                -1, 1  # Clip to [-1, 1]
            )
        
        # 2. Game time (normalized)
        if 'inning' in feature_df.columns:
            # Assuming 9 innings is a full game
            max_innings = 9
            
            if 'inning_half' in feature_df.columns:
                # Calculate game progress (0 to 1)
                feature_df['game_progress'] = (
                    (feature_df['inning'] - 1) * 2 + feature_df['inning_half']
                ) / (max_innings * 2)
            else:
                feature_df['game_progress'] = feature_df['inning'] / max_innings
            
            # Clip to [0, 1] for extra innings
            feature_df['game_progress'] = np.clip(feature_df['game_progress'], 0, 1)
        
        # 3. Pitcher fatigue proxy
        if 'pitch_number' in feature_df.columns:
            # Normalize pitch count (assuming 100 pitches is high)
            feature_df['pitcher_fatigue'] = np.clip(feature_df['pitch_number'] / 100, 0, 1)
        
        # 4. Platoon advantage
        if 'p_throws' in feature_df.columns and 'stand' in feature_df.columns:
            # 1 if batter has platoon advantage, 0 otherwise
            feature_df['platoon_advantage'] = (
                ((feature_df['p_throws'] == 'L') & (feature_df['stand'] == 'R')) |
                ((feature_df['p_throws'] == 'R') & (feature_df['stand'] == 'L'))
            ).astype(int)
        
        # 5. Count leverage
        if 'balls' in feature_df.columns and 'strikes' in feature_df.columns:
            # Define count leverage based on historical batting averages in different counts
            count_leverage = {
                '0-0': 0.0,    # Neutral
                '0-1': -0.3,   # Pitcher advantage
                '0-2': -0.6,   # Strong pitcher advantage
                '1-0': 0.2,    # Slight batter advantage
                '1-1': -0.1,   # Slight pitcher advantage
                '1-2': -0.5,   # Pitcher advantage
                '2-0': 0.4,    # Batter advantage
                '2-1': 0.2,    # Slight batter advantage
                '2-2': -0.3,   # Pitcher advantage
                '3-0': 0.7,    # Strong batter advantage
                '3-1': 0.5,    # Batter advantage
                '3-2': 0.1     # Slight batter advantage
            }
            
            # Create count state if not present
            if 'count_state' not in feature_df.columns:
                feature_df['count_state'] = feature_df['balls'].astype(str) + '-' + feature_df['strikes'].astype(str)
            
            # Map leverage values
            feature_df['count_leverage'] = feature_df['count_state'].map(count_leverage).fillna(0)
        
        return feature_df
    
    def create_sequences(self, df: pd.DataFrame, 
                         group_by: List[str] = ['game_pk', 'at_bat_number'],
                         sort_by: List[str] = ['game_pk', 'at_bat_number', 'pitch_number'],
                         min_sequence_length: int = 3) -> pd.DataFrame:
        """
        Create sequences for HMM-GLM modeling.
        
        Args:
            df: DataFrame with engineered features
            group_by: Columns to group by to create sequences
            sort_by: Columns to sort by within sequences
            min_sequence_length: Minimum sequence length to include
            
        Returns:
            DataFrame with sequence IDs and positions
        """
        if df.empty:
            return df
        
        # Create a copy to avoid modifying the original
        seq_df = df.copy()
        
        # Ensure all group_by and sort_by columns exist
        missing_cols = [col for col in group_by + sort_by if col not in seq_df.columns]
        if missing_cols:
            logger.warning(f"Missing columns for sequence creation: {missing_cols}")
            # Create dummy columns for missing ones
            for col in missing_cols:
                seq_df[col] = 0
        
        # Sort the DataFrame
        seq_df = seq_df.sort_values(by=sort_by)
        
        # Create sequence IDs and positions
        seq_df['sequence_id'] = 0
        seq_df['sequence_pos'] = 0
        
        # Group by the specified columns
        for i, (_, group) in enumerate(seq_df.groupby(group_by)):
            if len(group) >= min_sequence_length:
                # Assign sequence ID and positions
                seq_df.loc[group.index, 'sequence_id'] = i + 1
                seq_df.loc[group.index, 'sequence_pos'] = range(len(group))
        
        # Filter out rows not part of valid sequences
        seq_df = seq_df[seq_df['sequence_id'] > 0]
        
        return seq_df
    
    def convert_to_hmm_glm_format(self, df: pd.DataFrame, 
                                  feature_cols: Optional[List[str]] = None,
                                  outcome_col: str = 'is_positive_outcome',
                                  sequence_id_col: str = 'sequence_id',
                                  sequence_pos_col: str = 'sequence_pos') -> Tuple[np.ndarray, np.ndarray, np.ndarray, List[int]]:
        """
        Convert DataFrame to HMM-GLM input format.
        
        Args:
            df: DataFrame with sequences and features
            feature_cols: List of feature column names (if None, use all numeric columns)
            outcome_col: Name of the outcome column
            sequence_id_col: Name of the sequence ID column
            sequence_pos_col: Name of the sequence position column
            
        Returns:
            Tuple of (X, y, sequences, lengths):
                X: Feature matrix [n_samples, n_features]
                y: Outcome vector [n_samples]
                sequences: Sequence IDs for each sample
                lengths: Length of each sequence
        """
        if df.empty:
            return np.array([]), np.array([]), np.array([]), []
        
        # Select feature columns
        if feature_cols is None:
            # Use all numeric columns except specific ones
            exclude_cols = [outcome_col, sequence_id_col, sequence_pos_col, 
                           'game_pk', 'at_bat_number', 'pitch_number']
            feature_cols = [col for col in df.select_dtypes(include=[np.number]).columns 
                          if col not in exclude_cols]
        
        # Check if all columns exist
        missing_cols = [col for col in feature_cols + [outcome_col, sequence_id_col, sequence_pos_col] 
                       if col not in df.columns]
        if missing_cols:
            logger.error(f"Missing required columns: {missing_cols}")
            return np.array([]), np.array([]), np.array([]), []
        
        # Extract features, outcomes, and sequence info
        X = df[feature_cols].values
        y = df[outcome_col].values
        sequences = df[sequence_id_col].values
        
        # Calculate sequence lengths
        unique_sequences = df[sequence_id_col].unique()
        lengths = [len(df[df[sequence_id_col] == seq_id]) for seq_id in unique_sequences]
        
        return X, y, sequences, lengths


def convert_mlb_data(raw_data_path: str, output_path: Optional[str] = None, 
                     min_sequence_length: int = 3) -> Tuple[pd.DataFrame, Dict[str, Any]]:
    """
    Convert raw MLB data to HMM-GLM compatible format.
    
    Args:
        raw_data_path: Path to raw MLB data (CSV file or directory)
        output_path: Path to save the converted data (optional)
        min_sequence_length: Minimum sequence length to include
        
    Returns:
        Tuple of (converted_df, metadata):
            converted_df: DataFrame with converted data
            metadata: Dictionary with metadata about the conversion
    """
    # Initialize converter and metadata
    converter = MLBDataConverter()
    metadata = {
        'source': raw_data_path,
        'min_sequence_length': min_sequence_length,
        'n_original_rows': 0,
        'n_converted_rows': 0,
        'n_sequences': 0,
        'feature_columns': [],
        'outcome_column': 'is_positive_outcome'
    }
    
    # Load raw data
    raw_df = converter.load_raw_data(raw_data_path)
    if raw_df.empty:
        logger.error(f"No data loaded from {raw_data_path}")
        return pd.DataFrame(), metadata
    
    metadata['n_original_rows'] = len(raw_df)
    
    # Preprocess data
    if 'events' in raw_df.columns:
        # This looks like Statcast data
        processed_df = converter.preprocess_statcast_data(raw_df)
    else:
        # This looks like play-by-play data
        processed_df = converter.preprocess_pbp_data(raw_df)
    
    # Engineer features
    feature_df = converter.engineer_features(processed_df)
    
    # Create sequences
    sequence_df = converter.create_sequences(
        feature_df,
        group_by=['game_pk', 'at_bat_number'] if 'at_bat_number' in feature_df.columns else ['game_id', 'play_id'],
        sort_by=['game_pk', 'at_bat_number', 'pitch_number'] if 'pitch_number' in feature_df.columns else ['game_id', 'play_id', 'event_idx'],
        min_sequence_length=min_sequence_length
    )
    
    metadata['n_converted_rows'] = len(sequence_df)
    metadata['n_sequences'] = sequence_df['sequence_id'].nunique()
    
    # Identify feature columns
    numeric_cols = sequence_df.select_dtypes(include=[np.number]).columns
    exclude_cols = ['sequence_id', 'sequence_pos', 'is_positive_outcome', 'is_hit', 'is_goal',
                   'game_pk', 'at_bat_number', 'pitch_number', 'game_id', 'play_id', 'event_idx']
    feature_cols = [col for col in numeric_cols if col not in exclude_cols]
    
    metadata['feature_columns'] = feature_cols
    
    # Save converted data if requested
    if output_path:
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        sequence_df.to_csv(output_path, index=False)
        
        # Save metadata
        metadata_path = output_path.replace('.csv', '_metadata.json')
        import json
        with open(metadata_path, 'w') as f:
            json.dump({k: v for k, v in metadata.items() if isinstance(v, (str, int, float, bool, list))}, f, indent=2)
    
    return sequence_df, metadata


if __name__ == "__main__":
    # Example usage
    import sys
    
    if len(sys.argv) > 1:
        raw_data_path = sys.argv[1]
        output_path = sys.argv[2] if len(sys.argv) > 2 else "data/mlb/mlb_converted.csv"
        
        print(f"Converting MLB data from {raw_data_path} to {output_path}")
        df, metadata = convert_mlb_data(raw_data_path, output_path)
        
        print(f"Conversion complete. Converted {metadata['n_original_rows']} rows into {metadata['n_converted_rows']} rows.")
        print(f"Created {metadata['n_sequences']} sequences.")
    else:
        print("Usage: python mlb_converter.py <raw_data_path> [output_path]")

