"""
CD Dataset Preprocessing for Raindrop
Converts CD dataset from graph-based format to Raindrop's expected time series format.
"""

import numpy as np
import pandas as pd
import os
import glob
from tqdm import tqdm
import torch
from datetime import datetime, timedelta
import warnings
warnings.filterwarnings("ignore")

# CD dataset configuration
CD_CONFIG = {
    'biomarkers': [
        "npu02593", "npu02902", "npu02636", "npu02840", "npu01933", "npu01349",  # White blood cell subtypes
        "npu19748", "npu19717",  # Inflammation markers (CRP, F-cal)
        "npu03568",  # Platelets
        "npu02319",  # Hemoglobin
        "npu02508",  # Iron
        "npu02070", "npu01700", "npu10267",  # Vitamins and folate
        "npu19651", "npu01370", "npu19673",  # Liver function markers
    ],
    'static_features': ['sex'],
    'max_time_steps': 50,  # Maximum number of time steps per patient
    'train_ratio': 0.8,
    'val_ratio': 0.1,
    'test_ratio': 0.1,
    'seed': 12
}

def convert_to_tuple(val):
    """Convert string representation of tuple to actual tuple"""
    if pd.isna(val) or val == 'nan':
        return (np.nan, np.nan, np.nan)
    
    if isinstance(val, str):
        # Remove parentheses and split by comma
        val = val.strip('()')
        parts = val.split(',')
        if len(parts) >= 3:
            try:
                value = float(parts[0].strip())
                unit = parts[1].strip().strip("'")
                lab_code = parts[2].strip().strip("'")
                return (value, unit, lab_code)
            except:
                return (np.nan, np.nan, np.nan)
    
    return (np.nan, np.nan, np.nan)

def load_cd_data(data_dir):
    """Load CD dataset from the processed CSV files"""
    print("Loading CD dataset...")
    
    # Get all CSV files
    train_patient_files = glob.glob(os.path.join(data_dir, "train__CD__patient", "*.csv"))
    train_control_files = glob.glob(os.path.join(data_dir, "train__CD__control", "*.csv"))
    val_patient_files = glob.glob(os.path.join(data_dir, "val__CD__patient", "*.csv"))
    val_control_files = glob.glob(os.path.join(data_dir, "val__CD__control", "*.csv"))
    test_patient_files = glob.glob(os.path.join(data_dir, "test__CD__patient", "*.csv"))
    test_control_files = glob.glob(os.path.join(data_dir, "test__CD__control", "*.csv"))
    
    print(f"Found {len(train_patient_files)} train patient files")
    print(f"Found {len(train_control_files)} train control files")
    print(f"Found {len(val_patient_files)} val patient files")
    print(f"Found {len(val_control_files)} val control files")
    print(f"Found {len(test_patient_files)} test patient files")
    print(f"Found {len(test_control_files)} test control files")
    
    # Load metadata
    metadata_path = os.path.join(data_dir, "metadata")
    unique_analysiscode = pd.read_csv(f"{metadata_path}/unique_analysiscode.csv")["0"].tolist()
    unique_units = pd.read_csv(f"{metadata_path}/unique_units.csv")["0"].tolist()
    unique_lab_codes = pd.read_csv(f"{metadata_path}/unique_lab_ids.csv")["0"].tolist()
    
    return {
        'train_patient': train_patient_files,
        'train_control': train_control_files,
        'val_patient': val_patient_files,
        'val_control': val_control_files,
        'test_patient': test_patient_files,
        'test_control': test_control_files,
        'metadata': {
            'biomarkers': unique_analysiscode,
            'units': unique_units,
            'lab_codes': unique_lab_codes
        }
    }

def process_patient_data(file_path, biomarkers, max_time_steps):
    """Process a single patient's data into Raindrop format"""
    df = pd.read_csv(file_path)
    

    df["samplingdate"] = pd.to_datetime(df["samplingdate"])
    df = df.sort_values('samplingdate')
    
    patient_id = df['lbnr'].iloc[0]
    label = 1 if df['dataset'].iloc[0] == 'patient' else 0
    sex = df['sex'].iloc[0]
    
    # Initialize time series matrix
    time_series = np.full((max_time_steps, len(biomarkers)), np.nan)
    timestamps = np.zeros(max_time_steps)
    
    unique_dates = df['samplingdate'].unique()
    unique_dates = sorted(unique_dates)
    

    if len(unique_dates) > max_time_steps:
        unique_dates = unique_dates[:max_time_steps]
    
    # Fill time series matrix
    for t_idx, date in enumerate(unique_dates):
        date_data = df[df['samplingdate'] == date]
        
        if t_idx == 0:
            start_time = date
        time_diff = (date - start_time).total_seconds() / 3600  # hours
        timestamps[t_idx] = time_diff
        

        for b_idx, biomarker in enumerate(biomarkers):
            if biomarker in date_data.columns:
                val = date_data[biomarker].iloc[0]
                if pd.notna(val):
                    biom_data = convert_to_tuple(val)
                    if not np.isnan(biom_data[0]):
                        time_series[t_idx, b_idx] = biom_data[0]
    
    static_features = np.array([sex])
    
    return {
        'id': patient_id,
        'arr': time_series,
        'time': timestamps.reshape(-1, 1),
        'extended_static': static_features,
        'label': label
    }

def create_raindrop_format(cd_data, max_time_steps):
    """Convert CD data to Raindrop format"""
    print("Converting to Raindrop format...")
    
    # Process all datasets
    datasets = {}
    for split in ['train', 'val', 'test']:
        patient_files = cd_data[f'{split}_patient']
        control_files = cd_data[f'{split}_control']
        
        all_files = patient_files + control_files
        all_patients = []
        
        print(f"Processing {split} split: {len(all_files)} files")
        for file_path in tqdm(all_files, desc=f"Processing {split}"):
            try:
                patient_data = process_patient_data(
                    file_path, 
                    CD_CONFIG['biomarkers'], 
                    max_time_steps
                )
                all_patients.append(patient_data)
            except Exception as e:
                print(f"Error processing {file_path}: {e}")
                continue
        
        datasets[split] = all_patients
    
    return datasets

def create_splits(cd_data, splits_dir):
    """Create train/val/test splits for Raindrop"""
    print("Creating data splits...")
    
    # Combine all patient IDs
    all_patient_ids = []
    for split in ['train', 'val', 'test']:
        patient_files = cd_data[f'{split}_patient']
        control_files = cd_data[f'{split}_control']
        
        for file_path in patient_files + control_files:
            patient_id = os.path.basename(file_path).replace('.csv', '')
            all_patient_ids.append(patient_id)
    
    # Create 5 different splits
    np.random.seed(CD_CONFIG['seed'])
    n_splits = 5
    
    for split_idx in range(n_splits):
        # Shuffle patient IDs
        np.random.shuffle(all_patient_ids)
        
        n_total = len(all_patient_ids)
        n_train = int(n_total * CD_CONFIG['train_ratio'])
        n_val = int(n_total * CD_CONFIG['val_ratio'])
        
        train_ids = all_patient_ids[:n_train]
        val_ids = all_patient_ids[n_train:n_train + n_val]
        test_ids = all_patient_ids[n_train + n_val:]
        
        # Save split
        split_data = np.array([train_ids, val_ids, test_ids], dtype=object)
        split_path = f"{splits_dir}/cd_split{split_idx + 1}.npy"
        np.save(split_path, split_data)
        print(f"Saved split {split_idx + 1}: train={len(train_ids)}, val={len(val_ids)}, test={len(test_ids)}")

def save_raindrop_data(datasets, save_dir):
    """Save data in Raindrop format"""
    print("Saving Raindrop format data...")
    
    # Combine all patients
    all_patients = []
    all_outcomes = []
    
    for split_name, split_data in datasets.items():
        for patient in split_data:
            all_patients.append(patient)
            all_outcomes.append(patient['label'])
    
    # Convert to numpy arrays
    PTdict_list = np.array(all_patients, dtype=object)
    arr_outcomes = np.array(all_outcomes).reshape(-1, 1)
    
    # Save main data files
    np.save(f"{save_dir}/PTdict_list.npy", PTdict_list)
    np.save(f"{save_dir}/arr_outcomes.npy", arr_outcomes)
    
    # Save metadata files
    biomarkers = CD_CONFIG['biomarkers']
    static_features = CD_CONFIG['static_features']
    
    np.save(f"{save_dir}/ts_params.npy", np.array(biomarkers))
    np.save(f"{save_dir}/static_params.npy", np.array(static_features))
    np.save(f"{save_dir}/extended_static_params.npy", np.array(static_features))
    
    print(f"Saved {len(all_patients)} patients")
    print(f"Biomarkers: {len(biomarkers)}")
    print(f"Static features: {len(static_features)}")
    
    # Create readme
    readme_content = """# CD Dataset for Raindrop

## Processed data:
* PTdict_list.npy: Array of patient dictionaries with time series data
* arr_outcomes.npy: Array of patient outcomes (0=control, 1=patient)
* ts_params.npy: Array of biomarker names
* static_params.npy: Array of static feature names
* extended_static_params.npy: Array of static feature names

## Data format:
Each patient dictionary contains:
- 'id': Patient ID
- 'arr': Time series matrix (time_steps x num_biomarkers)
- 'time': Timestamps (time_steps x 1)
- 'extended_static': Static features (sex only, age removed to avoid data leakage)
- 'label': Binary label (0=control, 1=patient)

## Splits:
* cd_split1.npy to cd_split5.npy: 5 different train/val/test splits
"""
    
    with open(f"{save_dir}/readme.md", 'w') as f:
        f.write(readme_content)

def main():
    """Main preprocessing function"""
    print("CD Dataset Preprocessing for Raindrop")
    print("=" * 50)
    
    # Paths
    cd_data_dir = "/ngc/projects2/predict_r/research/projects/0054_GNAN_biomarker_trajectories/datasets/Predict/17_bioms/prediag/sequential_sparse_v5.0"  # Relative to Raindrop/code/
    save_dir = "/ngc/projects2/predict_r/research/projects/0054_GNAN_biomarker_trajectories/Raindrop/CDdata/processed_data_new"
    split_dir = "/ngc/projects2/predict_r/research/projects/0054_GNAN_biomarker_trajectories/Raindrop/CDdata/splits_new"
    
    # Load CD data
    cd_data = load_cd_data(cd_data_dir)
    
    # Create Raindrop format
    datasets = create_raindrop_format(cd_data, CD_CONFIG['max_time_steps'])
    
    # Save data
    save_raindrop_data(datasets, save_dir)
    
    # Create splits
    create_splits(cd_data, split_dir)
    
    print("\nPreprocessing completed successfully!")
    print(f"Data saved to: {save_dir}")
    print(f"Splits saved to: {split_dir}")

if __name__ == "__main__":
    main() 