import numpy as np

wandb = False

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os
from tqdm import tqdm
from sklearn.metrics import roc_auc_score, classification_report, confusion_matrix, average_precision_score, precision_score, recall_score, f1_score
from baselines.Raindrop.code.baselines.models import TransformerModel, TransformerModel2, SEFT
from baselines.Raindrop.code.baselines.utils_phy12 import *

if wandb:
    import wandb

    # wandb.offline
    # os.environ['WANDB_SILENT']="true"

    wandb.login(key=str('14734fe9c5574e019e8f517149a20d6fe1b2fd0d'))
    config = wandb.config
    # run = wandb.init(project='WBtest', config={'wandb_nb':'wandb_three_in_one_hm'})
    run = wandb.init(project='Raindrop', entity='XZ', config={'wandb_nb':'wandb_three_in_one_hm'})

torch.manual_seed(1)

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='P12', choices=['P12', 'P19', 'eICU', 'PAM', 'CD']) #
parser.add_argument('--splittype', type=str, default='random', choices=['random', 'age', 'gender'], help='only use for P12 and P19')
parser.add_argument('--withmissingratio', default=False, help='if True, missing ratio ranges from 0 to 0.5; if False, missing ratio =0') #
parser.add_argument('--reverse', default=False, help='if True,use female, older for tarining; if False, use female or younger for training') #
parser.add_argument('--feature_removal_level', type=str, default='no_removal', choices=['no_removal', 'set', 'sample'],
                    help='use this only when splittype==random; otherwise, set as no_removal') #
parser.add_argument('--predictive_label', type=str, default='mortality', choices=['mortality', 'LoS'],
                    help='use this only with P12 dataset (mortality or length of stay)')
parser.add_argument('--seed', type=int, default=1, help='Random seed for reproducibility')
parser.add_argument('--quick_test', action='store_true', help='Run with only 100 samples for quick testing')
parser.add_argument('--epochs', type=int, default=20, help='Number of training epochs')
parser.add_argument('--use_cached_dataset', action='store_true', help='Use cached PSV dataset (P12 only)')
parser.add_argument('--cached_dataset_dir', type=str, default='/tmp', help='Directory with cached PSV files (P12)')
parser.add_argument('--split_pkl_path', type=str, default='P12_data_splits/split_{split_idx}.pkl', help='Split PKL path template with {split_idx}')
parser.add_argument('--los_threshold_days', type=int, default=3, help='LoS threshold in days for P12')
args, unknown = parser.parse_known_args()

# Set random seeds for reproducibility
import random
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

print(f'Using seed: {args.seed}')


def one_hot(y_):
    # Function to encode output labels from number indexes
    # e.g.: [[5], [0], [3]] --> [[0, 0, 0, 0, 0, 1], [1, 0, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0]]
    y_ = y_.reshape(len(y_))

    y_ = [int(x) for x in y_]
    n_values = np.max(y_) + 1
    return np.eye(n_values)[np.array(y_, dtype=np.int32)]


arch = 'standard'

model_path = '../../models/'

dataset = args.dataset
print('Dataset used: ', dataset)
print('args.dataset, args.splittype, args.reverse, args.withmissingratio, args.feature_removal_level',
      args.dataset, args.splittype, args.reverse, args.withmissingratio, args.feature_removal_level)

if dataset == 'P12':
    base_path = '/home/dcm.aau.dk/km20bf/Biomarker_FeatureGroup_GNAN/baselines/Raindrop/P12data'
elif dataset == 'P19':
    base_path = '../../P19data'
elif dataset == 'eICU':
    base_path = '../../eICUdata'
elif dataset == 'PAM':
    base_path = '../../PAMdata'
elif dataset == 'CD':
    base_path = '/ngc/projects2/predict_r/research/projects/0054_GNAN_biomarker_trajectories/Raindrop/CDdata'

# Device detection
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

baseline = True
split = args.splittype  # possible values: 'random', 'age', 'gender' ('age' not possible for dataset 'eICU')
reverse = args.reverse  # False or True
feature_removal_level = args.feature_removal_level  # 'set', 'sample'

if args.withmissingratio == True:
    missing_ratios = [0.1, 0.2, 0.3, 0.4, 0.5]
else:
    missing_ratios = [0]

for missing_ratio in missing_ratios:
    num_epochs = args.epochs

    if dataset == 'P12' or dataset == 'P19' or dataset == 'eICU' or dataset == 'CD':
        learning_rate = 0.001
    elif dataset == 'PAM':
        learning_rate = 0.01

    if dataset == 'P12':
        d_static = 9
        d_inp = 36
        static_info = 1
    elif dataset == 'P19':
        d_static = 6
        d_inp = 34
        static_info = 1
    elif dataset == 'eICU':
        d_static = 399
        d_inp = 14
        static_info = 1
    elif dataset == 'PAM':
        d_static = 0
        d_inp = 17
        static_info = None
    elif dataset == 'CD':
        d_static = 1  # sex only
        d_inp = 17    # 17 biomarkers
        static_info = 1

    d_model = d_inp
    nhid = 2 * d_model

    nlayers = 2
    nhead = 1

    dropout = 0.3

    if dataset == 'P12':
        max_len = 215
        n_classes = 2
    elif dataset == 'P19':
        max_len = 60
        n_classes = 2
    elif dataset == 'eICU':
        max_len = 300
        n_classes = 2
    elif dataset == 'PAM':
        max_len = 600
        n_classes = 8
    elif dataset == 'CD':
        max_len = 50  # Maximum time steps per patient
        n_classes = 2  # Binary classification

    aggreg = 'mean'

    MAX = 100

    n_runs = 1
    n_splits = 5
    subset = False

    acc_arr = np.zeros((n_splits, n_runs))
    auprc_arr = np.zeros((n_splits, n_runs))
    auroc_arr = np.zeros((n_splits, n_runs))
    precision_arr = np.zeros((n_splits, n_runs))
    recall_arr = np.zeros((n_splits, n_runs))
    F1_arr = np.zeros((n_splits, n_runs))
    
    # Progress bar for splits
    split_pbar = tqdm(range(n_splits), desc="Processing splits", unit="split")
    for k in split_pbar:
        split_idx = k + 1
        print('Split id: %d' % split_idx)

        if dataset == 'P12':
            if subset == True:
                split_path = '/splits/phy12_split_subset' + str(split_idx) + '.npy'
            else:
                split_path = '/splits/phy12_split' + str(split_idx) + '.npy'
        elif dataset == 'P19':
            split_path = '/splits/phy19_split' + str(split_idx) + '_new.npy'
        elif dataset == 'eICU':
            split_path = '/splits/eICU_split' + str(split_idx) + '.npy'
        elif dataset == 'PAM':
            split_path = '/splits/PAM_split_' + str(split_idx) + '.npy'
        elif dataset == 'CD':
            split_path = '/splits/cd_split' + str(split_idx) + '.npy'

        # If using cached PSV dataset for P12, set environment variables expected by get_data_split
        if dataset == 'P12' and getattr(args, 'use_cached_dataset', False):
            os.environ['USE_CACHED_DATASET'] = '1'
            os.environ['CACHED_PSV_DIR'] = str(getattr(args, 'cached_dataset_dir', '/tmp'))
            split_pkl = str(getattr(args, 'split_pkl_path', 'P12_data_splits/split_{split_idx}.pkl')).replace('{split_idx}', str(split_idx))
            os.environ['SPLIT_PKL_PATH'] = split_pkl
            os.environ['LOS_THRESHOLD_DAYS'] = str(getattr(args, 'los_threshold_days', 3))

        Ptrain, Pval, Ptest, ytrain, yval, ytest = get_data_split(base_path, split_path, split_type=split,
                                                                  reverse=reverse, baseline=baseline, dataset=dataset,
                                                                  predictive_label=args.predictive_label)
        print(len(Ptrain), len(Pval), len(Ptest), len(ytrain), len(yval), len(ytest))

        # ===== Sanity A) Verify we're actually training on the sparse dataset =====
        import json

        def _density_stats(Plist):
            """Return (mean_density, median_density, count, shape_hint) for a list of PT dicts."""
            if len(Plist) == 0:
                return 0.0, 0.0, 0, None
            densities = []
            # all arrs should be [T, F]
            shape_hint = tuple(Plist[0]['arr'].shape)
            for p in Plist:
                a = p['arr']
                total = a.size
                nz = np.count_nonzero(a)
                densities.append(nz / total if total > 0 else 0.0)
            densities = np.array(densities, dtype=float)
            return float(densities.mean()), float(np.median(densities)), len(Plist), shape_hint

        def _peek_patient(p):
            a = p['arr']
            t, f = a.shape
            nonzeros = int(np.count_nonzero(a))
            # first feature nonzero indices (up to 10)
            idxs = np.where(a[:, 0] != 0)[0][:10].tolist()
            return {
                "shape": [int(t), int(f)],
                "nonzeros_total": nonzeros,
                "first_feature_first10_nonzero_idx": idxs
            }

        train_mean, train_med, train_n, train_shape = _density_stats(Ptrain)
        val_mean,   val_med,   val_n,   val_shape   = _density_stats(Pval)
        test_mean,  test_med,  test_n,  test_shape  = _density_stats(Ptest)

        print("=== DATA DENSITY CHECK (pre-normalization) ===")
        print(f"base_path: {base_path}")
        print(f"Split {split_idx}: "
            f"TRAIN mean dens={train_mean:.4f} med={train_med:.4f} n={train_n} shape_hint={train_shape}")
        print(f"Split {split_idx}: "
            f"VAL   mean dens={val_mean:.4f}   med={val_med:.4f}   n={val_n}   shape_hint={val_shape}")
        print(f"Split {split_idx}: "
            f"TEST  mean dens={test_mean:.4f}  med={test_med:.4f}  n={test_n}  shape_hint={test_shape}")

        # Optional: peek at the first patient of each split for quick manual inspection
        if train_n > 0:
            peek_train = _peek_patient(Ptrain[0])
            print(f"[peek train p0] shape={peek_train['shape']} nonzeros={peek_train['nonzeros_total']} "
                f"first_feat_nonzero_idx[:10]={peek_train['first_feature_first10_nonzero_idx']}")
        if val_n > 0:
            peek_val = _peek_patient(Pval[0])
            print(f"[peek  val  p0] shape={peek_val['shape']} nonzeros={peek_val['nonzeros_total']} "
                f"first_feat_nonzero_idx[:10]={peek_val['first_feature_first10_nonzero_idx']}")
        if test_n > 0:
            peek_test = _peek_patient(Ptest[0])
            print(f"[peek test  p0] shape={peek_test['shape']} nonzeros={peek_test['nonzeros_total']} "
                f"first_feat_nonzero_idx[:10]={peek_test['first_feature_first10_nonzero_idx']}")

        # Persist a JSON snapshot so you can compare across runs/splits
        summary_path = f"sparsity_summary_split{split_idx}.json"
        with open(summary_path, "w") as f:
            json.dump({
                "base_path": base_path,
                "split_idx": int(split_idx),
                "train": {"mean_density": train_mean, "median_density": train_med,
                        "n": int(train_n), "shape_hint": train_shape},
                "val":   {"mean_density": val_mean,   "median_density": val_med,
                        "n": int(val_n),   "shape_hint": val_shape},
                "test":  {"mean_density": test_mean,  "median_density": test_med,
                        "n": int(test_n),  "shape_hint": test_shape},
                "peek_train_p0": peek_train if train_n > 0 else None,
                "peek_val_p0":   peek_val   if val_n   > 0 else None,
                "peek_test_p0":  peek_test  if test_n  > 0 else None,
            }, f, indent=2)
        print(f"[saved] {summary_path}")
        # ===== End Sanity A) =====
        
        # Quick test mode: limit to 100 samples each
        if args.quick_test:
            max_samples = 100
            if len(Ptrain) > max_samples:
                print(f"Quick test mode: Limiting data to {max_samples} samples each")
                Ptrain = Ptrain[:max_samples]
                Pval = Pval[:max_samples]
                Ptest = Ptest[:max_samples]
                ytrain = ytrain[:max_samples]
                yval = yval[:max_samples]
                ytest = ytest[:max_samples]
                print(f"After limiting: train={len(Ptrain)}, val={len(Pval)}, test={len(Ptest)}")

        if dataset == 'P12' or dataset == 'P19' or dataset == 'eICU' or dataset == 'CD':
            # Support variable-length sequences by padding training set to T_max for statistics
            F = Ptrain[0]['arr'].shape[1]
            D = len(Ptrain[0]['extended_static'])
            T_max = max(len(p['arr']) for p in Ptrain)

            Ptrain_tensor_tmp = np.zeros((len(Ptrain), T_max, F))
            Ptrain_static_tensor_tmp = np.zeros((len(Ptrain), D))

            for i in range(len(Ptrain)):
                arr_i = Ptrain[i]['arr']
                t_i = min(len(arr_i), T_max)
                if t_i > 0:
                    Ptrain_tensor_tmp[i, :t_i, :] = arr_i[:t_i]
                Ptrain_static_tensor_tmp[i] = Ptrain[i]['extended_static']

            mf, stdf = getStats(Ptrain_tensor_tmp)
            ms, ss = getStats_static(Ptrain_static_tensor_tmp, dataset=dataset)

            Ptrain_tensor, Ptrain_static_tensor, Ptrain_time_tensor, ytrain_tensor = tensorize_normalize(Ptrain, ytrain, mf,
                                                                                                         stdf, ms, ss)
            Pval_tensor, Pval_static_tensor, Pval_time_tensor, yval_tensor = tensorize_normalize(Pval, yval, mf, stdf, ms, ss)
            Ptest_tensor, Ptest_static_tensor, Ptest_time_tensor, ytest_tensor = tensorize_normalize(Ptest, ytest, mf, stdf, ms,
                                                                                              ss)
            print(Ptrain_tensor.shape, Ptrain_static_tensor.shape, Ptrain_time_tensor.shape, ytrain_tensor.shape)
        elif dataset == 'PAM':
            T, F = Ptrain[0].shape
            D = 1

            Ptrain_tensor = Ptrain
            Ptrain_static_tensor = np.zeros((len(Ptrain), D))

            mf, stdf = getStats(Ptrain)
            Ptrain_tensor, Ptrain_static_tensor, Ptrain_time_tensor, ytrain_tensor = tensorize_normalize_other(Ptrain, ytrain, mf, stdf)
            Pval_tensor, Pval_static_tensor, Pval_time_tensor, yval_tensor = tensorize_normalize_other(Pval, yval, mf, stdf)
            Ptest_tensor, Ptest_static_tensor, Ptest_time_tensor, ytest_tensor = tensorize_normalize_other(Ptest, ytest, mf, stdf)

        # remove part of variables in validation and test set
        if missing_ratio > 0:
            num_all_features = int(Pval_tensor.shape[2] / 2)
            num_missing_features = round(missing_ratio * num_all_features)
            if feature_removal_level == 'sample':
                for i, patient in enumerate(Pval_tensor):
                    idx = np.random.choice(num_all_features, num_missing_features, replace=False)
                    patient[:, idx] = torch.zeros(Pval_tensor.shape[1], num_missing_features)  # values
                    patient[:, idx + num_all_features] = torch.zeros(Pval_tensor.shape[1],
                                                                     num_missing_features)  # masks
                    Pval_tensor[i] = patient
                for i, patient in enumerate(Ptest_tensor):
                    idx = np.random.choice(num_all_features, num_missing_features, replace=False)
                    patient[:, idx] = torch.zeros(Ptest_tensor.shape[1], num_missing_features)  # values
                    patient[:, idx + num_all_features] = torch.zeros(Ptest_tensor.shape[1],
                                                                     num_missing_features)  # masks
                    Ptest_tensor[i] = patient
            elif feature_removal_level == 'set':
                density_score_indices = np.load('saved/IG_density_scores_' + dataset + '.npy', allow_pickle=True)[:, 0]
                idx = density_score_indices[:num_missing_features].astype(int)
                Pval_tensor[:, :, idx] = torch.zeros(Pval_tensor.shape[0], Pval_tensor.shape[1], num_missing_features)  # values
                Pval_tensor[:, :, idx + num_all_features] = torch.zeros(Pval_tensor.shape[0], Pval_tensor.shape[1], num_missing_features)  # masks
                Ptest_tensor[:, :, idx] = torch.zeros(Ptest_tensor.shape[0], Ptest_tensor.shape[1], num_missing_features)  # values
                Ptest_tensor[:, :, idx + num_all_features] = torch.zeros(Ptest_tensor.shape[0], Ptest_tensor.shape[1], num_missing_features)  # masks

        Ptrain_tensor = Ptrain_tensor.permute(1, 0, 2)
        Pval_tensor = Pval_tensor.permute(1, 0, 2)
        Ptest_tensor = Ptest_tensor.permute(1, 0, 2)

        Ptrain_time_tensor = Ptrain_time_tensor.squeeze(2).permute(1, 0)
        Pval_time_tensor = Pval_time_tensor.squeeze(2).permute(1, 0)
        Ptest_time_tensor = Ptest_time_tensor.squeeze(2).permute(1, 0)

        # Align model max_len to the actual training sequence length
        train_seq_len = Ptrain_tensor.size(0)

        for m in tqdm(range(n_runs), desc=f"Runs for split {split_idx}", unit="run", leave=False):
            print('- - Run %d - -' % (m + 1))
            if dataset == 'P12' or dataset == 'P19' or dataset == 'eICU' or dataset == 'CD':
                model = SEFT(d_inp, d_model, nhead, nhid, nlayers, dropout, train_seq_len,
                                          d_static, MAX, 0.5, aggreg, n_classes)
            elif dataset == 'PAM':
                model = SEFT(d_inp, d_model, nhead, nhid, nlayers, dropout, train_seq_len,
                             d_static, MAX, 0.5, aggreg, n_classes, static=False)

            model = model.to(device)

            criterion = torch.nn.CrossEntropyLoss().to(device)

            optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1,
                                                                   patience=1, threshold=0.001, threshold_mode='rel',
                                                                   cooldown=0, min_lr=1e-8, eps=1e-08)

            idx_0 = np.where(ytrain == 0)[0]
            idx_1 = np.where(ytrain == 1)[0]

            if dataset == 'P12' or dataset == 'P19' or dataset == 'eICU' or dataset == 'CD':
                strategy = 2
            elif dataset == 'PAM':
                strategy = 3

            n0, n1 = len(idx_0), len(idx_1)
            expanded_idx_1 = np.concatenate([idx_1, idx_1, idx_1], axis=0)
            expanded_n1 = len(expanded_idx_1)

            batch_size = 16
            if strategy == 1:
                n_batches = 10  # number of batches to process per epoch
            elif strategy == 2:
                K0 = n0 // int(batch_size / 2)
                K1 = expanded_n1 // int(batch_size / 2)
                n_batches = np.min([K0, K1])
            elif strategy == 3:
                n_batches = 30

            # Adjust batch size for small datasets
            if n_batches == 0:
                batch_size = 16
                if strategy == 2:
                    K0 = n0 // int(batch_size / 2)
                    K1 = expanded_n1 // int(batch_size / 2)
                    n_batches = np.min([K0, K1])
                elif strategy == 3:
                    n_batches = 30
                # Ensure at least 1 batch
                if n_batches == 0:
                    n_batches = 1

            best_aupr_val = best_auc_val = 0.0
            print('Stop epochs: %d, Batches/epoch: %d, Total batches: %d' % (num_epochs, n_batches, num_epochs * n_batches))

            start = time.time()
            if wandb:
                wandb.watch(model)
            
            # Progress bar for epochs
            epoch_pbar = tqdm(range(num_epochs), desc=f"Training (Split {split_idx})", unit="epoch")
            for epoch in epoch_pbar:
                model.train()

                if strategy == 2:
                    np.random.shuffle(expanded_idx_1)
                    I1 = expanded_idx_1
                    np.random.shuffle(idx_0)
                    I0 = idx_0

                # Progress bar for batches within each epoch
                batch_pbar = tqdm(range(n_batches), desc=f"Epoch {epoch+1}/{num_epochs}", leave=False, unit="batch")
                for n in batch_pbar:
                    if strategy == 1:
                        idx = random_sample(idx_0, idx_1, batch_size)
                    elif strategy == 2:
                        """In each batch=128, 64 positive samples, 64 negative samples"""
                        idx0_batch = I0[n * int(batch_size / 2):(n + 1) * int(batch_size / 2)]
                        idx1_batch = I1[n * int(batch_size / 2):(n + 1) * int(batch_size / 2)]
                        idx = np.concatenate([idx0_batch, idx1_batch], axis=0)
                    elif strategy == 3:
                        idx = np.random.choice(list(range(Ptrain_tensor.shape[1])), size=int(batch_size), replace=False)
                        # idx = random_sample_8(ytrain, batch_size)   # to balance dataset

                    if dataset == 'P12' or dataset == 'P19' or dataset == 'eICU' or dataset == 'CD':
                        P, Ptime, Pstatic, y = Ptrain_tensor[:, idx, :].to(device), Ptrain_time_tensor[:, idx].to(device), \
                                               Ptrain_static_tensor[idx].to(device), ytrain_tensor[idx].to(device)
                    elif dataset == 'PAM':
                        P, Ptime, Pstatic, y = Ptrain_tensor[:, idx, :].to(device), Ptrain_time_tensor[:, idx].to(device), \
                                               None, ytrain_tensor[idx].to(device)

                    lengths = torch.sum(Ptime > 0, dim=0)

                    outputs = evaluate_standard(model, P, Ptime, Pstatic, static=static_info)

                    optimizer.zero_grad()
                    loss = criterion(outputs, y.squeeze())
                    loss.backward()
                    optimizer.step()

                    # Calculate training metrics for the last batch
                    if n == n_batches - 1:  # Only on the last batch of each epoch
                        if dataset == 'P12' or dataset == 'P19' or dataset == 'eICU' or dataset == 'CD':
                            train_probs = torch.squeeze(torch.sigmoid(outputs))
                            train_probs = train_probs.cpu().detach().numpy()
                            train_probs = np.nan_to_num(train_probs)

                            train_y = y.cpu().detach().numpy()
                            train_auroc = roc_auc_score(train_y, train_probs[:, 1])
                            train_auprc = average_precision_score(train_y, train_probs[:, 1])
                        elif dataset == 'PAM':
                            train_probs = torch.squeeze(nn.functional.softmax(outputs, dim=1))
                            train_probs = train_probs.cpu().detach().numpy()
                            train_probs = np.nan_to_num(train_probs)
                            train_y = y.cpu().detach().numpy()
                            train_auroc = roc_auc_score(one_hot(train_y), train_probs)
                            train_auprc = average_precision_score(one_hot(train_y), train_probs)

                        if wandb:
                            wandb.log({ "train_loss": loss.item(), "train_auprc": train_auprc, "train_auroc": train_auroc})

                """Validation"""
                model.eval()
                if epoch == 0 or epoch % 1 == 0:
                    with torch.no_grad():
                        print(f"\nValidating epoch {epoch+1}/{num_epochs}...")
                        out_val = evaluate_standard(model, Pval_tensor, Pval_time_tensor, Pval_static_tensor, static=static_info)
                        
                        # For validation loss, use raw logits
                        val_loss = criterion(out_val, torch.from_numpy(yval.ravel()).long().to(device))
                        
                        # For metrics, use sigmoid output
                        out_val_sigmoid = torch.squeeze(torch.sigmoid(out_val))
                        out_val_np = out_val_sigmoid.detach().cpu().numpy()
                        out_val_np = np.nan_to_num(out_val_np)

                        if dataset == 'P12' or dataset == 'P19' or dataset == 'eICU' or dataset == 'CD':
                            auc_val = roc_auc_score(yval.ravel(), out_val_np[:, 1])
                            aupr_val = average_precision_score(yval.ravel(), out_val_np[:, 1])
                        elif dataset == 'PAM':
                            auc_val = roc_auc_score(one_hot(yval), out_val_np)
                            aupr_val = average_precision_score(one_hot(yval), out_val_np)

                        print("Validation: Epoch %d,  val_loss:%.4f, aupr_val: %.2f, auc_val: %.2f" % (epoch,
                          val_loss.item(), aupr_val * 100, auc_val * 100))

                        if wandb:
                            wandb.log({"val_loss": val_loss.item(), "val_auprc": aupr_val, "val_auroc": auc_val})

                        scheduler.step(aupr_val)
                        if aupr_val > best_aupr_val:
                            best_aupr_val = aupr_val
                            print(
                                "**[S] Epoch %d, aupr_val: %.4f, auc_val: %.4f **" % (epoch, aupr_val * 100, auc_val * 100))
                            # torch.save(model.state_dict(), model_path + arch + '_' + str(split_idx) + '.pt')

            end = time.time()
            time_elapsed = end - start
            print('Total Time elapsed: %.3f mins' % (time_elapsed / 60.0))

            """Testing"""
            print(f"\nTesting split {split_idx}...")
            # model.load_state_dict(torch.load(model_path + arch + '_' + str(split_idx) + '.pt'))
            model.eval()

            with torch.no_grad():
                out_test = evaluate(model, Ptest_tensor, Ptest_time_tensor, Ptest_static_tensor, n_classes=n_classes, static=static_info).numpy()
                out_test = np.nan_to_num(out_test)
                ypred = np.argmax(out_test, axis=1)

                denoms = np.sum(np.exp(out_test), axis=1).reshape((-1, 1))
                probs = np.exp(out_test) / denoms

                acc = np.sum(ytest.ravel() == ypred.ravel()) / ytest.shape[0]

                if dataset == 'P12' or dataset == 'P19' or dataset == 'eICU' or dataset == 'CD':
                    auc = roc_auc_score(ytest.ravel(), probs[:, 1])
                    aupr = average_precision_score(ytest.ravel(), probs[:, 1])
                elif dataset == 'PAM':
                    auc = roc_auc_score(one_hot(ytest), probs)
                    aupr = average_precision_score(one_hot(ytest), probs)
                    precision = precision_score(ytest, ypred, average='macro', )
                    recall = recall_score(ytest, ypred, average='macro', )
                    F1 = f1_score(ytest, ypred, average='macro', )
                    print('Testing: Precision = %.2f | Recall = %.2f | F1 = %.2f' % (precision * 100, recall * 100, F1 * 100))

                print('Testing: AUROC = %.2f | AUPRC = %.2f | Accuracy = %.2f' % (auc * 100, aupr * 100, acc * 100))
                
                # Ensure ytest and ypred are properly flattened for classification_report
                ytest_flat = ytest.ravel()
                ypred_flat = ypred.ravel()
                print('classification report', classification_report(ytest_flat, ypred_flat))
                print(confusion_matrix(ytest_flat, ypred_flat, labels=list(range(n_classes))))

            # store
            acc_arr[k, m] = acc * 100
            auprc_arr[k, m] = aupr * 100
            auroc_arr[k, m] = auc * 100
            if dataset == 'PAM':
                precision_arr[k, m] = precision * 100
                recall_arr[k, m] = recall * 100
                F1_arr[k, m] = F1 * 100

    # pick best performer for each split based on max AUPRC
    idx_max = np.argmax(auprc_arr, axis=1)
    acc_vec = [acc_arr[k, idx_max[k]] for k in range(n_splits)]
    auprc_vec = [auprc_arr[k, idx_max[k]] for k in range(n_splits)]
    auroc_vec = [auroc_arr[k, idx_max[k]] for k in range(n_splits)]
    if dataset == 'PAM':
        precision_vec = [precision_arr[k, idx_max[k]] for k in range(n_splits)]
        recall_vec = [recall_arr[k, idx_max[k]] for k in range(n_splits)]
        F1_vec = [F1_arr[k, idx_max[k]] for k in range(n_splits)]

    print("split type:{}, reverse:{}, using baseline:{}, missing ratio:{}".format(split, reverse, baseline, missing_ratio))
    print('args.dataset, args.splittype, args.reverse, args.withmissingratio, args.feature_removal_level',
          args.dataset, args.splittype, args.reverse, args.withmissingratio, args.feature_removal_level)

    # display mean and standard deviation
    mean_acc, std_acc = np.mean(acc_vec), np.std(acc_vec)
    mean_auprc, std_auprc = np.mean(auprc_vec), np.std(auprc_vec)
    mean_auroc, std_auroc = np.mean(auroc_vec), np.std(auroc_vec)
    print('------------------------------------------')
    print('Accuracy = %.1f +/- %.1f' % (mean_acc, std_acc))
    print('AUPRC    = %.1f +/- %.1f' % (mean_auprc, std_auprc))
    print('AUROC    = %.1f +/- %.1f' % (mean_auroc, std_auroc))
    if dataset == 'PAM':
        mean_precision, std_precision = np.mean(precision_vec), np.std(precision_vec)
        mean_recall, std_recall = np.mean(recall_vec), np.std(recall_vec)
        mean_F1, std_F1 = np.mean(F1_vec), np.std(F1_vec)
        print('Precision = %.1f +/- %.1f' % (mean_precision, std_precision))
        print('Recall    = %.1f +/- %.1f' % (mean_recall, std_recall))
        print('F1        = %.1f +/- %.1f' % (mean_F1, std_F1))

    # Mark the run as finished
    if wandb:
        wandb.finish()

    # save in numpy file
    # np.save('./results/' + arch + '_phy12_setfunction.npy', [acc_vec, auprc_vec, auroc_vec])
