
import numpy as np
import torch
import os


def random_split(n=11988, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1):
    """Use 9:1:1 split"""
    p_train = train_ratio
    p_val = val_ratio
    p_test = test_ratio

    n = 11988  # original 12000 patients, remove 12 outliers
    n_train = round(n * p_train)
    n_val = round(n * p_val)
    n_test = n - (n_train + n_val)
    p = np.random.permutation(n)
    idx_train = p[:n_train]
    idx_val = p[n_train:n_train + n_val]
    idx_test = p[n_train + n_val:]
    return idx_train, idx_val, idx_test


def get_data_split(base_path, split_path, split_type='random', reverse=False, baseline=True, dataset='P12', predictive_label='mortality', los_threshold_days: int = 3):
    # Optional override: use cached PSV dataset if env flag set (P12 only)
    if dataset == 'P12':
        use_cached_env = os.environ.get('USE_CACHED_DATASET', '0').lower() in ('1', 'true', 'yes')
        if use_cached_env:
            # Lazy import to avoid heavy import graph when not needed
            from baselines.Raindrop.code.baselines.utils_phy12 import _load_cached_p12_from_psv
            cached_dir = os.environ.get('CACHED_PSV_DIR', '/tmp')
            split_pkl_path = os.environ.get('SPLIT_PKL_PATH', 'P12_data_splits/split_1.pkl')
            los_thresh = int(os.environ.get('LOS_THRESHOLD_DAYS', str(los_threshold_days)))
            return _load_cached_p12_from_psv(cached_dir, split_pkl_path, predictive_label=predictive_label, los_threshold_days=los_thresh)
    # Optional override of processed_data directory for sparsity experiments
    override_dir = os.environ.get('SPARSE_PROCESSED_DIR', '').strip()
    processed_override = override_dir if (override_dir and os.path.isdir(override_dir)) else None
    # load data
    if dataset == 'P12':
        if processed_override:
            Pdict_list = np.load(os.path.join(processed_override, 'PTdict_list.npy'), allow_pickle=True)
            arr_outcomes = np.load(os.path.join(processed_override, 'arr_outcomes.npy'), allow_pickle=True)
        else:
            Pdict_list = np.load(base_path + '/processed_data/PTdict_list.npy', allow_pickle=True)
            arr_outcomes = np.load(base_path + '/processed_data/arr_outcomes.npy', allow_pickle=True)
        dataset_prefix = ''
    elif dataset == 'P19':
        Pdict_list = np.load(base_path + '/processed_data/PT_dict_list_6.npy', allow_pickle=True)
        arr_outcomes = np.load(base_path + '/processed_data/arr_outcomes_6.npy', allow_pickle=True)
        dataset_prefix = 'P19_'
    elif dataset == 'eICU':
        Pdict_list = np.load(base_path + '/processed_data/PTdict_list.npy', allow_pickle=True)
        arr_outcomes = np.load(base_path + '/processed_data/arr_outcomes.npy', allow_pickle=True)
        dataset_prefix = 'eICU_'
    elif dataset == 'PAM':
        Pdict_list = np.load(base_path + '/processed_data/PTdict_list.npy', allow_pickle=True)
        arr_outcomes = np.load(base_path + '/processed_data/arr_outcomes.npy', allow_pickle=True)
        dataset_prefix = ''  # not applicable
    elif dataset == 'CD':
        if processed_override:
            Pdict_list = np.load(os.path.join(processed_override, 'PTdict_list.npy'), allow_pickle=True)
            arr_outcomes = np.load(os.path.join(processed_override, 'arr_outcomes.npy'), allow_pickle=True)
        else:
            # Prefer processed_data_new if present
            pd_new = os.path.join(base_path, 'processed_data_new')
            pd_std = os.path.join(base_path, 'processed_data')
            root = pd_new if os.path.isdir(pd_new) else pd_std
            Pdict_list = np.load(os.path.join(root, 'PTdict_list.npy'), allow_pickle=True)
            arr_outcomes = np.load(os.path.join(root, 'arr_outcomes.npy'), allow_pickle=True)
        dataset_prefix = 'CD_'

    show_statistics = False
    if show_statistics:
        idx_under_65 = []
        idx_over_65 = []

        idx_male = []
        idx_female = []

        # variables for statistics
        all_ages = []
        female_count = 0
        male_count = 0
        all_BMI = []

        X_static = np.zeros((len(Pdict_list), len(Pdict_list[0]['extended_static'])))
        for i in range(len(Pdict_list)):
            X_static[i] = Pdict_list[i]['extended_static']
            age, gender_0, gender_1, height, _, _, _, _, weight = X_static[i]
            if age > 0:
                all_ages.append(age)
                if age < 65:
                    idx_under_65.append(i)
                else:
                    idx_over_65.append(i)
            if gender_0 == 1:
                female_count += 1
                idx_female.append(i)
            if gender_1 == 1:
                male_count += 1
                idx_male.append(i)
            if height > 0 and weight > 0:
                all_BMI.append(weight / ((height / 100) ** 2))

        # # plot statistics
        # plt.hist(all_ages, bins=[i * 10 for i in range(12)])
        # plt.xlabel('Years')
        # plt.ylabel('# people')
        # plt.title('Histogram of patients ages, age known in %d samples.\nMean: %.1f, Std: %.1f, Median: %.1f' %
        #           (len(all_ages), np.mean(np.array(all_ages)), np.std(np.array(all_ages)), np.median(np.array(all_ages))))
        # plt.show()
        #
        # plt.hist(all_BMI, bins=[5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60])
        # all_BMI = np.array(all_BMI)
        # all_BMI = all_BMI[(all_BMI > 10) & (all_BMI < 65)]
        # plt.xlabel('BMI')
        # plt.ylabel('# people')
        # plt.title('Histogram of patients BMI, height and weight known in %d samples.\nMean: %.1f, Std: %.1f, Median: %.1f' %
        #           (len(all_BMI), np.mean(all_BMI), np.std(all_BMI), np.median(all_BMI)))
        # plt.show()
        # print('\nGender known: %d,  Male count: %d,  Female count: %d\n' % (male_count + female_count, male_count, female_count))

    # np.save('saved/idx_under_65.npy', np.array(idx_under_65), allow_pickle=True)
    # np.save('saved/idx_over_65.npy', np.array(idx_over_65), allow_pickle=True)
    # np.save('saved/idx_male.npy', np.array(idx_male), allow_pickle=True)
    # np.save('saved/idx_female.npy', np.array(idx_female), allow_pickle=True)

    if baseline==True:
        BL_path = ''
    else:
        BL_path = 'baselines/'

    if split_type == 'random':
        # load random indices from a split
        idx_train, idx_val, idx_test = np.load(base_path + split_path, allow_pickle=True)
    elif split_type == 'age':
        if reverse == False:
            idx_train = np.load(BL_path+'saved/' + dataset_prefix + 'idx_under_65.npy', allow_pickle=True)
            idx_vt = np.load(BL_path+'saved/' + dataset_prefix + 'idx_over_65.npy', allow_pickle=True)
        elif reverse == True:
            idx_train = np.load(BL_path+'saved/' + dataset_prefix + 'idx_over_65.npy', allow_pickle=True)
            idx_vt = np.load(BL_path+'saved/' + dataset_prefix + 'idx_under_65.npy', allow_pickle=True)

        np.random.shuffle(idx_vt)
        idx_val = idx_vt[:round(len(idx_vt) / 2)]
        idx_test = idx_vt[round(len(idx_vt) / 2):]
    elif split_type == 'gender':
        if reverse == False:
            idx_train = np.load(BL_path+'saved/' + dataset_prefix + 'idx_male.npy', allow_pickle=True)
            idx_vt = np.load(BL_path+'saved/' + dataset_prefix + 'idx_female.npy', allow_pickle=True)
        elif reverse == True:
            idx_train = np.load(BL_path+'saved/' + dataset_prefix + 'idx_female.npy', allow_pickle=True)
            idx_vt = np.load(BL_path+'saved/' + dataset_prefix + 'idx_male.npy', allow_pickle=True)

        np.random.shuffle(idx_vt)
        idx_val = idx_vt[:round(len(idx_vt) / 2)]
        idx_test = idx_vt[round(len(idx_vt) / 2):]

    # Handle CD dataset differently - convert patient IDs to indices
    if dataset == 'CD':
        print(f"CD dataset: Converting patient IDs to indices")
        print(f"Original idx_train type: {type(idx_train)}, length: {len(idx_train)}")
        print(f"Sample idx_train values: {idx_train[:5] if len(idx_train) > 0 else 'empty'}")
        
        # Create mapping from patient IDs to indices
        patient_id_to_idx = {patient['id']: idx for idx, patient in enumerate(Pdict_list)}
        print(f"Created mapping for {len(patient_id_to_idx)} patients")
        print(f"Sample patient IDs in data: {list(patient_id_to_idx.keys())[:5]}")
        
        # Convert string patient IDs to integers for matching
        idx_train_int = [int(pid) for pid in idx_train]
        idx_val_int = [int(pid) for pid in idx_val]
        idx_test_int = [int(pid) for pid in idx_test]
        
        # Convert patient IDs to indices
        idx_train = [patient_id_to_idx[pid] for pid in idx_train_int if pid in patient_id_to_idx]
        idx_val = [patient_id_to_idx[pid] for pid in idx_val_int if pid in patient_id_to_idx]
        idx_test = [patient_id_to_idx[pid] for pid in idx_test_int if pid in patient_id_to_idx]
        
        print(f"Converted to indices: train={len(idx_train)}, val={len(idx_val)}, test={len(idx_test)}")
        
        # Convert to numpy arrays
        idx_train = np.array(idx_train)
        idx_val = np.array(idx_val)
        idx_test = np.array(idx_test)
    
    # extract train/val/test examples
    
    Ptrain = Pdict_list[idx_train]
    Pval = Pdict_list[idx_val]
    Ptest = Pdict_list[idx_test]

    if dataset == 'P12' or dataset == 'P19' or dataset == 'PAM':
        if predictive_label == 'mortality':
            y = arr_outcomes[:, -1].reshape((-1, 1))
        elif predictive_label == 'LoS':  # for P12 only
            y = arr_outcomes[:, 3].reshape((-1, 1))
            threshold = float(los_threshold_days)
            y = np.array(list(map(lambda los: 0 if los <= threshold else 1, y)))[..., np.newaxis]
    elif dataset == 'eICU':
        y = arr_outcomes[..., np.newaxis]
    elif dataset == 'CD':
        y = arr_outcomes[..., np.newaxis]
    ytrain = y[idx_train]
    yval = y[idx_val]
    ytest = y[idx_test]

    return Ptrain, Pval, Ptest, ytrain, yval, ytest


def getStats(P_tensor):
    N, T, F = P_tensor.shape
    Pf = P_tensor.transpose((2, 0, 1)).reshape(F, -1)
    mf = np.zeros((F, 1))
    stdf = np.ones((F, 1))
    eps = 1e-7
    for f in range(F):
        vals_f = Pf[f, :]
        # Handle NaN values and filter out non-positive values
        valid_mask = np.logical_and(vals_f > 0, ~np.isnan(vals_f))
        vals_f = vals_f[valid_mask]
        
        if len(vals_f) > 0:
            mean_val = np.mean(vals_f)
            std_val = np.std(vals_f)
            max_val = np.max([std_val, eps])
            mf[f] = mean_val
            stdf[f] = max_val
        else:
            # If no valid values, use default values
            mf[f] = 0.0
            stdf[f] = 1.0
    return mf, stdf


def mask_normalize(P_tensor, mf, stdf):
    """ Normalize time series variables. Missing ones are set to zero after normalization. """
    N, T, F = P_tensor.shape
    Pf = P_tensor.transpose((2, 0, 1)).reshape(F, -1)
    # Handle NaN values in the mask
    M = 1*(np.logical_and(P_tensor > 0, ~np.isnan(P_tensor))) + 0*(np.logical_or(P_tensor <= 0, np.isnan(P_tensor)))
    M_3D = M.transpose((2, 0, 1)).reshape(F, -1)
    for f in range(F):
        # Handle NaN values in normalization
        Pf_f = Pf[f].copy()
        Pf_f[np.isnan(Pf_f)] = 0  # Replace NaN with 0 for normalization
        Pf[f] = (Pf_f-mf[f])/(stdf[f]+1e-18)
    Pf = Pf * M_3D
    Pnorm_tensor = Pf.reshape((F, N, T)).transpose((1, 2, 0))
    Pfinal_tensor = np.concatenate([Pnorm_tensor, M], axis=2)
    return Pfinal_tensor


def getStats_static(P_tensor, dataset='P12'):
    N, S = P_tensor.shape
    Ps = P_tensor.transpose((1, 0))
    ms = np.zeros((S, 1))
    ss = np.ones((S, 1))

    if dataset == 'P12':
        # ['Age' 'Gender=0' 'Gender=1' 'Height' 'ICUType=1' 'ICUType=2' 'ICUType=3' 'ICUType=4' 'Weight']
        bool_categorical = [0, 1, 1, 0, 1, 1, 1, 1, 0]
    elif dataset == 'P19':
        # ['Age' 'Gender' 'Unit1' 'Unit2' 'HospAdmTime' 'ICULOS']
        bool_categorical = [0, 1, 0, 0, 0, 0]
    elif dataset == 'eICU':
        # ['apacheadmissiondx' 'ethnicity' 'gender' 'admissionheight' 'admissionweight'] -> 399 dimensions
        bool_categorical = [1] * 397 + [0] * 2
    elif dataset == 'CD':
        # ['age_at_sampling_in_days_norm', 'sex'] - age is continuous, sex is categorical
        bool_categorical = [0, 1]

    for s in range(S):
        if bool_categorical[s] == 0:  # if not categorical
            vals_s = Ps[s, :]
            # Handle NaN values and filter out non-positive values
            vals_s = vals_s[np.logical_and(vals_s > 0, ~np.isnan(vals_s))]
            if len(vals_s) > 0:
                ms[s] = np.mean(vals_s)
                ss[s] = np.std(vals_s)
    return ms, ss


def mask_normalize_static(P_tensor, ms, ss):
    N, S = P_tensor.shape
    Ps = P_tensor.transpose((1, 0))

    # input normalization
    for s in range(S):
        Ps[s] = (Ps[s] - ms[s]) / (ss[s] + 1e-18)

    # set missing values to zero after normalization
    for s in range(S):
        idx_missing = np.where(np.logical_or(Ps[s, :] <= 0, np.isnan(Ps[s, :])))
        Ps[s, idx_missing] = 0

    # reshape back
    Pnorm_tensor = Ps.reshape((S, N)).transpose((1, 0))
    return Pnorm_tensor


def tensorize_normalize(P, y, mf, stdf, ms, ss):
    # Support variable-length sequences by padding to max length
    F = P[0]['arr'].shape[1]
    D = len(P[0]['extended_static'])
    T_max = max(len(sample['arr']) for sample in P)

    P_tensor = np.zeros((len(P), T_max, F))
    P_time = np.zeros((len(P), T_max, 1))
    P_static_tensor = np.zeros((len(P), D))
    for i in range(len(P)):
        arr_i = P[i]['arr']
        tim_i = P[i]['time']
        t_i = min(len(arr_i), T_max)
        if t_i > 0:
            P_tensor[i, :t_i, :] = arr_i[:t_i]
            P_time[i, :t_i, 0] = tim_i[:t_i, 0]
        P_static_tensor[i] = P[i]['extended_static']
    P_tensor = mask_normalize(P_tensor, mf, stdf)
    P_tensor = torch.Tensor(P_tensor)

    P_time = torch.Tensor(P_time) / 60.0  # convert mins to hours
    P_static_tensor = mask_normalize_static(P_static_tensor, ms, ss)
    P_static_tensor = torch.Tensor(P_static_tensor)

    y_tensor = y
    y_tensor = torch.Tensor(y_tensor[:, 0]).type(torch.LongTensor)
    return P_tensor, P_static_tensor, P_time, y_tensor


def tensorize_normalize_other(P, y, mf, stdf):
    T, F = P[0].shape
    P_time = np.zeros((len(P), T, 1))
    for i in range(len(P)):
        tim = torch.linspace(0, T, T).reshape(-1, 1)
        P_time[i] = tim
    P_tensor = mask_normalize(P, mf, stdf)
    P_tensor = torch.Tensor(P_tensor)

    P_time = torch.Tensor(P_time) / 60.0

    y_tensor = y
    y_tensor = torch.Tensor(y_tensor[:, 0]).type(torch.LongTensor)
    return P_tensor, None, P_time, y_tensor


def masked_softmax(A, epsilon=0.000000001):
    A_max = torch.max(A, dim=1, keepdim=True)[0]
    A_exp = torch.exp(A - A_max)
    A_exp = A_exp * (A != 0).float()
    A_softmax = A_exp / (torch.sum(A_exp, dim=0, keepdim=True) + epsilon)
    return A_softmax


def random_sample(idx_0, idx_1, B, replace=False):
    """ Returns a balanced sample of tensors by randomly sampling without replacement. """
    idx0_batch = np.random.choice(idx_0, size=int(B / 2), replace=replace)
    idx1_batch = np.random.choice(idx_1, size=int(B / 2), replace=replace)
    idx = np.concatenate([idx0_batch, idx1_batch], axis=0)
    return idx


def evaluate(model, P_tensor, P_time_tensor, P_static_tensor, batch_size=100, n_classes=2, static=1):
    model.eval()
    device = next(model.parameters()).device
    P_tensor = P_tensor.to(device)
    P_time_tensor = P_time_tensor.to(device)
    if static is None:
        Pstatic = None
    else:
        P_static_tensor = P_static_tensor.to(device)
        N, Fs = P_static_tensor.shape

    T, N, Ff = P_tensor.shape
    n_batches, rem = N // batch_size, N % batch_size
    out = torch.zeros(N, n_classes)
    start = 0
    for i in range(n_batches):
        P = P_tensor[:, start:start + batch_size, :]
        Ptime = P_time_tensor[:, start:start + batch_size]
        if P_static_tensor is not None:
            Pstatic = P_static_tensor[start:start + batch_size]
        lengths = torch.sum(Ptime > 0, dim=0)
        middleoutput, _, _ = model.forward(P, Pstatic, Ptime, lengths)
        out[start:start + batch_size] = middleoutput.detach().cpu()
        start += batch_size
    if rem > 0:
        P = P_tensor[:, start:start + rem, :]
        Ptime = P_time_tensor[:, start:start + rem]
        if P_static_tensor is not None:
            Pstatic = P_static_tensor[start:start + batch_size]
        lengths = torch.sum(Ptime > 0, dim=0)
        whatever, _, _ = model.forward(P, Pstatic, Ptime, lengths)
        out[start:start + rem] = whatever.detach().cpu()
    return out


def evaluate_standard(model, P_tensor, P_time_tensor, P_static_tensor, batch_size=100, n_classes=2, static=1):
    device = next(model.parameters()).device
    P_tensor = P_tensor.to(device)
    P_time_tensor = P_time_tensor.to(device)
    if static is None:
        P_static_tensor = None
    else:
        P_static_tensor = P_static_tensor.to(device)

    lengths = torch.sum(P_time_tensor > 0, dim=0)
    out, _, _ = model.forward(P_tensor, P_static_tensor, P_time_tensor, lengths)
    return out

