# Dataset building module for EHR data
# Handles splitting data and building features/labels for various prediction tasks

import numpy as np
import pandas as pd
from preprocess.parse_csv import EHRParser


def split_patients(patient_admission, admission_codes, code_map, train_num, test_num, seed=6669):
    """
    Split patients into train/validation/test sets
    Ensures all codes in code_map appear in training set
    Args:
        patient_admission: Dict mapping patient IDs to their admission records
        admission_codes: Dict mapping admission IDs to their diagnosis codes
        code_map: Dict mapping codes to their indices
        train_num: Number of patients for training
        test_num: Number of patients for testing
        seed: Random seed for reproducibility
    """
    np.random.seed(seed)
    common_pids = set()
    for i, code in enumerate(code_map):
        print('\r\t%.2f%%' % ((i + 1) * 100 / len(code_map)), end='')
        for pid, admissions in patient_admission.items():
            for admission in admissions:
                codes = admission_codes[admission[EHRParser.adm_id_col]]
                if code in codes:
                    common_pids.add(pid)
                    break
            else:
                continue
            break
    print('\r\t100%')
    max_admission_num = 0
    pid_max_admission_num = 0
    for pid, admissions in patient_admission.items():
        if len(admissions) > max_admission_num:
            max_admission_num = len(admissions)
            pid_max_admission_num = pid
    common_pids.add(pid_max_admission_num)
    remaining_pids = np.array(list(set(patient_admission.keys()).difference(common_pids)))
    np.random.shuffle(remaining_pids)

    valid_num = len(patient_admission) - train_num - test_num
    train_pids = np.array(list(common_pids.union(set(remaining_pids[:(train_num - len(common_pids))].tolist()))))
    valid_pids = remaining_pids[(train_num - len(common_pids)):(train_num + valid_num - len(common_pids))]
    test_pids = remaining_pids[(train_num + valid_num - len(common_pids)):]

    return train_pids, valid_pids, test_pids


def split_patients_disparity(patient_admission, admission_codes, code_map, train_num, test_num, patient_info,
                             feature_key, g1_value=None, seed=6669):
    """
    Split patients into train/validation/test sets with disparity awareness
    Ensures balanced test sets across different demographic groups
    Args:
        patient_admission: Dict mapping patient IDs to their admission records
        admission_codes: Dict mapping admission IDs to their diagnosis codes
        code_map: Dict mapping codes to their indices
        train_num: Number of patients for training
        test_num: Number of patients for testing
        patient_info: Dict containing patient demographic information
        feature_key: Key in patient_info to use for splitting (e.g., 'GENDER')
        g1_value: Value of feature_key to use for group 1
        seed: Random seed for reproducibility
    """
    import numpy as np
    np.random.seed(seed)

    # Find patients that have all codes in code_map
    common_pids = set()
    for i, code in enumerate(code_map):
        print('\r\t%.2f%%' % ((i + 1) * 100 / len(code_map)), end='')
        for pid, admissions in patient_admission.items():
            for admission in admissions:
                codes = admission_codes[admission[EHRParser.adm_id_col]]
                if code in codes:
                    common_pids.add(pid)
                    break
            else:
                continue
            break
    print('\r\t100%')

    # Add patient with most admissions to common set
    max_admission_num = 0
    pid_max_admission_num = None
    for pid, admissions in patient_admission.items():
        if len(admissions) > max_admission_num:
            max_admission_num = len(admissions)
            pid_max_admission_num = pid
    if pid_max_admission_num is not None:
        common_pids.add(pid_max_admission_num)

    # Split remaining patients by demographic feature
    all_pids = set(patient_admission.keys())
    remaining_pids = list(all_pids.difference(common_pids))

    group1 = []
    group2 = []
    group1_value = g1_value
    for pid in remaining_pids:
        feature_value = patient_info[pid][feature_key]
        if group1_value is None:
            group1_value = feature_value
            group1.append(pid)
        elif feature_value == group1_value:
            group1.append(pid)
        else:
            group2.append(pid)

    if len(group1) == 0 or len(group2) == 0:
        raise ValueError("Remaining patients do not split into two groups based on the provided feature key.")

    # Create balanced test sets
    group1 = np.array(group1)
    group2 = np.array(group2)
    np.random.shuffle(group1)
    np.random.shuffle(group2)

    required_each = test_num // 2
    if len(group1) < required_each or len(group2) < required_each:
        raise ValueError("Not enough patients in one of the groups to form the balanced test set.")

    test_pid_g1 = group1[:required_each]
    test_pid_g2 = group2[:required_each]

    # Split remaining patients into train and validation
    test_set = set(test_pid_g1.tolist() + test_pid_g2.tolist())
    remaining_for_train_valid = np.array([pid for pid in remaining_pids if pid not in test_set])
    np.random.shuffle(remaining_for_train_valid)

    train_needed = train_num - len(common_pids)
    if train_needed > len(remaining_for_train_valid):
        raise ValueError("Not enough patients to fulfill the training set requirement after reserving test set.")
    train_from_remaining = remaining_for_train_valid[:train_needed]
    train_pids = np.array(list(common_pids)) if train_needed <= 0 else np.concatenate(
        [np.array(list(common_pids)), train_from_remaining])

    valid_pids = remaining_for_train_valid[train_needed:]

    return train_pids, valid_pids, test_pid_g1, test_pid_g2


def build_code_x_encoded(pids, patient_admission, admission_events_encoded, event_types, max_visit_code_nums):
    """
    Build encoded features for each visit
    Args:
        pids: List of patient IDs
        patient_admission: Dict mapping patient IDs to their admission records
        admission_events_encoded: Dict mapping admission IDs to encoded event codes
        event_types: List of event types to include ('d', 'p', 'm')
        max_visit_code_nums: Dict mapping event types to max number of codes per visit
    """
    n = sum([len(patient_admission[pid]) - 1 for pid in pids])
    x = {t: None for t in event_types}
    lens = {t: None for t in event_types}
    for e_type in event_types:
        max_admission_num = max([len(admissions) for admissions in patient_admission.values()])
        max_visit_code_num = max_visit_code_nums[e_type]

        if max_visit_code_num == 0:
            continue

        x_t = np.zeros((n, max_admission_num, max_visit_code_num), dtype=int)
        lens_t = np.zeros((n,), dtype=int)
        p_idx_t = np.zeros((len(pids),), dtype=int)
        for i, pid in enumerate(pids):
            admissions = patient_admission[pid]
            p_idx_t[i] = p_idx_t[i-1] + len(admissions) - 1
            for k, admission in enumerate(admissions[:-1]):
                codes = admission_events_encoded[e_type][admission[EHRParser.adm_id_col]]
                for idx in range(p_idx_t[i-1]+k, p_idx_t[i]):
                    x_t[idx, k, :len(codes)] = codes
                lens_t[p_idx_t[i-1]+k] = k + 1
        x[e_type], lens[e_type] = x_t, lens_t
    return x, lens


def build_code_x_multi_hot(pids, patient_admission, admission_events_encoded, event_types, code_nums):
    """
    Build multi-hot encoded features for each visit
    Args:
        pids: List of patient IDs
        patient_admission: Dict mapping patient IDs to their admission records
        admission_events_encoded: Dict mapping admission IDs to encoded event codes
        event_types: List of event types to include ('d', 'p', 'm')
        code_nums: Dict mapping event types to number of unique codes
    """
    n = sum([len(patient_admission[pid]) - 1 for pid in pids])
    x = {key: None for key in ['d', 'p', 'm']}
    lens = {key: None for key in ['d', 'p', 'm']}
    
    for e_type in event_types:
        max_admission_num = max([len(admissions) for admissions in patient_admission.values()])
        x_t = np.zeros((n, max_admission_num, code_nums[e_type]), dtype=bool)
        lens_t = np.zeros((n,), dtype=int)
        p_idx_t = np.zeros((len(pids),), dtype=int)
        for i, pid in enumerate(pids):
            admissions = patient_admission[pid]
            p_idx_t[i] = p_idx_t[i-1] + len(admissions) - 1
            for k, admission in enumerate(admissions[:-1]):
                codes = np.array(admission_events_encoded[e_type][admission[EHRParser.adm_id_col]]) - 1
                for idx in range(p_idx_t[i-1]+k, p_idx_t[i]):
                    x_t[idx, k, codes] = 1
                lens_t[p_idx_t[i-1]+k] = k + 1
        x[e_type], lens[e_type] = x_t, lens_t
    return x, lens


def build_code_y_binary(pids, patient_admission, task):
    """
    Build binary labels for mortality or readmission prediction
    Args:
        pids: List of patient IDs
        patient_admission: Dict mapping patient IDs to their admission records
        task: Either 'mortality' or 'readmission'
    """
    n = sum([len(patient_admission[pid]) - 1 for pid in pids])
    y = np.zeros((n,), dtype=bool)
    p_idx_t = np.zeros((len(pids),), dtype=int)
    for i, pid in enumerate(pids):
        admissions = patient_admission[pid]
        p_idx_t[i] = p_idx_t[i - 1] + len(admissions) - 1
        for k, admission in enumerate(admissions[:-1]):
            next_admission = admissions[k+1]
            if task == 'mortality':
                death_time = next_admission[EHRParser.death_time_col]
                y[p_idx_t[i-1]+k] = 0 if pd.isna(death_time) else 1
                # if death_time is not None:
                #     print(death_time)
            elif task == 'readmission':
                discharge_time = admission[EHRParser.disch_time_col]
                next_adm_time = next_admission[EHRParser.adm_time_col]
                y[p_idx_t[i-1]+k] = 1 if (next_adm_time - discharge_time).days < 15 else 0
    return y


def build_code_y_multi_label(pids, patient_admission, admission_events_encoded, code_nums, task):
    """
    Build multi-label targets for drug recommendation or diagnosis prediction
    Args:
        pids: List of patient IDs
        patient_admission: Dict mapping patient IDs to their admission records
        admission_events_encoded: Dict mapping admission IDs to encoded event codes
        code_nums: Dict mapping event types to number of unique codes
        task: Either 'drugrec' or 'diagnosis'
    """
    n = sum([len(patient_admission[pid]) - 1 for pid in pids])
    e_type = 'm' if task == 'drugrec' else 'd'  # Only support 'drugrec' and 'diagnosis' tasks
    y = np.zeros((n, code_nums[e_type]), dtype=bool)
    p_idx_t = np.zeros((len(pids),), dtype=int)
    for i, pid in enumerate(pids):
        admissions = patient_admission[pid]
        p_idx_t[i] = p_idx_t[i - 1] + len(admissions) - 1
        for k, admission in enumerate(admissions[1:]):
            codes = np.array(admission_events_encoded[e_type][admission[EHRParser.adm_id_col]]) - 1
            if len(codes) > 0:
                y[p_idx_t[i - 1] + k, codes] = 1
    return y


def build_heart_failure_y(hf_prefix: str, codes_y: np.ndarray, code_map: dict) -> np.ndarray:
    print('building train/valid/test heart failure labels ...')
    hf_list = np.array([cid for code, cid in code_map.items() if code.startswith(hf_prefix)])
    hfs = np.zeros((len(code_map), ), dtype=int)
    hfs[hf_list - 1] = 1
    hf_exist = np.logical_and(codes_y, hfs)
    y = (np.sum(hf_exist, axis=-1) > 0).astype(int)
    return y
