"""
NInFEA Dataset Loader for Intrinsic Dimensionality Analysis

This module provides functions to load and prepare the NInFEA (Non-Invasive 
Multimodal Foetal ECG-Doppler) dataset for multimodal intrinsic dimensionality 
estimation experiments.

Citation:
Sulas, E., Urru, M., Tumbarello, R. et al. A non-invasive multimodal foetal 
ECG-Doppler dataset for antenatal cardiology research. Sci Data 8, 30 (2021). 
https://doi.org/10.1038/s41597-021-00811-3
"""

import os
import sys
import torch
import numpy as np
import glob
from PIL import Image
from typing import List, Dict, Any, Tuple, Optional
from sklearn.preprocessing import StandardScaler

# Fix for numpy._core compatibility issue when loading older PyTorch files
def _setup_numpy_compatibility():
    """Setup numpy compatibility for loading older PyTorch files saved with different numpy versions."""
    if not hasattr(sys.modules.get('numpy', {}), '_core'):
        class FakeCore:
            def __getattr__(self, name):
                return getattr(np, name, None)
        
        fake_core = FakeCore()
        sys.modules['numpy._core'] = fake_core
        sys.modules['numpy._core.numeric'] = np
        sys.modules['numpy._core.multiarray'] = np

_setup_numpy_compatibility()

def read_bin_file(file_path: str) -> Tuple[np.ndarray, float]:
    """
    Reads a single .bin file from the NInFEA dataset according to its custom format.
    
    Binary file format (from ReadBinaryFile.m):
    - fs (8 bytes), 'double', 'ieee-le'  
    - rows (8 bytes), 'uint64', 'ieee-le'
    - cols (8 bytes), 'double', 'ieee-le'  # Note: this should be double, not uint64!
    - data (rows x cols) x 8 bytes, 'double', 'ieee-le' (written column wise)

    Args:
        file_path (str): The full path to the .bin file.

    Returns:
        Tuple[np.ndarray, float]: A tuple containing the data array 
        in (samples, channels) format and the sampling frequency.
    """
    with open(file_path, 'rb') as f:
        # Read the header
        sampling_freq = np.frombuffer(f.read(8), dtype='<f8')[0]  # little-endian double
        num_rows = np.frombuffer(f.read(8), dtype='<u8')[0]       # little-endian uint64
        num_cols = np.frombuffer(f.read(8), dtype='<f8')[0]       # little-endian double (not uint64!)
        
        # Convert num_cols to int
        num_cols = int(num_cols)
        num_rows = int(num_rows)
        
        # Read the data, which is stored column-wise
        data = np.frombuffer(f.read(), dtype='<f8')  # little-endian double
        
        # Check if we have the expected amount of data
        expected_size = num_rows * num_cols
        if len(data) != expected_size:
            # Truncate or pad as needed
            if len(data) > expected_size:
                data = data[:expected_size]
            else:
                # Pad with zeros if insufficient data
                padded_data = np.zeros(expected_size)
                padded_data[:len(data)] = data
                data = padded_data
        
        # Reshape to (rows, cols) - data is stored column-wise
        # MATLAB format is (channels x samples), we want (samples x channels)
        data_matrix = data.reshape(num_rows, num_cols)  # This gives us (channels, samples)
        reshaped_data = data_matrix.T  # Transpose to get (samples, channels)
        
    return reshaped_data, sampling_freq

def read_pwd_image(file_path: str) -> np.ndarray:
    """
    Reads a .bmp image file and converts it to a NumPy array.

    Args:
        file_path (str): The full path to the .bmp file.

    Returns:
        np.ndarray: The image data as a NumPy array.
    """
    with Image.open(file_path) as img:
        return np.array(img)

def prepare_ninfea_on_demand(data_dir: str = "./01_data/raw/ninfea") -> List[Dict[str, Any]]:
    """
    Prepare NInFEA dataset on-demand without saving to disk.
    
    Args:
        data_dir (str): Directory containing the raw NInFEA data
        
    Returns:
        List[Dict[str, Any]]: List of records with multimodal data
    """
    data_path = os.path.join(data_dir, "physionet.org", "files", "ninfea", "1.0.0")
    
    if not os.path.exists(data_path):
        raise FileNotFoundError(
            f"NInFEA raw data not found at {data_path}. "
            "Please run 050_NInFEA_prep.py first to download the data."
        )
    
    # Find all the binary data files
    bin_files_dir = os.path.join(data_path, 'bin_format_ecg_and_respiration')
    bin_files = sorted(glob.glob(os.path.join(bin_files_dir, '*.bin')))
    
    if not bin_files:
        raise FileNotFoundError(f"No .bin files found in {bin_files_dir}")
    
    print(f"Processing {len(bin_files)} recordings on-demand...")
    
    prepared_data = []
    for bin_file_path in bin_files:
        record_id = os.path.basename(bin_file_path).split('.')[0]
        
        # Read the electrophysiological and respiration data
        raw_data, sampling_freq = read_bin_file(bin_file_path)
        
        # Read the corresponding PWD image
        pwd_images_dir = os.path.join(os.path.dirname(os.path.dirname(bin_file_path)), 'pwd_images')
        bmp_file_path = os.path.join(pwd_images_dir, f'{record_id}.bmp')
        pwd_image_data = read_pwd_image(bmp_file_path) if os.path.exists(bmp_file_path) else None

        # Separate the modalities based on the channel description
        unipolar_ecg_data = raw_data[:, 0:24]  # Channels 1-24
        maternal_ecg_data = raw_data[:, 24:27] # Channels 25-27
        respiration_data = raw_data[:, 31:32]  # Channel 32
        
        # Create dictionary and convert to PyTorch Tensors
        record_dict = {
            'record_id': record_id,
            'sampling_freq_hz': sampling_freq,
            'unipolar_ecg': torch.from_numpy(unipolar_ecg_data).float(),
            'maternal_ecg': torch.from_numpy(maternal_ecg_data).float(),
            'maternal_respiration': torch.from_numpy(respiration_data).float(),
            'pwd_image': torch.from_numpy(pwd_image_data).float() if pwd_image_data is not None else None,
        }
        prepared_data.append(record_dict)
    
    return prepared_data

def load_ninfea_dataset(data_dir: str = "./01_data/processed", 
                       processed_file: str = "ninfea_processed.pt",
                       raw_data_dir: str = "./01_data/raw/ninfea",
                       standardize: bool = True) -> List[Dict[str, Any]]:
    """
    Load the NInFEA dataset, either from processed file or on-demand.
    
    Args:
        data_dir (str): Directory containing the processed data
        processed_file (str): Name of the processed data file
        raw_data_dir (str): Directory containing the raw data (for on-demand processing)
        standardize (bool): Whether to standardize the data
        
    Returns:
        List[Dict[str, Any]]: List of records with multimodal data
    """
    file_path = os.path.join(data_dir, processed_file)
    
    if os.path.exists(file_path):
        try:
            print(f"Loading processed NInFEA dataset from {file_path}")
            dataset = torch.load(file_path, map_location='cpu')
        except (ModuleNotFoundError, AttributeError, ImportError, TypeError) as e:
            print(f"Warning: Could not load processed file due to compatibility issue: {e}")
            print("This is likely due to numpy version differences. Falling back to on-demand processing...")
            dataset = prepare_ninfea_on_demand(raw_data_dir)
    else:
        print(f"Processed file not found at {file_path}. Processing on-demand...")
        dataset = prepare_ninfea_on_demand(raw_data_dir)
    
    if standardize:
        dataset = standardize_ninfea_dataset(dataset)
    
    return dataset

def load_ninfea_dataset2(data_dir: str = "./01_data/processed", 
                        processed_file: str = "ninfea_processed.pt",
                        raw_data_dir: str = "./01_data/raw/ninfea",
                        standardize: bool = True) -> List[Dict[str, Any]]:
    """
    Alternative loader that tolerantly handles older pickled files by temporarily
    allowlisting the numpy._core multiarray scalar during torch.load. This should
    be used when the processed file was saved with older numpy/PyTorch and
    torch.load raises the WeightsUnpicklingError / Weights-only message.

    This function intentionally does not modify `load_ninfea_dataset` so that
    existing scripts relying on the original behavior remain unchanged.
    """
    file_path = os.path.join(data_dir, processed_file)

    def _get_numpy_scalar_global():
        # Try to resolve the numpy scalar global used in older saved files
        try:
            return np._core.multiarray.scalar
        except Exception:
            # Try a more conservative import path
            try:
                import numpy._core.multiarray as _ma
                return _ma.scalar
            except Exception:
                return None

    if os.path.exists(file_path):
        try:
            print(f"Loading processed NInFEA dataset from {file_path} (safe loader)")
            dataset = None
            allow_global = _get_numpy_scalar_global()
            # Prefer the safe_globals context manager if available (PyTorch >= 2.6+)
            try:
                if allow_global is not None and hasattr(torch.serialization, 'safe_globals'):
                    with torch.serialization.safe_globals([allow_global]):
                        dataset = torch.load(file_path, map_location='cpu')
                elif allow_global is not None and hasattr(torch.serialization, 'add_safe_globals'):
                    # Older torch versions may expose add_safe_globals; use it defensively
                    torch.serialization.add_safe_globals([allow_global])
                    try:
                        dataset = torch.load(file_path, map_location='cpu')
                    finally:
                        # No explicit remove API; keep-call is limited to local trusted environment
                        pass
                else:
                    # Fallback: try a normal load and if weights-only error occurs, retry with weights_only=False
                    try:
                        dataset = torch.load(file_path, map_location='cpu')
                    except Exception as e_inner:
                        msg = str(e_inner)
                        if 'Weights only load failed' in msg or 'WeightsUnpickler' in msg or 'unsupported' in msg.lower():
                            print('Weights-only load failed — retrying with weights_only=False (trusted file).')
                            dataset = torch.load(file_path, map_location='cpu', weights_only=False)
                        else:
                            raise
            except Exception as e_load:
                print(f"Safe-load attempt failed: {e_load}")
                print("Falling back to on-demand processing...")
                dataset = prepare_ninfea_on_demand(raw_data_dir)
        except (ModuleNotFoundError, AttributeError, ImportError, TypeError, Exception) as e:
            print(f"Warning: Could not load processed file due to compatibility issue: {e}")
            print("Falling back to on-demand processing...")
            dataset = prepare_ninfea_on_demand(raw_data_dir)
    else:
        print(f"Processed file not found at {file_path}. Processing on-demand...")
        dataset = prepare_ninfea_on_demand(raw_data_dir)
    
    if standardize:
        dataset = standardize_ninfea_dataset(dataset)
    
    return dataset

def standardize_ninfea_dataset(dataset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """
    Standardize the numerical modalities in the NInFEA dataset.
    
    Args:
        dataset: List of records with multimodal data
        
    Returns:
        List[Dict[str, Any]]: Standardized dataset
    """
    standardized_dataset = []
    
    # Collect all data for fitting scalers
    ecg_modalities = ['unipolar_ecg', 'maternal_ecg', 'maternal_respiration']
    scalers = {}
    
    # Fit scalers on all data
    for modality in ecg_modalities:
        all_data = []
        for record in dataset:
            if record[modality] is not None:
                # Flatten temporal dimension for standardization
                data = record[modality].numpy()
                all_data.append(data.reshape(-1, data.shape[-1]))
        
        if all_data:
            combined_data = np.concatenate(all_data, axis=0)
            scaler = StandardScaler()
            scaler.fit(combined_data)
            scalers[modality] = scaler
    
    # Apply standardization
    for record in dataset:
        standardized_record = record.copy()
        
        for modality in ecg_modalities:
            if record[modality] is not None and modality in scalers:
                data = record[modality].numpy()
                original_shape = data.shape
                # Reshape for standardization
                data_flat = data.reshape(-1, data.shape[-1])
                data_standardized = scalers[modality].transform(data_flat)
                # Reshape back
                data_standardized = data_standardized.reshape(original_shape)
                standardized_record[modality] = torch.from_numpy(data_standardized).float()
        
        standardized_dataset.append(standardized_record)
    
    return standardized_dataset

def prepare_ninfea_for_multimodal_analysis(dataset: List[Dict[str, Any]], 
                                         max_length: Optional[int] = None) -> Dict[str, torch.Tensor]:
    """
    Prepare NInFEA data for multimodal intrinsic dimensionality analysis.
    
    Args:
        dataset: List of records with multimodal data
        max_length: Maximum sequence length (for truncation/padding)
        
    Returns:
        Dict with modality names as keys and tensors as values
    """
    modalities = {
        'unipolar_ecg': [],
        'maternal_ecg': [], 
        'maternal_respiration': [],
        'pwd_image': []
    }
    
    valid_records = []
    
    for record in dataset:
        # Skip records without essential modalities
        if (record['unipolar_ecg'] is None or 
            record['maternal_ecg'] is None or 
            record['maternal_respiration'] is None):
            continue
            
        valid_records.append(record)
    
    if not valid_records:
        raise ValueError("No valid records found in dataset")
    
    # Find common length for time series data
    if max_length is None:
        lengths = [record['unipolar_ecg'].shape[0] for record in valid_records]
        max_length = min(lengths)  # Use minimum to avoid padding
    
    for record in valid_records:
        # Process ECG modalities (time series)
        for modality in ['unipolar_ecg', 'maternal_ecg', 'maternal_respiration']:
            data = record[modality]
            if data.shape[0] > max_length:
                data = data[:max_length]  # Truncate
            elif data.shape[0] < max_length:
                # Pad with zeros
                padding = torch.zeros(max_length - data.shape[0], data.shape[1])
                data = torch.cat([data, padding], dim=0)
            
            modalities[modality].append(data)
        
        # Process PWD images (if available)
        if record['pwd_image'] is not None:
            modalities['pwd_image'].append(record['pwd_image'])
    
    # Stack into tensors
    result = {}
    for modality, data_list in modalities.items():
        if data_list:
            result[modality] = torch.stack(data_list)
    
    return result

def get_ninfea_data_info(dataset: List[Dict[str, Any]]) -> Dict[str, Any]:
    """
    Get information about the NInFEA dataset structure.
    
    Args:
        dataset: List of records with multimodal data
        
    Returns:
        Dict with dataset information
    """
    info = {
        'num_records': len(dataset),
        'modalities': {},
        'sampling_frequencies': []
    }
    
    for record in dataset:
        info['sampling_frequencies'].append(record['sampling_freq_hz'])
        
        for modality, data in record.items():
            if isinstance(data, torch.Tensor):
                if modality not in info['modalities']:
                    info['modalities'][modality] = {
                        'shapes': [],
                        'available_count': 0
                    }
                info['modalities'][modality]['shapes'].append(tuple(data.shape))
                info['modalities'][modality]['available_count'] += 1
    
    # Compute summary statistics
    info['mean_sampling_freq'] = np.mean(info['sampling_frequencies'])
    info['std_sampling_freq'] = np.std(info['sampling_frequencies'])
    
    return info
