#!/usr/bin/env python3
"""
Converter that follows the exact Raindrop preprocessing pipeline for PhysioNet 2012
"""

import os
import numpy as np
import pandas as pd

def convert_to_raindrop_format(data_dir, output_dir):
    """
    Convert PhysioNet 2012 data to Raindrop format using their exact preprocessing.
    
    Args:
        data_dir: Directory containing PSV files
        output_dir: Directory to save Raindrop format data
    """
    
    print("="*60)
    print("CONVERTING TO RAINDROP FORMAT USING THEIR PREPROCESSING")
    print("="*60)
    
    # Create output directories
    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(os.path.join(output_dir, 'processed_data'), exist_ok=True)
    os.makedirs(os.path.join(output_dir, 'splits'), exist_ok=True)
    
    # Step 1: Extract all unique parameters (biomarkers) from all files
    print("Step 1: Extracting unique parameters...")
    param_list = extract_unique_parameters(data_dir)
    print(f"Found {len(param_list)} unique parameters: {param_list}")
    
    # Step 2: Define static parameters (same as Raindrop)
    static_param_list = ['Age', 'Gender', 'Height', 'ICUType', 'Weight']
    print(f"Static parameters: {static_param_list}")
    
    # Step 3: Parse all patients
    print("Step 2: Parsing all patients...")
    P_list, arr_outcomes = parse_all_patients(data_dir, param_list, static_param_list)
    print(f"Parsed {len(P_list)} patients")
    
    # Step 4: Create irregular sampling (same as Raindrop)
    print("Step 3: Creating irregular sampling...")
    PTdict_list = create_irregular_sampling(P_list, param_list)
    print(f"Created {len(PTdict_list)} patient dictionaries")
    
    # Step 5: Save processed data
    print("Step 4: Saving processed data...")
    save_processed_data(output_dir, PTdict_list, arr_outcomes, param_list, static_param_list)
    
    # Step 6: Create splits (same as Raindrop)
    print("Step 5: Creating splits...")
    create_splits(output_dir, len(PTdict_list))
    
    print("✅ Conversion completed successfully!")
    return output_dir

def extract_unique_parameters(data_dir):
    """Extract all unique parameters from all PSV files."""
    param_list = set()
    
    for root, dirs, files in os.walk(data_dir):
        for file in files:
            if file.endswith('.psv'):
                file_path = os.path.join(root, file)
                try:
                    df = pd.read_csv(file_path, sep="|")
                    # Get all columns except static and time columns
                    static_cols = ['ICULOS', 'Age', 'Gender', 'Height', 'ICUType', 'Weight', 'Survival']
                    biomarker_cols = [col for col in df.columns if col not in static_cols]
                    param_list.update(biomarker_cols)
                except Exception as e:
                    print(f"Error reading {file}: {e}")
                    continue
    
    # Remove any NaN or empty parameters
    param_list = [p for p in param_list if str(p) != 'nan' and p != '']
    param_list = sorted(list(param_list))
    
    return param_list

def parse_all_patients(data_dir, param_list, static_param_list):
    """Parse all patients following Raindrop's format."""
    P_list = []
    arr_outcomes = []
    
    psv_files = []
    for root, dirs, files in os.walk(data_dir):
        for file in files:
            if file.endswith('.psv'):
                psv_files.append(os.path.join(root, file))
    
    psv_files.sort()  # Sort for reproducibility
    
    for i, file_path in enumerate(psv_files):
        if i % 1000 == 0:
            print(f"Processing patient {i+1}/{len(psv_files)}")
        
        try:
            file_name = os.path.splitext(os.path.basename(file_path))[0]
            df = pd.read_csv(file_path, sep="|")
            
            # Get static features
            static_data = df[static_param_list].iloc[0]
            static_tuple = tuple(static_data.values)
            
            # Get outcome (Survival)
            survival = df['Survival'].iloc[-1]
            
            # Create time series list (following Raindrop format)
            ts_list = []
            for _, row in df.iterrows():
                time_hours = int(row['ICULOS']) // 60
                time_minutes = int(row['ICULOS']) % 60
                total_mins = int(row['ICULOS'])
                
                for param in param_list:
                    if param in row and not pd.isna(row[param]):
                        value = row[param]
                        ts_list.append((time_hours, time_minutes, total_mins, param, value))
            
            # Create patient dictionary (following Raindrop format)
            patient_dict = {
                'id': file_name,
                'static': static_tuple,
                'ts': ts_list
            }
            
            P_list.append(patient_dict)
            arr_outcomes.append(survival)
            
        except Exception as e:
            print(f"Error processing {file_path}: {e}")
            continue
    
    # Create arr_outcomes in the same format as Raindrop original
    # Columns: ["RecordID","SAPS-I","SOFA","Length_of_stay","Survival","In-hospital_death"]
    n_patients = len(P_list)
    arr_outcomes_array = np.zeros((n_patients, 6))
    
    for i, patient in enumerate(P_list):
        # RecordID (use the file name as ID)
        arr_outcomes_array[i, 0] = int(patient['id'])
        # SAPS-I (not available, set to 0)
        arr_outcomes_array[i, 1] = 0
        # SOFA (not available, set to 0)
        arr_outcomes_array[i, 2] = 0
        # Length_of_stay (not available, set to 0)
        arr_outcomes_array[i, 3] = 0
        # Survival (from our data)
        arr_outcomes_array[i, 4] = arr_outcomes[i]
        # In-hospital_death (same as survival for binary classification)
        arr_outcomes_array[i, 5] = arr_outcomes[i]
    
    return P_list, arr_outcomes_array

def create_irregular_sampling(P_list, param_list):
    """Create irregular sampling following Raindrop's approach."""
    max_tmins = 48 * 60  # 48 hours in minutes
    max_len = 215  # Same as Raindrop
    F = len(param_list)
    
    PTdict_list = []
    
    for ind, patient in enumerate(P_list):
        if ind % 1000 == 0:
            print(f"Creating irregular sampling for patient {ind+1}/{len(P_list)}")
        
        ID = patient['id']
        static = patient['static']
        ts = patient['ts']
        
        # Find unique times (following Raindrop)
        unq_tmins = []
        for sample in ts:
            current_tmin = sample[2]  # total minutes
            if (current_tmin not in unq_tmins) and (current_tmin < max_tmins):
                unq_tmins.append(current_tmin)
        unq_tmins = np.array(unq_tmins)
        
        # One-hot encoding of categorical static variables (following Raindrop)
        extended_static = [static[0], 0, 0, static[2], 0, 0, 0, 0, static[4]]
        if static[1] == 0:
            extended_static[1] = 1
        elif static[1] == 1:
            extended_static[2] = 1
        if static[3] == 1:
            extended_static[4] = 1
        elif static[3] == 2:
            extended_static[5] = 1
        elif static[3] == 3:
            extended_static[6] = 1
        elif static[3] == 4:
            extended_static[7] = 1
        
        # Construct array of maximal size (following Raindrop)
        Parr = np.zeros((max_len, F))
        Tarr = np.zeros((max_len, 1))
        
        # For each time measurement find index and store (following Raindrop)
        for sample in ts:
            tmins = sample[2]
            param = sample[3]
            value = sample[4]
            if tmins < max_tmins:
                if param in param_list:
                    time_id = np.where(tmins == unq_tmins)[0][0]
                    param_id = np.where(np.array(param_list) == param)[0][0]
                    Parr[time_id, param_id] = value
                    Tarr[time_id, 0] = unq_tmins[time_id]
        
        length = len(unq_tmins)
        
        # Construct dictionary (following Raindrop format)
        my_dict = {
            'id': ID,
            'static': static,
            'extended_static': extended_static,
            'arr': Parr,
            'time': Tarr,
            'length': length
        }
        
        PTdict_list.append(my_dict)
    
    return PTdict_list

def save_processed_data(output_dir, PTdict_list, arr_outcomes, param_list, static_param_list):
    """Save processed data in Raindrop format."""
    
    # Save PTdict_list
    np.save(os.path.join(output_dir, 'processed_data', 'PTdict_list.npy'), PTdict_list)
    print(f"Saved PTdict_list.npy with {len(PTdict_list)} patients")
    
    # Save arr_outcomes
    np.save(os.path.join(output_dir, 'processed_data', 'arr_outcomes.npy'), arr_outcomes)
    print(f"Saved arr_outcomes.npy with shape {arr_outcomes.shape}")
    
    # Save parameter lists
    np.save(os.path.join(output_dir, 'processed_data', 'ts_params.npy'), param_list)
    np.save(os.path.join(output_dir, 'processed_data', 'static_params.npy'), static_param_list)
    
    # Save extended static parameters (following Raindrop)
    extended_static_list = ['Age', 'Gender=0', 'Gender=1', 'Height', 'ICUType=1', 'ICUType=2', 'ICUType=3', 'ICUType=4', 'Weight']
    np.save(os.path.join(output_dir, 'processed_data', 'extended_static_params.npy'), extended_static_list)
    
    print("✅ All processed data saved successfully")

def create_splits(output_dir, n_patients):
    """Create train/val/test splits following Raindrop's approach."""
    
    # Use 8:1:1 split as in Raindrop
    p_train = 0.80
    p_val = 0.10
    p_test = 0.10
    
    n_train = round(n_patients * p_train)
    n_val = round(n_patients * p_val)
    n_test = n_patients - (n_train + n_val)
    
    print(f"Creating splits: Train={n_train}, Val={n_val}, Test={n_test}")
    
    # Create 5 splits as in Raindrop
    for j in range(5):
        p = np.random.permutation(n_patients)
        idx_train = p[:n_train]
        idx_val = p[n_train:n_train + n_val]
        idx_test = p[n_train + n_val:]
        
        # Save exactly like Raindrop original
        np.save(os.path.join(output_dir, 'splits', f'phy12_split_subset{j+1}.npy'), (idx_train, idx_val, idx_test))
    
    print("✅ All splits saved successfully")

if __name__ == "__main__":
    # Data paths
    data_dir = "/home/dcm.aau.dk/km20bf/Biomarker_FeatureGroup_GNAN/tmp"
    output_dir = "Raindrop/P12data"
    
    convert_to_raindrop_format(data_dir, output_dir) 