"""
CNCRC Data Splitting Module

This module provides data splitting functionality for CNCRC experiments.
Supports both i.i.d. and grouped data splitting strategies to maintain
exchangeability assumptions required by conformal prediction.

Migrated and enhanced from the original data_loader module.
"""
import random
from typing import List, Dict, Any, Tuple, Optional
import logging
import numpy as np

logger = logging.getLogger(__name__)


class DataSplitConfig:
    """Configuration for data splitting strategies."""
    
    def __init__(
        self, 
        is_grouped: bool = False,
        train_ratio: float = 0.6,
        calibration_ratio: float = 0.2, 
        test_ratio: float = 0.2,
        group_key: str = "group_id",
        random_seed: Optional[int] = None
    ):
        """
        Initialize data split configuration.
        
        Args:
            is_grouped: Whether to use group-level splitting
            train_ratio: Proportion for training set
            calibration_ratio: Proportion for calibration set
            test_ratio: Proportion for test set
            group_key: Key to identify groups in data
            random_seed: Random seed for reproducibility
        """
        if not np.isclose(train_ratio + calibration_ratio + test_ratio, 1.0):
            raise ValueError("Ratios must sum to 1.0")
            
        self.is_grouped = is_grouped
        self.train_ratio = train_ratio
        self.calibration_ratio = calibration_ratio
        self.test_ratio = test_ratio
        self.group_key = group_key
        self.random_seed = random_seed


def split_data(
    data: List[Dict[str, Any]], 
    config: DataSplitConfig
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]:
    """
    Split data into train, calibration, and test sets.
    
    Supports two splitting strategies:
    1. I.I.D. splitting: Random shuffle and split
    2. Group-level splitting: Ensure no group appears in multiple splits
    
    Args:
        data: List of data samples (dictionaries)
        config: Data splitting configuration
        
    Returns:
        Tuple of (train_set, calibration_set, test_set)
        
    Example:
        >>> data = [{'id': 1, 'patient_id': 'p1', 'features': [...]}, ...]
        >>> config = DataSplitConfig(is_grouped=True, group_key='patient_id')
        >>> train, cal, test = split_data(data, config)
    """
    if not data:
        logger.warning("Empty data provided")
        return [], [], []
    
    # Set random seed for reproducibility
    if config.random_seed is not None:
        random.seed(config.random_seed)
        np.random.seed(config.random_seed)
    
    if not config.is_grouped:
        return _split_iid(data, config)
    else:
        return _split_grouped(data, config)


def _split_iid(
    data: List[Dict[str, Any]], 
    config: DataSplitConfig
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]:
    """
    I.I.D. data splitting strategy.
    
    Randomly shuffles data and splits according to specified ratios.
    """
    logger.info(f"Performing I.I.D. split on {len(data)} samples")
    
    # Create a copy to avoid modifying original data
    data_copy = data.copy()
    random.shuffle(data_copy)
    
    n = len(data_copy)
    n_train = int(n * config.train_ratio)
    n_cal = int(n * config.calibration_ratio)
    
    train_set = data_copy[:n_train]
    calibration_set = data_copy[n_train:n_train + n_cal]
    test_set = data_copy[n_train + n_cal:]
    
    logger.info(f"Split sizes - Train: {len(train_set)}, Cal: {len(calibration_set)}, Test: {len(test_set)}")
    
    return train_set, calibration_set, test_set


def _split_grouped(
    data: List[Dict[str, Any]], 
    config: DataSplitConfig
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]:
    """
    Group-level data splitting strategy.
    
    Ensures all samples from the same group go to the same split.
    This maintains exchangeability when samples within groups are dependent.
    """
    logger.info(f"Performing grouped split on {len(data)} samples using key '{config.group_key}'")
    
    # Group samples by group_key
    group_to_samples = {}
    for sample in data:
        group_id = sample.get(config.group_key)
        if group_id is None:
            raise ValueError(f"Sample missing group key '{config.group_key}': {sample}")
        
        if group_id not in group_to_samples:
            group_to_samples[group_id] = []
        group_to_samples[group_id].append(sample)
    
    # Split groups
    group_ids = list(group_to_samples.keys())
    random.shuffle(group_ids)
    
    n_groups = len(group_ids)
    n_train_groups = int(n_groups * config.train_ratio)
    n_cal_groups = int(n_groups * config.calibration_ratio)
    
    train_groups = group_ids[:n_train_groups]
    cal_groups = group_ids[n_train_groups:n_train_groups + n_cal_groups]
    test_groups = group_ids[n_train_groups + n_cal_groups:]
    
    # Collect samples for each split
    train_set = [sample for gid in train_groups for sample in group_to_samples[gid]]
    calibration_set = [sample for gid in cal_groups for sample in group_to_samples[gid]]
    test_set = [sample for gid in test_groups for sample in group_to_samples[gid]]
    
    logger.info(f"Grouped split - Groups: {n_groups}, Train groups: {len(train_groups)}, "
               f"Cal groups: {len(cal_groups)}, Test groups: {len(test_groups)}")
    logger.info(f"Sample sizes - Train: {len(train_set)}, Cal: {len(calibration_set)}, Test: {len(test_set)}")
    
    return train_set, calibration_set, test_set


def validate_split(
    train_set: List[Dict[str, Any]],
    calibration_set: List[Dict[str, Any]], 
    test_set: List[Dict[str, Any]],
    group_key: Optional[str] = None
) -> Dict[str, Any]:
    """
    Validate data split to ensure no data leakage.
    
    Args:
        train_set: Training data
        calibration_set: Calibration data
        test_set: Test data
        group_key: If provided, check for group overlap
        
    Returns:
        Validation results dictionary
    """
    results = {
        'train_size': len(train_set),
        'calibration_size': len(calibration_set),
        'test_size': len(test_set),
        'total_size': len(train_set) + len(calibration_set) + len(test_set),
        'has_overlap': False,
        'overlap_details': {}
    }
    
    if group_key:
        # Check for group overlap
        train_groups = {sample.get(group_key) for sample in train_set}
        cal_groups = {sample.get(group_key) for sample in calibration_set}
        test_groups = {sample.get(group_key) for sample in test_set}
        
        train_cal_overlap = train_groups & cal_groups
        train_test_overlap = train_groups & test_groups
        cal_test_overlap = cal_groups & test_groups
        
        has_overlap = bool(train_cal_overlap or train_test_overlap or cal_test_overlap)
        
        results.update({
            'has_overlap': has_overlap,
            'group_key': group_key,
            'train_groups': len(train_groups),
            'cal_groups': len(cal_groups),
            'test_groups': len(test_groups),
            'overlap_details': {
                'train_cal': list(train_cal_overlap),
                'train_test': list(train_test_overlap),
                'cal_test': list(cal_test_overlap)
            }
        })
        
        if has_overlap:
            logger.error(f"Group overlap detected in split: {results['overlap_details']}")
        else:
            logger.info("Group-level validation passed: no overlap detected")
    
    return results


def temporal_split(
    data: List[Dict[str, Any]],
    time_key: str,
    train_end_time: Any,
    calibration_end_time: Any
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]:
    """
    Split data based on temporal order.
    
    Useful for evaluating model performance under distribution shift.
    
    Args:
        data: Data with time information
        time_key: Key containing time information
        train_end_time: End time for training data
        calibration_end_time: End time for calibration data
        
    Returns:
        Tuple of (train_set, calibration_set, test_set)
    """
    logger.info(f"Performing temporal split using time key '{time_key}'")
    
    train_set = [sample for sample in data if sample[time_key] <= train_end_time]
    calibration_set = [sample for sample in data 
                      if train_end_time < sample[time_key] <= calibration_end_time]
    test_set = [sample for sample in data if sample[time_key] > calibration_end_time]
    
    logger.info(f"Temporal split sizes - Train: {len(train_set)}, "
               f"Cal: {len(calibration_set)}, Test: {len(test_set)}")
    
    return train_set, calibration_set, test_set




