import os
import numpy as np
import pandas as pd
import pickle
from collections import defaultdict
import torch
from sklearn.model_selection import train_test_split

def convert_physionet_to_raindrop_format(data_dir, output_dir, n_splits=5):
    """
    Convert PhysioNet 2012 dataset from GMAN format to Raindrop format.
    
    Args:
        data_dir: Directory containing the PhysioNet 2012 CSV files
        output_dir: Directory to save the Raindrop format data
        n_splits: Number of train/val/test splits to create
    """
    
    # Create output directories
    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(os.path.join(output_dir, 'processed_data'), exist_ok=True)
    os.makedirs(os.path.join(output_dir, 'splits'), exist_ok=True)
    
    # Biomarker features (same as in your dataset)
    biomarker_features = ['ALP', 'ALT', 'AST', 'Albumin', 'BUN', 'Bilirubin', 'Cholesterol', 'Creatinine',
                         'DiasABP', 'FiO2', 'GCS', 'Glucose', 'HCO3', 'HCT', 'HR', 'K', 'Lactate', 'MAP',
                         'MechVent', 'Mg', 'NIDiasABP', 'NIMAP', 'NISysABP', 'Na', 'PaCO2', 'PaO2',
                         'Platelets', 'RespRate', 'SaO2', 'SysABP', 'Temp', 'TroponinI', 'TroponinT',
                         'Urine', 'WBC', 'pH']
    
    static_features = ['Age', 'Gender', 'ICUType']
    
    # Get all PSV files (pipe-separated values)
    psv_files = []
    for root, dirs, files in os.walk(data_dir):
        for file in files:
            if file.endswith('.psv'):
                psv_files.append(os.path.join(root, file))
    
    print(f"Found {len(psv_files)} PSV files")
    
    # Process each file and create the Raindrop format
    PTdict_list = []
    arr_outcomes = []
    
    for i, file_path in enumerate(psv_files):
        if i % 100 == 0:
            print(f"Processing file {i+1}/{len(psv_files)}")
        
        try:
            # Read CSV file
            df = pd.read_csv(file_path, sep="|")
            
            # Sort by ICULOS (time)
            df = df.sort_values('ICULOS')
            
            # Get outcome (Survival)
            survival = df['Survival'].iloc[-1]
            
            # Get static features
            static_data = df[static_features].iloc[0]
            
            # Create extended static features (similar to Raindrop format)
            # [Age, Gender_0, Gender_1, Height, Weight, ICUType_0, ICUType_1, ICUType_2, ICUType_3]
            extended_static = []
            
            # Age
            age = static_data['Age'] if not pd.isna(static_data['Age']) else 0
            extended_static.append(age)
            
            # Gender (one-hot encoded)
            gender = static_data['Gender'] if not pd.isna(static_data['Gender']) else 0
            gender_0 = 1 if gender == 0 else 0
            gender_1 = 1 if gender == 1 else 0
            extended_static.extend([gender_0, gender_1])
            
            # Height (not available in this dataset, set to 0)
            extended_static.append(0)
            
            # Weight (not available in this dataset, set to 0)
            extended_static.append(0)
            
            # ICUType (one-hot encoded)
            icu_type = static_data['ICUType'] if not pd.isna(static_data['ICUType']) else 0
            icu_type_0 = 1 if icu_type == 0 else 0
            icu_type_1 = 1 if icu_type == 1 else 0
            icu_type_2 = 1 if icu_type == 2 else 0
            icu_type_3 = 1 if icu_type == 3 else 0
            extended_static.extend([icu_type_0, icu_type_1, icu_type_2, icu_type_3])
            
            # Create time series array
            # Initialize with zeros for all features
            max_time = int(df['ICULOS'].max())
            # Use a fixed maximum length to ensure uniform shapes
            max_sequence_length = 215  # Same as Raindrop's max_len parameter
            time_series = np.zeros((max_sequence_length, len(biomarker_features)), dtype=np.float32)
            time_stamps = np.zeros((max_sequence_length, 1), dtype=np.float32)
            
            # Fill in the time series data
            for _, row in df.iterrows():
                time_idx = int(row['ICULOS'])
                if time_idx < max_sequence_length:  # Safety check
                    time_stamps[time_idx, 0] = time_idx
                
                for j, feature in enumerate(biomarker_features):
                    if feature in row and not pd.isna(row[feature]):
                        if time_idx < max_sequence_length:  # Safety check
                            time_series[time_idx, j] = float(row[feature])
            
            # Create patient dictionary
            patient_dict = {
                'arr': time_series,
                'time': time_stamps,
                'extended_static': np.array(extended_static, dtype=np.float32)
            }
            
            PTdict_list.append(patient_dict)
            arr_outcomes.append(survival)
            
        except Exception as e:
            print(f"Error processing {file_path}: {e}")
            continue
    
    # Convert to numpy arrays
    PTdict_list = np.array(PTdict_list, dtype=object)
    arr_outcomes = np.array(arr_outcomes)
    
    print(f"Processed {len(PTdict_list)} patients")
    print(f"Outcome distribution: {np.bincount(arr_outcomes.astype(int))}")
    print(f"Sample patient time series shape: {PTdict_list[0]['arr'].shape}")
    print(f"Sample patient time stamps shape: {PTdict_list[0]['time'].shape}")
    print(f"Sample patient static features shape: {PTdict_list[0]['extended_static'].shape}")
    
    # Save processed data
    np.save(os.path.join(output_dir, 'processed_data', 'PTdict_list.npy'), PTdict_list)
    np.save(os.path.join(output_dir, 'processed_data', 'arr_outcomes.npy'), arr_outcomes)
    
    # Create splits
    n_patients = len(PTdict_list)
    indices = np.arange(n_patients)
    
    for split_idx in range(1, n_splits + 1):
        # Use stratified split to maintain class balance
        train_idx, temp_idx = train_test_split(
            indices, 
            test_size=0.2, 
            random_state=42 + split_idx,
            stratify=arr_outcomes
        )
        
        val_idx, test_idx = train_test_split(
            temp_idx,
            test_size=0.5,
            random_state=42 + split_idx,
            stratify=arr_outcomes[temp_idx]
        )
        
        # Save split
        split_data = (train_idx, val_idx, test_idx)
        np.save(os.path.join(output_dir, 'splits', f'phy12_split{split_idx}.npy'), split_data)
        
        print(f"Split {split_idx}: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}")
    
    # Create a readme file
    readme_content = """# PhysioNet 2012 Challenge dataset (Converted for Raindrop)

### Processed data ###:
* /processed_data/
  * PTdict_list.npy - List of patient dictionaries with time series and static data
  * arr_outcomes.npy - Survival outcomes (0=survived, 1=died)
  
* /splits/
  * phy12_splitX.npy - Train/val/test splits where X ranges from 1 to 5

### Data format ###:
Each patient in PTdict_list contains:
- 'arr': Time series data (T x F matrix where T=time steps, F=features)
- 'time': Time stamps (T x 1 matrix)
- 'extended_static': Static features [Age, Gender_0, Gender_1, Height, Weight, ICUType_0, ICUType_1, ICUType_2, ICUType_3]

### Features ###:
Biomarker features: ALP, ALT, AST, Albumin, BUN, Bilirubin, Cholesterol, Creatinine, DiasABP, FiO2, GCS, Glucose, HCO3, HCT, HR, K, Lactate, MAP, MechVent, Mg, NIDiasABP, NIMAP, NISysABP, Na, PaCO2, PaO2, Platelets, RespRate, SaO2, SysABP, Temp, TroponinI, TroponinT, Urine, WBC, pH

Static features: Age, Gender, ICUType
"""
    
    with open(os.path.join(output_dir, 'processed_data', 'readme.md'), 'w') as f:
        f.write(readme_content)
    
    print(f"Data conversion completed. Output saved to {output_dir}")
    return output_dir

if __name__ == "__main__":
    # Example usage
    data_dir = "/home/dcm.aau.dk/km20bf/Biomarker_FeatureGroup_GNAN/tmp"  # Directory containing your PhysioNet CSV files
    output_dir = "Raindrop/P12data"  # Use the original Raindrop P12data directory
    
    convert_physionet_to_raindrop_format(data_dir, output_dir) 