# Main preprocessing script for EHR datasets (MIMIC-III, MIMIC-IV, eICU)
# Handles data parsing, encoding, and dataset building for various medical prediction tasks
# Supported Tasks:
# - Mortality prediction
# - Readmission prediction
# - Drug recommendation
# - Diagnosis prediction
# - Heart failure prediction

import os
import pickle
import numpy as np

# Import parsers for different EHR datasets
from preprocess.parse_csv import Mimic3Parser, Mimic4Parser, EICUParser
from preprocess.parse_csv import parse_patient_info
from preprocess.encoded import encode_code

# Import dataset building functions
from preprocess.build_dataset import split_patients, split_patients_disparity, build_code_x_encoded
from preprocess.build_dataset import build_code_x_multi_hot, build_code_y_binary, build_code_y_multi_label
from preprocess.build_dataset import build_heart_failure_y


if __name__ == '__main__':
    # Configuration for different datasets and tasks
    conf = {
        'mimic3': {
            'parser': Mimic3Parser,
            'train_num': 4000,
            'test_num': 2000,
            'threshold': 0.01
        },
        'mimic4': {
            'parser': Mimic4Parser,
            'train_num': 8000,
            'test_num': 1000,
            'threshold': 0.01,
            # 'sample_num': 10000
        },
        'eicu': {
            'parser': EICUParser,
            'train_num': 8000,
            'test_num': 1000,
            'threshold': 0.01
        },
        'other': {
            'feature_keys': ['d', 'p', 'm'],  # d: diagnosis, p: procedure, m: medication
            'task': ['mortality', 'readmission', 'drugrec', 'diagnosis', 'heart'],
            'build_code_': 'encoded',
            'data_level_': 'visit'
        }
    }
    from_saved = True  # [True, False]
    is_pcs, is_med = 'p' in conf['other']['feature_keys'], 'm' in conf['other']['feature_keys']

    # Setup data paths
    data_path = '../DG-EHR/data'
    dataset = 'mimic3'  # mimic3, eICU, or mimic4
    dataset_path = os.path.join(data_path, dataset)
    raw_path = os.path.join(dataset_path, 'raw')
    if not os.path.exists(raw_path):
        os.makedirs(raw_path)
        print('please put the CSV files in `data/%s/raw`' % dataset)
        exit()
    parsed_path = os.path.join(dataset_path, 'parsed')

    if from_saved: 
        patient_admission = pickle.load(open(os.path.join(parsed_path, 'patient_admission.pkl'), 'rb'))
        admission_codes = pickle.load(open(os.path.join(parsed_path, 'admission_codes.pkl'), 'rb'))
        admission_pcs = pickle.load(open(os.path.join(parsed_path, 'admission_pcs.pkl'), 'rb')) if is_pcs else None
        admission_med = pickle.load(open(os.path.join(parsed_path, 'admission_med.pkl'), 'rb')) if is_med else None
    else:
        parser = conf[dataset]['parser'](raw_path, is_pcs, is_med)
        sample_num = conf[dataset].get('sample_num', None)
        patient_admission, admissions = parser.parse(sample_num)
        admission_codes, admission_pcs, admission_med = admissions
        del admissions
        print('saving parsed data ...')
        if not os.path.exists(parsed_path):
            os.makedirs(parsed_path)
        pickle.dump(patient_admission, open(os.path.join(parsed_path, 'patient_admission.pkl'), 'wb'))
        pickle.dump(admission_codes, open(os.path.join(parsed_path, 'admission_codes.pkl'), 'wb'))
        if is_pcs:
            pickle.dump(admission_pcs, open(os.path.join(parsed_path, 'admission_pcs.pkl'), 'wb'))
        if is_med:
            pickle.dump(admission_med, open(os.path.join(parsed_path, 'admission_med.pkl'), 'wb'))

    print("Total number of patients:", len(patient_admission))
    print("Total number of admissions:", len(admission_codes))
    
    # Calculate readmission rate (readmission defined as interval <= 15 days)
    print("\nCalculating readmission statistics...")
    total_intervals = 0
    readmission_count = 0
    
    for pid, admissions in patient_admission.items():
        if len(admissions) < 2:
            continue
            
        # Compare each admission with next admission
        for i in range(len(admissions)-1):
            curr_admission = admissions[i]
            next_admission = admissions[i+1]
            
            interval = (next_admission['admission_time'] - curr_admission['admission_time']).days
            total_intervals += 1
            if interval < 15:
                readmission_count += 1
                
    readmission_rate = readmission_count / total_intervals if total_intervals > 0 else 0
    print(f"Total number of admission intervals: {total_intervals}")
    print(f"Number of readmissions (<15 days): {readmission_count}")
    print(f"Readmission rate: {readmission_rate:.2%}")
    
    # Get patient demographic information for MIMIC datasets (Not supported for eICU)
    if dataset == 'mimic3' or dataset == 'mimic4':
        patient_info = parse_patient_info(raw_path, parsed_path)
        print(patient_info[44083])
        pickle.dump(patient_info, open(os.path.join(parsed_path, 'patient_info.pkl'), 'wb'))
        
    # Encode medical codes into numerical representations
    print('encoding code ...')
    admission_codes_encoded, code_map = encode_code(patient_admission, admission_codes)  # Code map starts from 1
    admission_pcs_encoded, pcs_map = encode_code(patient_admission, admission_pcs) if is_pcs else ({}, {})
    admission_med_encoded, med_map = encode_code(patient_admission, admission_med) if is_med else ({}, {})
    code_nums = {
        'd': len(code_map) if 'd' in conf['other']['feature_keys'] else 0,
        'p': len(pcs_map) if is_pcs else 0,
        'm': len(med_map) if is_med else 0,
    }
    print('# Diagnosis: %d; # Procedure: %d; # Medication: %d' % (code_nums['d'], code_nums['p'], code_nums['m']))
    
    # Split patients into train/validation/test sets with disparity awareness
    train_pids, valid_pids, test_pids_g1, test_pids_g2 = split_patients_disparity(
        patient_admission=patient_admission,
        admission_codes=admission_codes,
        code_map=code_map,
        train_num=conf[dataset]['train_num'],
        test_num=conf[dataset]['test_num'],
        patient_info=patient_info,
        feature_key="GENDER",
        g1_value="M",
    )
    test_pids = np.concatenate((test_pids_g1, test_pids_g2))
    print('There are %d train, %d valid, %d test samples' % (len(train_pids), len(valid_pids), len(test_pids)))
    print('There are %d test for group 1, %d for group 2' % (len(test_pids_g1), len(test_pids_g2)))
    
    # Prepare encoded events for feature building
    admission_events_encoded = dict.fromkeys(['d', 'p', 'm'], None)
    admission_events_encoded.update({
        'd': admission_codes_encoded,
        'p': admission_pcs_encoded,
        'm': admission_med_encoded
    })
    
    # Build features (X) 
    build_code_ = conf['other']['build_code_']
    if build_code_ == 'encoded':
        max_visit_code_nums = {
            t: max([len(codes) for codes in admission_codes.values()]) if admission_codes is not None else 0
            for t, admission_codes in admission_events_encoded.items()}
        x_args = (patient_admission, admission_events_encoded, conf['other']['feature_keys'], max_visit_code_nums)
        x_parser = build_code_x_encoded
    elif build_code_ == 'multi-hot':
        x_args = (patient_admission, admission_events_encoded, conf['other']['feature_keys'], code_nums)
        x_parser = build_code_x_multi_hot
    else:
        raise ValueError('Invalid build_code_')
    
    print('Building code features by visit ...')
    train_codes_tuple = x_parser(train_pids, *x_args)  # (train_codes_x, train_visit_lens)
    valid_codes_tuple = x_parser(valid_pids, *x_args)  # (valid_codes_x, valid_visit_lens)
    test_codes_tuple = x_parser(test_pids, *x_args)  # (test_codes_x, test_visit_lens)
    test_codes_tuple_g1 = x_parser(test_pids_g1, *x_args)
    test_codes_tuple_g2 = x_parser(test_pids_g2, *x_args)
    
    
    print("Total number of visits for training:", sum([len(patient_admission[pid]) - 1 for pid in train_pids]))
    train_codes_x, train_visit_lens = train_codes_tuple
    print("Train codes shape:", train_codes_x['d'].shape)
    
    # # Find a sequence with length 10 and print its codes
    # for i, length in enumerate(train_visit_lens['d']):
    #     if length == 5:
    #         print(f"Found sequence with length 10 at index {i}")
    #         print(f"Visit lengths: {train_visit_lens['d'][i]}")
    #         print("Codes for each visit:")
    #         for j in range(5):
    #             print(f"Visit {5-j}: {train_codes_x['d'][i-j][:5]}")
    #         break
    
    # Build labels (Y) for different prediction tasks
    # 1. Mortality prediction
    y_args_mort = (patient_admission, conf['other']['task'][0])
    train_y_mort = build_code_y_binary(train_pids, *y_args_mort)
    valid_y_mort = build_code_y_binary(valid_pids, *y_args_mort)
    test_y_mort = build_code_y_binary(test_pids, *y_args_mort)
    test_y_mort_g1 = build_code_y_binary(test_pids_g1, *y_args_mort)
    test_y_mort_g2 = build_code_y_binary(test_pids_g2, *y_args_mort)

    # 2. Readmission prediction
    y_args_readm = (patient_admission, conf['other']['task'][1])
    train_y_readm = build_code_y_binary(train_pids, *y_args_readm)
    valid_y_readm = build_code_y_binary(valid_pids, *y_args_readm)
    test_y_readm = build_code_y_binary(test_pids, *y_args_readm)
    test_y_readm_g1 = build_code_y_binary(test_pids_g1, *y_args_readm)
    test_y_readm_g2 = build_code_y_binary(test_pids_g2, *y_args_readm)
    print("Number of positive values in readmission prediction:")
    print(f"Train: {np.sum(train_y_readm)}  {train_y_readm.shape}")
    print(f"Valid: {np.sum(valid_y_readm)}  {valid_y_readm.shape}")
    print(f"Test: {np.sum(test_y_readm)}  {test_y_readm.shape}")

    # 3. Drug recommendation
    y_args_drugrec = (patient_admission, admission_events_encoded, code_nums, conf['other']['task'][2])
    train_y_drugrec = build_code_y_multi_label(train_pids, *y_args_drugrec)
    valid_y_drugrec = build_code_y_multi_label(valid_pids, *y_args_drugrec)
    test_y_drugrec = build_code_y_multi_label(test_pids, *y_args_drugrec)
    test_y_drugrec_g1 = build_code_y_multi_label(test_pids_g1, *y_args_drugrec)
    test_y_drugrec_g2 = build_code_y_multi_label(test_pids_g2, *y_args_drugrec)

    # 4. Diagnosis prediction
    y_args_diag = (patient_admission, admission_events_encoded, code_nums, conf['other']['task'][3])
    train_y_diag = build_code_y_multi_label(train_pids, *y_args_diag)
    valid_y_diag = build_code_y_multi_label(valid_pids, *y_args_diag)
    test_y_diag = build_code_y_multi_label(test_pids, *y_args_diag)
    test_y_diag_g1 = build_code_y_multi_label(test_pids_g1, *y_args_diag)
    test_y_diag_g2 = build_code_y_multi_label(test_pids_g2, *y_args_diag)

    # 5. Heart failure prediction
    train_y_hf = build_heart_failure_y('428', train_y_diag, code_map)
    valid_y_hf = build_heart_failure_y('428', valid_y_diag, code_map)
    test_y_hf = build_heart_failure_y('428', test_y_diag, code_map)
    test_y_hf_g1 = build_heart_failure_y('428', test_y_diag_g1, code_map)
    test_y_hf_g2 = build_heart_failure_y('428', test_y_diag_g2, code_map)
    print("Number of positive values in heart failure prediction:")
    print(f"Train: {np.sum(train_y_hf)}  {train_y_hf.shape}")
    print(f"Valid: {np.sum(valid_y_hf)}  {valid_y_hf.shape}")
    print(f"Test: {np.sum(test_y_hf)}  {test_y_hf.shape}")

    # Save processed datasets
    print('Saving standard datasets ...')
    standard_path = os.path.join(dataset_path, 'standard')
    if not os.path.exists(standard_path):
        os.makedirs(standard_path)

    # Save features
    pickle.dump({
        'train_x': train_codes_tuple,
        'valid_x': valid_codes_tuple,
        'test_x': test_codes_tuple,
        'test_x_g1': test_codes_tuple_g1,
        'test_x_g2': test_codes_tuple_g2,
    }, open(os.path.join(standard_path, 'codes_dataset.pkl'), 'wb'))

    # Save labels for all tasks
    pickle.dump({
        'mortality': (train_y_mort, valid_y_mort, test_y_mort, test_y_mort_g1, test_y_mort_g2),
        'readmission': (train_y_readm, valid_y_readm, test_y_readm, test_y_readm_g1, test_y_readm_g2),
        'drugrec': (train_y_drugrec, valid_y_drugrec, test_y_drugrec, test_y_drugrec_g1, test_y_drugrec_g2),
        'diagnosis': (train_y_diag, valid_y_diag, test_y_diag, test_y_diag_g1, test_y_diag_g2),
        'heartf': (train_y_hf, valid_y_hf, test_y_hf, test_y_hf_g1, test_y_hf_g2),
    }, open(os.path.join(standard_path, 'tasks_dataset.pkl'), 'wb'))

    print(train_y_hf.shape, valid_y_hf.shape, test_y_hf.shape, test_y_hf_g1.shape, test_y_hf_g2.shape)
    