"""
MIMIC-IV Data Loader for CNCRC Drug Recommendation Task

This module provides functionality to load and preprocess MIMIC-IV dataset
for the drug recommendation task within the CNCRC framework.

Key Components:
- MimicDataLoader: Main class for loading MIMIC data
- Data preprocessing and cleaning functions
- Integration with CNCRC data structures (ClinicalContext, etc.)
- Support for train/calibration/test splits
- Efficient data loading with caching support

The loader assumes preprocessed MIMIC data is available locally in
data/processed/mimic/ directory with train/validation/test splits.
"""

import os
import pandas as pd
import numpy as np
import json
from typing import Dict, List, Tuple, Optional, Any, Union
from pathlib import Path
import logging
from dataclasses import dataclass, field
import warnings
from datetime import datetime

from ..core.data_structures import ClinicalContext, DrugInteraction, PredictionSet
from .data_splitter import split_data, DataSplitConfig

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


@dataclass
class MimicDataConfig:
    """Configuration for MIMIC data loading."""
    data_root: str = "data/processed/mimic"
    use_iid_split: bool = True  # Use IID split vs group split
    max_text_length: int = 2048  # Maximum text length for processing
    min_patient_notes: int = 1  # Minimum notes per patient
    cache_enabled: bool = True  # Enable data caching
    include_note_types: List[str] = field(default_factory=lambda: [
        "Discharge summary", "Nursing/other", "Radiology", "Physician"
    ])
    exclude_empty_labels: bool = True  # Exclude samples with no drug labels
    normalize_text: bool = True  # Basic text normalization
    seed: int = 42  # Random seed for reproducibility


class MimicDataLoader:
    """
    Main data loader for MIMIC dataset in CNCRC framework.
    
    This loader handles:
    - Loading preprocessed MIMIC data from Parquet/JSONL files
    - Converting to CNCRC-compatible data structures
    - Providing train/calibration/test splits
    - Data preprocessing and cleaning
    - Integration with clinical context
    
    Example:
        >>> config = MimicDataConfig(data_root="data/processed/mimic")
        >>> loader = MimicDataLoader(config)
        >>> train_data = loader.load_train_data()
        >>> print(f"Training samples: {len(train_data)}")
    """
    
    def __init__(self, config: MimicDataConfig):
        """
        Initialize MIMIC data loader.
        
        Args:
            config: Data loading configuration
        """
        self.config = config
        self.data_root = Path(config.data_root)
        self._cache = {}
        
        # Validate data directory exists
        if not self.data_root.exists():
            raise ValueError(f"Data directory not found: {self.data_root}")
        
        # Set random seed
        np.random.seed(config.seed)
        
        logger.info(f"Initialized MIMIC loader with data root: {self.data_root}")
    
    def load_train_data(self) -> pd.DataFrame:
        """Load training data."""
        return self._load_split("train")
    
    def load_validation_data(self) -> pd.DataFrame:
        """Load validation/calibration data."""
        return self._load_split("validation")
    
    def load_test_data(self) -> pd.DataFrame:
        """Load test data."""
        return self._load_split("test")
    
    def load_all_data(self) -> pd.DataFrame:
        """Load complete dataset."""
        return self._load_split("all")
    
    def _load_split(self, split: str) -> pd.DataFrame:
        """
        Load a specific data split.
        
        Args:
            split: Data split name ("train", "validation", "test", "all")
            
        Returns:
            DataFrame with loaded data
        """
        cache_key = f"{split}_{self.config.use_iid_split}"
        
        # Check cache first
        if self.config.cache_enabled and cache_key in self._cache:
            logger.info(f"Loading {split} data from cache")
            return self._cache[cache_key]
        
        # Determine file name
        split_suffix = "iid" if self.config.use_iid_split else "group"
        if split == "all":
            filename = "all.parquet"
        else:
            filename = f"{split}_{split_suffix}.parquet"
        
        file_path = self.data_root / filename
        
        # Fallback to JSONL if Parquet not available
        if not file_path.exists():
            file_path = self.data_root / filename.replace(".parquet", ".jsonl")
            
        if not file_path.exists():
            raise FileNotFoundError(f"Data file not found: {filename}")
        
        logger.info(f"Loading {split} data from {file_path}")
        
        # Load data
        if file_path.suffix == ".parquet":
            df = pd.read_parquet(file_path)
        elif file_path.suffix == ".jsonl":
            df = pd.read_json(file_path, lines=True)
        else:
            raise ValueError(f"Unsupported file format: {file_path.suffix}")
        
        # Apply preprocessing
        df = self._preprocess_data(df)
        
        # Cache if enabled
        if self.config.cache_enabled:
            self._cache[cache_key] = df
        
        logger.info(f"Loaded {len(df)} samples for {split} split")
        return df
    
    def _preprocess_data(self, df: pd.DataFrame) -> pd.DataFrame:
        """
        Preprocess loaded data.
        
        Args:
            df: Raw dataframe
            
        Returns:
            Preprocessed dataframe
        """
        logger.info(f"Preprocessing data: {len(df)} samples")
        
        # Make a copy to avoid modifying original
        df = df.copy()
        
        # Filter by note types if specified
        if self.config.include_note_types:
            if 'note_type' in df.columns:
                mask = df['note_type'].isin(self.config.include_note_types)
                df = df[mask]
                logger.info(f"Filtered by note types: {len(df)} samples remaining")
        
        # Remove samples with empty labels if specified
        if self.config.exclude_empty_labels and 'labels' in df.columns:
            # Handle various empty label formats
            def is_valid_label(label):
                # Handle None/NaN cases first
                try:
                    if label is None:
                        return False
                    if isinstance(label, float) and np.isnan(label):
                        return False
                except (TypeError, ValueError):
                    pass
                
                # Handle string cases
                if isinstance(label, str):
                    if label == '' or label == '[]' or label == 'null':
                        return False
                    return True
                
                # Handle list/array cases
                if isinstance(label, (list, tuple, np.ndarray)):
                    return len(label) > 0
                
                # Default: consider valid if not empty
                return True
            
            mask = df['labels'].apply(is_valid_label)
            df = df[mask]
            logger.info(f"Filtered empty labels: {len(df)} samples remaining")
        
        # Truncate text if too long
        if 'text_input' in df.columns and self.config.max_text_length > 0:
            df['text_input'] = df['text_input'].astype(str).str[:self.config.max_text_length]
        
        # Basic text normalization
        if self.config.normalize_text and 'text_input' in df.columns:
            df['text_input'] = df['text_input'].astype(str).str.strip()
            # Remove excessive whitespace
            df['text_input'] = df['text_input'].str.replace(r'\s+', ' ', regex=True)
        
        # Ensure required columns exist
        required_columns = ['patient_id', 'text_input', 'labels']
        missing_columns = [col for col in required_columns if col not in df.columns]
        if missing_columns:
            raise ValueError(f"Missing required columns: {missing_columns}")
        
        # Convert IDs to strings for consistency
        for id_col in ['patient_id', 'admission_id', 'note_id']:
            if id_col in df.columns:
                df[id_col] = df[id_col].astype(str)
        
        logger.info(f"Preprocessing complete: {len(df)} samples")
        return df
    
    def get_unique_patients(self, split: str = "train") -> List[str]:
        """
        Get list of unique patient IDs in a split.
        
        Args:
            split: Data split name
            
        Returns:
            List of unique patient IDs
        """
        df = self._load_split(split)
        return df['patient_id'].unique().tolist()
    
    def get_patient_data(self, patient_id: str, split: str = "train") -> pd.DataFrame:
        """
        Get all data for a specific patient.
        
        Args:
            patient_id: Patient identifier
            split: Data split to search in
            
        Returns:
            DataFrame with patient's data
        """
        df = self._load_split(split)
        return df[df['patient_id'] == patient_id].copy()
    
    def convert_to_clinical_contexts(self, df: pd.DataFrame) -> List[ClinicalContext]:
        """
        Convert DataFrame rows to ClinicalContext objects.
        
        Args:
            df: DataFrame with patient data
            
        Returns:
            List of ClinicalContext objects
        """
        contexts = []
        
        for _, row in df.iterrows():
            # Parse labels (could be JSON string, list, etc.)
            try:
                label_value = row['labels']
                
                # Handle None/NaN cases
                if label_value is None or (isinstance(label_value, float) and np.isnan(label_value)):
                    labels = []
                elif isinstance(label_value, str):
                    # Try to parse as JSON
                    try:
                        labels = json.loads(label_value)
                    except json.JSONDecodeError:
                        # Treat as comma-separated string
                        labels = [l.strip() for l in label_value.split(',') if l.strip()]
                elif isinstance(label_value, (list, tuple, np.ndarray)):
                    labels = list(label_value)
                else:
                    # Convert single value to list
                    labels = [str(label_value)]
                
                # Ensure labels is a list and filter out empty values
                if not isinstance(labels, list):
                    labels = [str(labels)]
                    
                # Filter out any None/empty values
                labels = [l for l in labels if l is not None and str(l).strip()]
                    
            except Exception as e:
                logger.warning(f"Failed to parse labels for patient {row['patient_id']}: {e}")
                labels = []
            
            # Create clinical context
            context = ClinicalContext(
                patient_id=str(row['patient_id']),
                age=65,  # Default age - would need additional MIMIC tables for real age
                gender="UNKNOWN",  # Would need PATIENTS table for real gender
                diagnoses=labels,  # Using labels as diagnoses for now
                current_medications=[],  # Would need prescriptions data
                allergies=[],  # Would need allergies data
                weight=None,
                admission_type="UNKNOWN",
                metadata={
                    'note_id': str(row.get('note_id', '')),
                    'admission_id': str(row.get('admission_id', '')),
                    'note_type': str(row.get('note_type', '')),
                    'text_input': str(row.get('text_input', ''))[:200] + "..." if len(str(row.get('text_input', ''))) > 200 else str(row.get('text_input', ''))
                }
            )
            
            contexts.append(context)
        
        return contexts
    
    def get_drug_vocabulary(self, splits: List[str] = None) -> Dict[str, int]:
        """
        Build drug vocabulary from labels across specified splits.
        
        Args:
            splits: List of splits to include (default: all splits)
            
        Returns:
            Dictionary mapping drug names to indices
        """
        if splits is None:
            splits = ["train", "validation", "test"]
        
        all_drugs = set()
        
        for split in splits:
            try:
                df = self._load_split(split)
                
                for label_value in df['labels']:
                    try:
                        # Handle None/NaN cases
                        if label_value is None or (isinstance(label_value, float) and np.isnan(label_value)):
                            continue
                        elif isinstance(label_value, str):
                            # Try to parse as JSON
                            try:
                                parsed_labels = json.loads(label_value)
                            except json.JSONDecodeError:
                                # Treat as comma-separated string
                                parsed_labels = [l.strip() for l in label_value.split(',') if l.strip()]
                        elif isinstance(label_value, (list, tuple, np.ndarray)):
                            parsed_labels = list(label_value)
                        else:
                            # Convert single value to list
                            parsed_labels = [str(label_value)]
                        
                        # Ensure it's a list and filter out empty values
                        if not isinstance(parsed_labels, list):
                            parsed_labels = [str(parsed_labels)]
                        
                        # Filter out None/empty values and add to set
                        valid_labels = [l for l in parsed_labels if l is not None and str(l).strip()]
                        all_drugs.update(valid_labels)
                        
                    except Exception as e:
                        logger.warning(f"Failed to parse labels in split {split}: {e}")
                        continue
                        
            except Exception as e:
                logger.warning(f"Failed to load split {split}: {e}")
                continue
        
        # Create vocabulary mapping
        vocab = {drug: idx for idx, drug in enumerate(sorted(all_drugs))}
        
        logger.info(f"Built drug vocabulary with {len(vocab)} unique drugs")
        return vocab
    
    def get_dataset_statistics(self) -> Dict[str, Any]:
        """
        Get comprehensive statistics about the dataset.
        
        Returns:
            Dictionary with dataset statistics
        """
        stats = {}
        
        for split in ["train", "validation", "test"]:
            try:
                df = self._load_split(split)
                
                # Basic statistics
                split_stats = {
                    'num_samples': len(df),
                    'num_unique_patients': df['patient_id'].nunique(),
                    'avg_text_length': df['text_input'].str.len().mean() if 'text_input' in df.columns else 0,
                    'note_types': df['note_type'].value_counts().to_dict() if 'note_type' in df.columns else {}
                }
                
                # Label statistics
                all_labels = []
                for label_value in df['labels']:
                    try:
                        # Handle None/NaN cases
                        if label_value is None or (isinstance(label_value, float) and np.isnan(label_value)):
                            continue
                        elif isinstance(label_value, str):
                            try:
                                parsed = json.loads(label_value)
                            except json.JSONDecodeError:
                                parsed = [l.strip() for l in label_value.split(',') if l.strip()]
                        elif isinstance(label_value, (list, tuple, np.ndarray)):
                            parsed = list(label_value)
                        else:
                            parsed = [str(label_value)]
                        
                        if not isinstance(parsed, list):
                            parsed = [str(parsed)]
                        
                        # Filter out None/empty values
                        valid_labels = [l for l in parsed if l is not None and str(l).strip()]
                        all_labels.extend(valid_labels)
                    except:
                        continue
                
                split_stats['num_unique_labels'] = len(set(all_labels))
                split_stats['avg_labels_per_sample'] = len(all_labels) / len(df) if len(df) > 0 else 0
                
                stats[split] = split_stats
                
            except Exception as e:
                logger.warning(f"Failed to compute statistics for split {split}: {e}")
                stats[split] = {'error': str(e)}
        
        return stats
    
    def create_data_splits_by_patient(
        self, 
        train_ratio: float = 0.7,
        calibration_ratio: float = 0.15,
        test_ratio: float = 0.15,
        group_by: str = "patient_id"
    ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
        """
        Create new data splits grouped by patient to avoid data leakage.
        
        Args:
            train_ratio: Fraction for training set
            calibration_ratio: Fraction for calibration set  
            test_ratio: Fraction for test set
            group_by: Column to group by for splitting
            
        Returns:
            Tuple of (train_df, calibration_df, test_df)
        """
        # Load all data
        all_data = self.load_all_data()
        
        # Convert DataFrame to list of dictionaries for data splitter
        data_list = all_data.to_dict('records')
        
        # Create split configuration
        split_config = DataSplitConfig(
            is_grouped=True,
            train_ratio=train_ratio,
            calibration_ratio=calibration_ratio,
            test_ratio=test_ratio,
            group_key=group_by,
            random_seed=self.config.seed
        )
        
        # Use existing data splitter
        train_list, cal_list, test_list = split_data(data_list, split_config)
        
        # Convert back to DataFrames
        train_df = pd.DataFrame(train_list)
        cal_df = pd.DataFrame(cal_list)
        test_df = pd.DataFrame(test_list)
        
        logger.info(f"Created new splits: train={len(train_df)}, cal={len(cal_df)}, test={len(test_df)}")
        
        return train_df, cal_df, test_df


def load_mimic_data(
    data_root: str = "data/processed/mimic",
    use_iid_split: bool = True,
    **kwargs
) -> MimicDataLoader:
    """
    Convenience function to create a MimicDataLoader with default configuration.
    
    Args:
        data_root: Path to preprocessed MIMIC data
        use_iid_split: Whether to use IID or group-based splits
        **kwargs: Additional configuration parameters
        
    Returns:
        Configured MimicDataLoader instance
        
    Example:
        >>> loader = load_mimic_data()
        >>> train_data = loader.load_train_data()
        >>> contexts = loader.convert_to_clinical_contexts(train_data.head(10))
    """
    config = MimicDataConfig(
        data_root=data_root,
        use_iid_split=use_iid_split,
        **kwargs
    )
    
    return MimicDataLoader(config)


# Convenience aliases for backward compatibility
def load_train_data(data_root: str = "data/processed/mimic", **kwargs) -> pd.DataFrame:
    """Load training data with default configuration."""
    loader = load_mimic_data(data_root=data_root, **kwargs)
    return loader.load_train_data()


def load_calibration_data(data_root: str = "data/processed/mimic", **kwargs) -> pd.DataFrame:
    """Load calibration/validation data with default configuration.""" 
    loader = load_mimic_data(data_root=data_root, **kwargs)
    return loader.load_validation_data()


def load_test_data(data_root: str = "data/processed/mimic", **kwargs) -> pd.DataFrame:
    """Load test data with default configuration."""
    loader = load_mimic_data(data_root=data_root, **kwargs)
    return loader.load_test_data()
