import os
import subprocess
import glob
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm
from typing import List, Dict, Tuple, Any

# --- Citing the source ---
CITATION_INFO = """
If you use this dataset in your research, please cite the following publication:
Sulas, E., Urru, M., Tumbarello, R. et al. A non-invasive multimodal foetal 
ECG-Doppler dataset for antenatal cardiology research. Sci Data 8, 30 (2021). 
https://doi.org/10.1038/s41597-021-00811-3
"""

def download_ninfea_dataset(dest_dir: str = "./01_data/raw/ninfea") -> str:
    """
    Downloads the NInFEA dataset from PhysioNet using wget.

    Args:
        dest_dir (str): The destination directory to store the data.

    Returns:
        str: The path to the downloaded data files.
    """
    physionet_url = "https://physionet.org/files/ninfea/1.0.0/"
    data_path = os.path.join(dest_dir, "physionet.org", "files", "ninfea", "1.0.0")

    if os.path.exists(data_path):
        print(f"Dataset already found in '{data_path}'. Skipping download.")
        return data_path

    print(f"Downloading NInFEA dataset to '{dest_dir}'...")
    os.makedirs(dest_dir, exist_ok=True)
    
    try:
        # Command to download the dataset recursively
        command = [
            "wget", "-r", "-N", "-c", "-np", physionet_url
        ]
        # The -P flag for wget is cleaner than changing directory, let's use it.
        subprocess.run(command, check=True, cwd=dest_dir)
        print("Download complete.")
    except FileNotFoundError:
        print("\nERROR: `wget` command not found.")
        print("Please install wget on your system to download the dataset.")
        print("On Debian/Ubuntu: sudo apt-get install wget")
        print("On macOS (with Homebrew): brew install wget")
        return ""
    except subprocess.CalledProcessError as e:
        print(f"\nERROR: Download failed with error: {e}")
        return ""
        
    return data_path

def read_bin_file(file_path: str) -> Tuple[np.ndarray, float]:
    """
    Reads a single .bin file from the NInFEA dataset according to its custom format.
    
    Binary file format (from ReadBinaryFile.m):
    - fs (8 bytes), 'double', 'ieee-le'  
    - rows (8 bytes), 'uint64', 'ieee-le'
    - cols (8 bytes), 'double', 'ieee-le'  # Note: this should be double, not uint64!
    - data (rows x cols) x 8 bytes, 'double', 'ieee-le' (written column wise)

    Args:
        file_path (str): The full path to the .bin file.

    Returns:
        Tuple[np.ndarray, float]: A tuple containing the data array 
        in (samples, channels) format and the sampling frequency.
    """
    with open(file_path, 'rb') as f:
        # Read the header
        sampling_freq = np.frombuffer(f.read(8), dtype='<f8')[0]  # little-endian double
        num_rows = np.frombuffer(f.read(8), dtype='<u8')[0]       # little-endian uint64
        num_cols = np.frombuffer(f.read(8), dtype='<f8')[0]       # little-endian double (not uint64!)
        
        # Convert num_cols to int
        num_cols = int(num_cols)
        num_rows = int(num_rows)
        
        print(f"Reading file {os.path.basename(file_path)}: {num_rows} rows, {num_cols} cols, fs={sampling_freq}")
        
        # Read the data, which is stored column-wise
        data = np.frombuffer(f.read(), dtype='<f8')  # little-endian double
        
        # Check if we have the expected amount of data
        expected_size = num_rows * num_cols
        if len(data) != expected_size:
            print(f"Warning: Expected {expected_size} values, got {len(data)}")
            # Truncate or pad as needed
            if len(data) > expected_size:
                data = data[:expected_size]
            else:
                # Pad with zeros if insufficient data
                padded_data = np.zeros(expected_size)
                padded_data[:len(data)] = data
                data = padded_data
        
        # Reshape to (rows, cols) - data is stored column-wise
        # MATLAB format is (channels x samples), we want (samples x channels)
        data_matrix = data.reshape(num_rows, num_cols)  # This gives us (channels, samples)
        reshaped_data = data_matrix.T  # Transpose to get (samples, channels)
        
    return reshaped_data, sampling_freq

def read_pwd_image(file_path: str) -> np.ndarray:
    """
    Reads a .bmp image file and converts it to a NumPy array.

    Args:
        file_path (str): The full path to the .bmp file.

    Returns:
        np.ndarray: The image data as a NumPy array.
    """
    with Image.open(file_path) as img:
        return np.array(img)

def prepare_ninfea_dataset(data_dir: str = "./01_data/raw/ninfea") -> List[Dict[str, Any]]:
    """
    Main function to download, parse, and prepare the NInFEA dataset.
    
    It organizes the data by modality into a list of dictionaries, where each
    dictionary corresponds to one recording session.

    Args:
        data_dir (str): The target directory for downloading and storing the data.

    Returns:
        List[Dict[str, Any]]: A list of dictionaries, each containing PyTorch 
        tensors for the different modalities of a single recording.
    """
    print(CITATION_INFO)
    
    # Step 1: Download the data if not present
    files_path = download_ninfea_dataset(dest_dir=data_dir)
    if not files_path:
        return []

    # Step 2: Find all the binary data files
    bin_files_dir = os.path.join(files_path, 'bin_format_ecg_and_respiration')
    bin_files = sorted(glob.glob(os.path.join(bin_files_dir, '*.bin')))
    if not bin_files:
        print(f"Error: No .bin files found in the directory '{bin_files_dir}'.")
        print(f"Available directories in '{files_path}':")
        for item in os.listdir(files_path):
            if os.path.isdir(os.path.join(files_path, item)):
                print(f"  - {item}")
        return []
        
    print(f"\nFound {len(bin_files)} recordings. Preparing data...")

    # Step 3: Process each recording
    prepared_data = []
    for bin_file_path in tqdm(bin_files, desc="Processing recordings"):
        record_id = os.path.basename(bin_file_path).split('.')[0]
        
        # --- Read the electrophysiological and respiration data ---
        raw_data, sampling_freq = read_bin_file(bin_file_path) # Shape: (samples, channels)
        
        # --- Read the corresponding PWD image ---
        pwd_images_dir = os.path.join(os.path.dirname(os.path.dirname(bin_file_path)), 'pwd_images')
        bmp_file_path = os.path.join(pwd_images_dir, f'{record_id}.bmp')
        pwd_image_data = read_pwd_image(bmp_file_path) if os.path.exists(bmp_file_path) else None

        # --- Separate the modalities based on the channel description ---
        # [cite_start]Description from paper and PhysioNet [cite: 1]
        unipolar_ecg_data = raw_data[:, 0:24]  # Channels 1-24
        maternal_ecg_data = raw_data[:, 24:27] # Channels 25-27
        respiration_data = raw_data[:, 31:32]  # Channel 32
        
        # --- Create dictionary and convert to PyTorch Tensors ---
        record_dict = {
            'record_id': record_id,
            'sampling_freq_hz': sampling_freq,
            'unipolar_ecg': torch.from_numpy(unipolar_ecg_data).float(),
            'maternal_ecg': torch.from_numpy(maternal_ecg_data).float(),
            'maternal_respiration': torch.from_numpy(respiration_data).float(),
            'pwd_image': torch.from_numpy(pwd_image_data).float() if pwd_image_data is not None else None,
        }
        prepared_data.append(record_dict)

    print("\nDataset preparation complete.")
    return prepared_data

if __name__ == '__main__':
    # Add src to path for consistency with other scripts in the project
    import sys
    from pathlib import Path
    project_root = Path(__file__).parent.parent.absolute()
    sys.path.append(str(project_root))
    
    # This will download the data to './01_data/raw/ninfea'
    # in the current directory and then process it.
    ninfea_dataset = prepare_ninfea_dataset()

    if ninfea_dataset:
        print(f"\nSuccessfully loaded {len(ninfea_dataset)} records.")
        
        # --- Example: Inspect the first record ---
        first_record = ninfea_dataset[0]
        print("\n--- Inspecting the first record ---")
        print(f"Record ID: {first_record['record_id']}")
        print(f"Sampling Frequency: {first_record['sampling_freq_hz']} Hz")
        
        # Print the shape of each tensor
        for modality, tensor in first_record.items():
            if isinstance(tensor, torch.Tensor):
                # PyTorch tensors are usually (Batch, Channels, ...) or (Features, ...)
                # Here, our data is (Samples, Channels) or (Height, Width, Colors)
                print(f"  - Modality '{modality}': Tensor of shape {tensor.shape}")
                
        # Save processed data for future use
        print("\n--- Checking dataset size before saving ---")
        
        # Estimate size in memory (rough calculation)
        total_size_bytes = 0
        for record in ninfea_dataset:
            for key, value in record.items():
                if isinstance(value, torch.Tensor):
                    # Calculate tensor size in bytes
                    tensor_size = value.numel() * value.element_size()
                    total_size_bytes += tensor_size
        
        size_gb = total_size_bytes / (1024**3)
        print(f"Estimated dataset size: {size_gb:.2f} GB")
        
        if size_gb > 5.0:
            print(f"Dataset size ({size_gb:.2f} GB) exceeds 5 GB threshold.")
            print("Skipping save to disk - will process on-the-fly in analysis scripts.")
            print("The dataset has been successfully prepared and can be loaded using:")
            print("  from src.data.ninfea_loader import prepare_ninfea_on_demand")
        else:
            output_file = "./01_data/processed/ninfea_processed.pt"
            os.makedirs(os.path.dirname(output_file), exist_ok=True)
            torch.save(ninfea_dataset, output_file)
            print(f"Processed dataset saved to: {output_file}")
            print(f"To load later: ninfea_dataset = torch.load('{output_file}')")