"""
Raindrop model adapted for CD dataset
Modified version of Raindrop.py to support CD dataset
"""

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import time
from sklearn.metrics import roc_auc_score, classification_report, confusion_matrix, average_precision_score, precision_score, recall_score, f1_score
from baselines.Raindrop.code.models_rd import *
from baselines.Raindrop.code.utils_rd import *
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")

# Import the original Raindrop functions
from baselines.Raindrop.code.Raindrop import one_hot, generate_global_structure, diffuse

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='CD', choices=['P12', 'P19', 'eICU', 'PAM', 'CD'])
parser.add_argument('--withmissingratio', default=False, help='if True, missing ratio ranges from 0 to 0.5; if False, missing ratio =0')
parser.add_argument('--splittype', type=str, default='random', choices=['random', 'age', 'gender'], help='only use for P12 and P19')
parser.add_argument('--reverse', default=False, help='if True,use female, older for tarining; if False, use female or younger for training')
parser.add_argument('--feature_removal_level', type=str, default='no_removal', choices=['no_removal', 'set', 'sample'],
                    help='use this only when splittype==random; otherwise, set as no_removal')
parser.add_argument('--predictive_label', type=str, default='mortality', choices=['mortality', 'LoS'],
                    help='use this only with P12 dataset (mortality or length of stay)')
parser.add_argument('--seed', type=int, default=42, help='Random seed for reproducibility')
parser.add_argument('--use_cached_dataset', action='store_true', help='Use cached PSV dataset (P12 only)')
parser.add_argument('--cached_dataset_dir', type=str, default='/tmp', help='Directory with cached PSV files (P12)')
parser.add_argument('--split_pkl_path', type=str, default='P12_data_splits/split_{split_idx}.pkl', help='Split PKL path template with {split_idx}')
parser.add_argument('--los_threshold_days', type=int, default=3, help='LoS threshold in days for P12')
args, unknown = parser.parse_known_args()

# Set random seeds for reproducibility
import random
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

print(f'Using seed: {args.seed}')

arch = 'raindrop'
model_path = '/home/dcm.aau.dk/km20bf/Biomarker_FeatureGroup_GNAN/baselines/Raindrop/models/'

dataset = args.dataset
print('Dataset used: ', dataset)

if dataset == 'P12':
    base_path = '/home/dcm.aau.dk/km20bf/Biomarker_FeatureGroup_GNAN/baselines/Raindrop/P12data'
elif dataset == 'P19':
    base_path = '/ngc/projects2/predict_r/research/projects/0054_GNAN_biomarker_trajectories/Raindrop/P19data'
elif dataset == 'eICU':
    base_path = '/ngc/projects2/predict_r/research/projects/0054_GNAN_biomarker_trajectories/Raindrop/eICUdata'
elif dataset == 'PAM':
    base_path = '/ngc/projects2/predict_r/research/projects/0054_GNAN_biomarker_trajectories/Raindrop/PAMdata'
elif dataset == 'CD':
    base_path = '/ngc/projects2/predict_r/research/projects/0054_GNAN_biomarker_trajectories/Raindrop/CDdata'

baseline = False  # always False for Raindrop
split = args.splittype  # possible values: 'random', 'age', 'gender'
reverse = args.reverse  # False or True
feature_removal_level = args.feature_removal_level  # 'set', 'sample'

print('args.dataset, args.splittype, args.reverse, args.withmissingratio, args.feature_removal_level',
      args.dataset, args.splittype, args.reverse, args.withmissingratio, args.feature_removal_level)

"""While missing_ratio >0, feature_removal_level is automatically used"""
if args.withmissingratio == True:
    missing_ratios = [0.1, 0.2, 0.3, 0.4, 0.5]
else:
    missing_ratios = [0]
print('missing ratio list', missing_ratios)

sensor_wise_mask = False

for missing_ratio in missing_ratios:
    num_epochs = 20
    learning_rate = 0.0001  # 0.001 works slightly better, sometimes 0.0001 better, depends on settings and datasets

    if dataset == 'P12':
        d_static = 9
        d_inp = 36
        static_info = 1
    elif dataset == 'P19':
        d_static = 6
        d_inp = 34
        static_info = 1
    elif dataset == 'eICU':
        d_static = 399
        d_inp = 14
        static_info = 1
    elif dataset == 'PAM':
        d_static = 0
        d_inp = 17
        static_info = None
    elif dataset == 'CD':
        d_static = 1  # sex only
        d_inp = 17    # 17 biomarkers
        static_info = 1

    d_ob = 4
    d_model = d_inp * d_ob
    nhid = 2 * d_model
    nlayers = 2
    nhead = 2
    dropout = 0.2

    if dataset == 'P12':
        max_len = 215
        n_classes = 2
    elif dataset == 'P19':
        max_len = 60
        n_classes = 2
    elif dataset == 'eICU':
        max_len = 300
        n_classes = 2
    elif dataset == 'PAM':
        max_len = 600
        n_classes = 8
    elif dataset == 'CD':
        max_len = 50  # Maximum time steps per patient
        n_classes = 2

    aggreg = 'mean'

    MAX = 100

    n_runs = 1
    n_splits = 5
    subset = False

    acc_arr = np.zeros((n_splits, n_runs))
    auprc_arr = np.zeros((n_splits, n_runs))
    auroc_arr = np.zeros((n_splits, n_runs))
    precision_arr = np.zeros((n_splits, n_runs))
    recall_arr = np.zeros((n_splits, n_runs))
    F1_arr = np.zeros((n_splits, n_runs))
    
    for k in tqdm(range(n_splits), desc="Processing splits", unit="split"):
        split_idx = k + 1
        print(f'\nSplit id: {split_idx}')
        
        if dataset == 'P12':
            if subset == True:
                split_path = '/splits/phy12_split_subset' + str(split_idx) + '.npy'
            else:
                split_path = '/splits/phy12_split' + str(split_idx) + '.npy'
        elif dataset == 'P19':
            split_path = '/splits/phy19_split' + str(split_idx) + '_new.npy'
        elif dataset == 'eICU':
            split_path = '/splits/eICU_split' + str(split_idx) + '.npy'
        elif dataset == 'PAM':
            split_path = '/splits/PAM_split_' + str(split_idx) + '.npy'
        elif dataset == 'CD':
            split_path = '/splits/cd_split' + str(split_idx) + '.npy'

        # prepare the data:
        # If using cached PSV dataset for P12, set environment variables expected by get_data_split
        if dataset == 'P12' and (getattr(args, 'use_cached_dataset', False) or os.environ.get('USE_CACHED_DATASET', '0').lower() in ('1', 'true', 'yes')):
            os.environ['USE_CACHED_DATASET'] = '1'
            os.environ['CACHED_PSV_DIR'] = str(getattr(args, 'cached_dataset_dir', '/tmp'))
            split_pkl = str(getattr(args, 'split_pkl_path', 'P12_data_splits/split_{split_idx}.pkl')).replace('{split_idx}', str(split_idx))
            os.environ['SPLIT_PKL_PATH'] = split_pkl
            os.environ['LOS_THRESHOLD_DAYS'] = str(getattr(args, 'los_threshold_days', 3))

        Ptrain, Pval, Ptest, ytrain, yval, ytest = get_data_split(base_path, split_path, split_type=split, reverse=reverse,
                                                                  baseline=baseline, dataset=dataset,
                                                                  predictive_label=args.predictive_label)
        print(len(Ptrain), len(Pval), len(Ptest), len(ytrain), len(yval), len(ytest))
        # max_samples = 100
        # if len(Ptrain) > max_samples:
        #     print(f"Limiting data to {max_samples} samples each for quick test")
        #     # Take first 100 samples from each split
        #     Ptrain = Ptrain[:max_samples]
        #     Pval = Pval[:max_samples]
        #     Ptest = Ptest[:max_samples]
        #     ytrain = ytrain[:max_samples]
        #     yval = yval[:max_samples]
        #     ytest = ytest[:max_samples]
        #     print(f"After limiting: train={len(Ptrain)}, val={len(Pval)}, test={len(Ptest)}")

        if dataset == 'P12' or dataset == 'P19' or dataset == 'eICU' or dataset == 'CD':
            # Support variable-length sequences: compute stats on padded training set and align val/test to train length
            F = Ptrain[0]['arr'].shape[1]
            D = len(Ptrain[0]['extended_static'])
            T_max = max(len(p['arr']) for p in Ptrain)

            Ptrain_tensor_tmp = np.zeros((len(Ptrain), T_max, F))
            Ptrain_static_tensor_tmp = np.zeros((len(Ptrain), D))
            for i in range(len(Ptrain)):
                arr_i = Ptrain[i]['arr']
                t_i = min(len(arr_i), T_max)
                if t_i > 0:
                    Ptrain_tensor_tmp[i, :t_i, :] = arr_i[:t_i]
                Ptrain_static_tensor_tmp[i] = Ptrain[i]['extended_static']

            mf, stdf = getStats(Ptrain_tensor_tmp)
            ms, ss = getStats_static(Ptrain_static_tensor_tmp, dataset=dataset)

            Ptrain_tensor, Ptrain_static_tensor, Ptrain_time_tensor, ytrain_tensor = tensorize_normalize(Ptrain, ytrain, mf,
                                                                                                         stdf, ms, ss)
            Pval_tensor, Pval_static_tensor, Pval_time_tensor, yval_tensor = tensorize_normalize(Pval, yval, mf, stdf, ms, ss)
            Ptest_tensor, Ptest_static_tensor, Ptest_time_tensor, ytest_tensor = tensorize_normalize(Ptest, ytest, mf, stdf, ms,
                                                                                              ss)

            # Align val/test time length to training time length
            train_T = Ptrain_tensor.size(1)

            def pad_or_truncate_time_dim(tensor_3d, target_T):
                current_T = tensor_3d.size(1)
                if current_T == target_T:
                    return tensor_3d
                if current_T < target_T:
                    pad_shape = (tensor_3d.size(0), target_T - current_T, tensor_3d.size(2))
                    pad = torch.zeros(pad_shape, dtype=tensor_3d.dtype)
                    return torch.cat([tensor_3d, pad], dim=1)
                else:
                    return tensor_3d[:, :target_T, :]

            def pad_or_truncate_time_dim_2d(tensor_3d_time, target_T):
                current_T = tensor_3d_time.size(1)
                if current_T == target_T:
                    return tensor_3d_time
                if current_T < target_T:
                    pad_shape = (tensor_3d_time.size(0), target_T - current_T, tensor_3d_time.size(2))
                    pad = torch.zeros(pad_shape, dtype=tensor_3d_time.dtype)
                    return torch.cat([tensor_3d_time, pad], dim=1)
                else:
                    return tensor_3d_time[:, :target_T, :]

            Pval_tensor = pad_or_truncate_time_dim(Pval_tensor, train_T)
            Ptest_tensor = pad_or_truncate_time_dim(Ptest_tensor, train_T)
            Pval_time_tensor = pad_or_truncate_time_dim_2d(Pval_time_tensor, train_T)
            Ptest_time_tensor = pad_or_truncate_time_dim_2d(Ptest_time_tensor, train_T)
        elif dataset == 'PAM':
            # PAM processing (unchanged from original)
            Ptrain_tensor = np.zeros((len(Ptrain), max_len, d_inp))
            Pval_tensor = np.zeros((len(Pval), max_len, d_inp))
            Ptest_tensor = np.zeros((len(Ptest), max_len, d_inp))

            for i in range(len(Ptrain)):
                Ptrain_tensor[i] = Ptrain[i]['arr']
            for i in range(len(Pval)):
                Pval_tensor[i] = Pval[i]['arr']
            for i in range(len(Ptest)):
                Ptest_tensor[i] = Ptest[i]['arr']

            mf, stdf = getStats(Ptrain_tensor)
            Ptrain_tensor, _, Ptrain_time_tensor, ytrain_tensor = tensorize_normalize_other(Ptrain, ytrain, mf, stdf)
            Pval_tensor, _, Pval_time_tensor, yval_tensor = tensorize_normalize_other(Pval, yval, mf, stdf)
            Ptest_tensor, _, Ptest_time_tensor, ytest_tensor = tensorize_normalize_other(Ptest, ytest, mf, stdf)

        # Model initialization and training
        global_structure = torch.ones(d_inp, d_inp)

        # remove part of variables in validation and test set
        if missing_ratio > 0:
            num_all_features = int(Pval_tensor.shape[2] / 2)
            num_missing_features = round(missing_ratio * num_all_features)
            if feature_removal_level == 'sample':
                for i, patient in enumerate(Pval_tensor):
                    idx = np.random.choice(num_all_features, num_missing_features, replace=False)
                    patient[:, idx] = torch.zeros(Pval_tensor.shape[1], num_missing_features)
                    Pval_tensor[i] = patient
                for i, patient in enumerate(Ptest_tensor):
                    idx = np.random.choice(num_all_features, num_missing_features, replace=False)
                    patient[:, idx] = torch.zeros(Ptest_tensor.shape[1], num_missing_features)
                    Ptest_tensor[i] = patient
            elif feature_removal_level == 'set':
                density_score_indices = np.load('./baselines/saved/IG_density_scores_' + dataset + '.npy', allow_pickle=True)[:, 0]
                idx = density_score_indices[:num_missing_features].astype(int)
                Pval_tensor[:, :, idx] = torch.zeros(Pval_tensor.shape[0], Pval_tensor.shape[1], num_missing_features)
                Ptest_tensor[:, :, idx] = torch.zeros(Ptest_tensor.shape[0], Ptest_tensor.shape[1], num_missing_features)

        Ptrain_tensor = Ptrain_tensor.permute(1, 0, 2)
        Pval_tensor = Pval_tensor.permute(1, 0, 2)
        Ptest_tensor = Ptest_tensor.permute(1, 0, 2)

        Ptrain_time_tensor = Ptrain_time_tensor.squeeze(2).permute(1, 0)
        Pval_time_tensor = Pval_time_tensor.squeeze(2).permute(1, 0)
        Ptest_time_tensor = Ptest_time_tensor.squeeze(2).permute(1, 0)

        # Use training sequence length for model init to match ob-propagation in/out dims
        train_seq_len = Ptrain_tensor.size(0)

        for m in tqdm(range(n_runs), desc=f"Runs for split {split_idx}", unit="run", leave=False):
            print(f'- - Run {m + 1} - -')

            if dataset == 'P12' or dataset == 'P19' or dataset == 'eICU' or dataset == 'CD':
                model = Raindrop_v2(d_inp, d_model, nhead, nhid, nlayers, dropout, train_seq_len,
                                    d_static, MAX, 0.5, aggreg, n_classes, global_structure,
                                    sensor_wise_mask=sensor_wise_mask)
            elif dataset == 'PAM':
                model = Raindrop_v2(d_inp, d_model, nhead, nhid, nlayers, dropout, train_seq_len,
                                    d_static, MAX, 0.5, aggreg, n_classes, global_structure,
                                    sensor_wise_mask=sensor_wise_mask, static=False)

            # Check if CUDA is available
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
            print(f"Using device: {device}")
            
            model = model.to(device)

            criterion = torch.nn.CrossEntropyLoss().to(device)
            optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1,
                                                                   patience=1, threshold=0.0001, threshold_mode='rel',
                                                                   cooldown=0, min_lr=1e-8, eps=1e-08)

            idx_0 = np.where(ytrain == 0)[0]
            idx_1 = np.where(ytrain == 1)[0]

            if dataset == 'P12' or dataset == 'P19' or dataset == 'eICU' or dataset == 'CD':
                strategy = 2
            elif dataset == 'PAM':
                strategy = 3

            # print("idx_0 ===>", idx_0.shape)
            # print("idx_1 ===>", idx_1.shape)
            # print()
            n0, n1 = len(idx_0), len(idx_1)
            expanded_idx_1 = np.concatenate([idx_1, idx_1, idx_1], axis=0)
            expanded_n1 = len(expanded_idx_1)

            batch_size = 16
            if strategy == 1:
                n_batches = 10
            elif strategy == 2:
                K0 = n0 // int(batch_size / 2)
                K1 = expanded_n1 // int(batch_size / 2)
                n_batches = np.min([K0, K1])
            elif strategy == 3:
                n_batches = 30

            best_aupr_val = best_auc_val = 0.0
            best_loss_val = 100.0
            print('Stop epochs: %d, Batches/epoch: %d, Total batches: %d' % (num_epochs, n_batches, num_epochs * n_batches))

            start = time.time()
            # print("n_batches ===>", n_batches)
            # print()
            for epoch in tqdm(range(num_epochs), desc=f"Training (Split {split_idx})", unit="epoch"):
                model.train()

                if strategy == 2:
                    np.random.shuffle(expanded_idx_1)
                    I1 = expanded_idx_1
                    np.random.shuffle(idx_0)
                    I0 = idx_0

                # Progress bar for batches within each epoch
                batch_pbar = tqdm(range(n_batches), desc=f"Epoch {epoch+1}/{num_epochs}", leave=False, unit="batch")

                for n in batch_pbar:
                    if strategy == 1:
                        idx = random_sample(idx_0, idx_1, batch_size)
                    elif strategy == 2:
                        idx0_batch = I0[n * int(batch_size / 2):(n + 1) * int(batch_size / 2)]
                        idx1_batch = I1[n * int(batch_size / 2):(n + 1) * int(batch_size / 2)]
                        idx = np.concatenate([idx0_batch, idx1_batch], axis=0)
                    elif strategy == 3:
                        idx = np.random.choice(list(range(Ptrain_tensor.shape[1])), size=int(batch_size), replace=False)

                    if dataset == 'P12' or dataset == 'P19' or dataset == 'eICU' or dataset == 'CD':
                        P, Ptime, Pstatic, y = Ptrain_tensor[:, idx, :].to(device), Ptrain_time_tensor[:, idx].to(device), \
                                               Ptrain_static_tensor[idx].to(device), ytrain_tensor[idx].squeeze().to(device)
                    elif dataset == 'PAM':
                        P, Ptime, Pstatic, y = Ptrain_tensor[:, idx, :].to(device), Ptrain_time_tensor[:, idx].to(device), \
                                               None, ytrain_tensor[idx].squeeze().to(device)

                    lengths = torch.sum(Ptime > 0, dim=0)
                    
                    outputs, local_structure_regularization, _ = model.forward(P, Pstatic, Ptime, lengths)
                    
                    optimizer.zero_grad()
                    loss = criterion(outputs, y)
                    loss.backward()
                    optimizer.step()

                if dataset == 'P12' or dataset == 'P19' or dataset == 'eICU' or dataset == 'CD':
                    train_probs = torch.squeeze(torch.sigmoid(outputs))
                    train_probs = train_probs.cpu().detach().numpy()
                    train_y = y.cpu().detach().numpy()
                    train_auroc = roc_auc_score(train_y, train_probs[:, 1])
                    train_auprc = average_precision_score(train_y, train_probs[:, 1])
                elif dataset == 'PAM':
                    train_probs = torch.squeeze(nn.functional.softmax(outputs, dim=1))
                    train_probs = train_probs.cpu().detach().numpy()
                    train_y = y.cpu().detach().numpy()
                    train_auroc = roc_auc_score(one_hot(train_y), train_probs)
                    train_auprc = average_precision_score(one_hot(train_y), train_probs)

                if epoch == 0 or epoch == num_epochs - 1:
                    print(confusion_matrix(train_y, np.argmax(train_probs, axis=1), labels=[0, 1]))

                """Validation"""
                model.eval()
                if epoch == 0 or epoch % 1 == 0:
                    with torch.no_grad():
                        print(f"\nValidating epoch {epoch+1}/{num_epochs}...")
                        out_val = evaluate_standard(model, Pval_tensor, Pval_time_tensor, Pval_static_tensor, static=static_info)
                        out_val = out_val.detach().cpu()
                        
                        # For binary classification, we need to apply sigmoid and get the positive class probability
                        out_val_sigmoid = torch.sigmoid(out_val)
                        out_val_np = out_val_sigmoid.numpy()

                        # Ensure yval is properly flattened for AUC calculation
                        yval_flat = yval.ravel()  # This will flatten any multi-dimensional array
                        
                        # Calculate validation loss using the raw logits
                        # Move out_val back to device for loss calculation
                        out_val_device = out_val.to(device)
                        val_loss = criterion(out_val_device, torch.from_numpy(yval_flat).long().to(device))

                        if dataset == 'P12' or dataset == 'P19' or dataset == 'eICU' or dataset == 'CD':
                            auc_val = roc_auc_score(yval_flat, out_val_np[:, 1])
                            aupr_val = average_precision_score(yval_flat, out_val_np[:, 1])
                        elif dataset == 'PAM':
                            auc_val = roc_auc_score(one_hot(yval), out_val)
                            aupr_val = average_precision_score(one_hot(yval), out_val)

                        print("Validation: Epoch %d,  val_loss:%.4f, aupr_val: %.2f, auc_val: %.2f" % (epoch,
                                                                                                        val_loss.item(),
                                                                                                        aupr_val * 100,
                                                                                                        auc_val * 100))

                        scheduler.step(aupr_val)
                        if auc_val > best_auc_val:
                            best_auc_val = auc_val
                            print(
                                "**[S] Epoch %d, aupr_val: %.4f, auc_val: %.4f **" % (
                                epoch, aupr_val * 100, auc_val * 100))
                            torch.save(model.state_dict(), model_path + arch + '_' + str(split_idx) + '.pt')

            end = time.time()
            time_elapsed = end - start
            print('Total Time elapsed: %.3f mins' % (time_elapsed / 60.0))

            """testing"""
            print(f"\nTesting split {split_idx}...")
            model.load_state_dict(torch.load(model_path + arch + '_' + str(split_idx) + '.pt'))
            model.eval()

            with torch.no_grad():
                out_test = evaluate(model, Ptest_tensor, Ptest_time_tensor, Ptest_static_tensor, n_classes=n_classes, static=static_info).numpy()
                ypred = np.argmax(out_test, axis=1)

                denoms = np.sum(np.exp(out_test), axis=1).reshape((-1, 1))
                probs = np.exp(out_test) / denoms

                # Ensure ytest is properly flattened
                ytest_flat = ytest.ravel()
                acc = np.sum(ytest_flat == ypred.ravel()) / ytest_flat.shape[0]

                if dataset == 'P12' or dataset == 'P19' or dataset == 'eICU' or dataset == 'CD':
                    auc = roc_auc_score(ytest_flat, probs[:, 1])
                    aupr = average_precision_score(ytest_flat, probs[:, 1])
                elif dataset == 'PAM':
                    auc = roc_auc_score(one_hot(ytest), probs)
                    aupr = average_precision_score(one_hot(ytest), probs)
                    precision = precision_score(ytest, ypred, average='macro', )
                    recall = recall_score(ytest, ypred, average='macro', )
                    F1 = f1_score(ytest, ypred, average='macro', )
                    print('Testing: Precision = %.2f | Recall = %.2f | F1 = %.2f' % (precision * 100, recall * 100, F1 * 100))

                print('Testing: AUROC = %.2f | AUPRC = %.2f | Accuracy = %.2f' % (auc * 100, aupr * 100, acc * 100))
                print('classification report', classification_report(ytest_flat, ypred))
                print(confusion_matrix(ytest_flat, ypred, labels=list(range(n_classes))))

            # store
            acc_arr[k, m] = acc * 100
            auprc_arr[k, m] = aupr * 100
            auroc_arr[k, m] = auc * 100
            if dataset == 'PAM':
                precision_arr[k, m] = precision * 100
                recall_arr[k, m] = recall * 100
                F1_arr[k, m] = F1 * 100

    # pick best performer for each split based on max AUPRC
    idx_max = np.argmax(auprc_arr, axis=1)
    acc_vec = [acc_arr[k, idx_max[k]] for k in range(n_splits)]
    auprc_vec = [auprc_arr[k, idx_max[k]] for k in range(n_splits)]
    auroc_vec = [auroc_arr[k, idx_max[k]] for k in range(n_splits)]
    if dataset == 'PAM':
        precision_vec = [precision_arr[k, idx_max[k]] for k in range(n_splits)]
        recall_vec = [recall_arr[k, idx_max[k]] for k in range(n_splits)]
        F1_vec = [F1_arr[k, idx_max[k]] for k in range(n_splits)]

    print("missing ratio:{}, split type:{}, reverse:{}, using baseline:{}".format(missing_ratio, split, reverse,
                                                                                  baseline))
    print('args.dataset, args.splittype, args.reverse, args.withmissingratio, args.feature_removal_level',
          args.dataset, args.splittype, args.reverse, args.withmissingratio, args.feature_removal_level)

    # display mean and standard deviation
    mean_acc, std_acc = np.mean(acc_vec), np.std(acc_vec)
    mean_auprc, std_auprc = np.mean(auprc_vec), np.std(auprc_vec)
    mean_auroc, std_auroc = np.mean(auroc_vec), np.std(auroc_vec)
    print('------------------------------------------')
    print('Accuracy = %.1f +/- %.1f' % (mean_acc, std_acc))
    print('AUPRC    = %.1f +/- %.1f' % (mean_auprc, std_auprc))
    print('AUROC    = %.1f +/- %.1f' % (mean_auroc, std_auroc))
    if dataset == 'PAM':
        mean_precision, std_precision = np.mean(precision_vec), np.std(precision_vec)
        mean_recall, std_recall = np.mean(recall_vec), np.std(recall_vec)
        mean_F1, std_F1 = np.mean(F1_vec), np.std(F1_vec)
        print('Precision = %.1f +/- %.1f' % (mean_precision, std_precision))
        print('Recall    = %.1f +/- %.1f' % (mean_recall, std_recall))
        print('F1        = %.1f +/- %.1f' % (mean_F1, std_F1))

if __name__ == '__main__':
    print("Raindrop CD - Modified version for CD dataset")
    print("=" * 50) 