# -*- coding:utf-8 -*-
import os
import math
import argparse
import warnings
import time
import utils as utils
import tqdm as tqdm
from sklearn.metrics import roc_auc_score, classification_report, confusion_matrix, average_precision_score, precision_score, recall_score, f1_score

warnings.filterwarnings("ignore")

# Argument parser for command-line options
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='physionet', choices=['P12', 'P19', 'physionet', 'mimic3', "PAM"])
parser.add_argument('--cuda', type=str, default='1')
parser.add_argument('--epochs', type=int, default=10)  #
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--lr', type=float, default=1e-3)

parser.add_argument('--history', type=int, default=48, help="number of hours (months for ushcn and ms for activity) as historical window")
parser.add_argument('--nhead', type=int, default=1, help="heads in Transformer")
parser.add_argument('--nlayer', type=int, default=1, help="# of layer in GAT")
parser.add_argument('-ps', '--patch_size', type=float, default=6, help="window size for a patch")
parser.add_argument('--stride', type=float, default=6, help="period stride for patch sliding")
parser.add_argument('-hd', '--hid_dim', type=int, default=64, help="Hidden dim of node embeddings")
parser.add_argument('--alpha', type=float, default=1, help="Proportion of Time decay")
parser.add_argument('--res', type=float, default=1, help="Res")

parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--logmode', type=str, default="a", help='File mode of logging.')
parser.add_argument('--model', type=str)
parser.add_argument('--irr_emb', action='store_true')
parser.add_argument('--mode', type=str, default=False)


args, unknown = parser.parse_known_args()
print(args)

# Set CUDA environment variables
os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
#from baselines.models.RTI_class import *
from model.ours_class_selfn import *
#from model.hipatch import *
from utils import *

# Set device for training
args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(True)

# Model save path
model_path = './models_classification/'
if not os.path.exists(model_path):
    os.mkdir(model_path)

# Load command-line arguments
dataset = args.dataset
batch_size = args.batch_size
learning_rate = args.lr
num_epochs = args.epochs

# Recursive function to determine the layer of patches
def layer_of_patches(n_patch):
    if n_patch == 1:
        return 1
    if n_patch % 2 == 0:
        return 1 + layer_of_patches(n_patch / 2)
    else:
        return layer_of_patches(n_patch + 1)
# Function to count model parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


print('Dataset used: ', dataset)
# Dataset specific parameters
if dataset == 'P12':
    base_path = '../data/P12'
    start = 0
    variables_num = 36
    d_static = 9
    args.d_static = 9
    timestamp_num = 215
    n_class = 2
    args.n_class = 2
    args.history = 48
elif dataset == 'physionet':
    base_path = '../data/physionet'
    start = 4
    variables_num = 36
    d_static = 9
    args.d_static = 9
    timestamp_num = 215
    n_class = 2
    args.n_class = 2
    args.history = 48
elif dataset == 'P19':
    base_path = '../data/P19'
    d_static = 6
    args.d_static = 6
    variables_num = 34
    timestamp_num = 60
    n_class = 2
    args.n_class = 2
    args.history = 60    
elif dataset == 'PAM':
    base_path = '../data/PAM'
    d_static = 0
    args.d_static = 0
    variables_num = 17
    timestamp_num = 600
    n_class = 8
    args.n_class = 8
    args.history = 60
elif dataset == 'mimic3':
    base_path = '../data/mimic3'
    start = 0
    d_static = 0
    args.d_static = 0
    variables_num = 16
    timestamp_num = 292
    n_class = 2
    args.n_class = 2
    args.history = 48

# Evaluation metrics 저장용 배열 확장
acc_arr = []
precision_arr = []
recall_arr = []
f1_arr = []
auroc_arr = []
auprc_arr = []

if args.irr_emb:
    log_path = "logs/{}_{}_{}_{}hdims_{}nlayers_{}nheads.log".format(args.dataset, args.model, args.mode, args.hid_dim, args.nlayer, args.nhead)    
else:
    log_path = "logs/{}_{}_{}hdims_{}nlayers_{}nheads.log".format(args.dataset, args.model, args.hid_dim, args.nlayer, args.nhead)    
if not os.path.exists("logs/"):
    utils.makedirs("logs/")
logger = utils.get_logger(logpath=log_path, filepath=os.path.abspath(__file__), mode=args.logmode)
logger.info(args)

# Run five experiments
for k in range(5):
    # Set different random seed
    torch.manual_seed(k)
    torch.cuda.manual_seed(k)
    np.random.seed(k)
    split_idx = k + 1
    logger.info(f"Using split: {split_idx}")

    if dataset == 'P12':
        split_path = '/splits/phy12_split' + str(split_idx) + '.npy'
    elif dataset == 'physionet':
        split_path = '/splits/phy12_split' + str(split_idx) + '.npy'
    elif dataset == 'P19':
        split_path = '/splits/phy19_split' + str(split_idx) + '_new.npy'
    elif dataset == 'PAM':
        split_path = '/splits/PAM_split_' + str(split_idx) + '.npy'
    elif dataset == 'mimic3':
        split_path = ''

    # Prepare data and split the dataset
    Ptrain, Pval, Ptest, ytrain, yval, ytest = get_data_split(base_path, split_path, dataset=dataset)
    print(len(Ptrain), len(Pval), len(Ptest), len(ytrain), len(yval), len(ytest))

    args.ndim = variables_num
    args.npatch = int(math.ceil((args.history - args.patch_size) / args.stride)) + 1
    args.patch_layer = layer_of_patches(args.npatch)
    args.scale_patch_size = args.patch_size / args.history
    args.task = 'classification'

    # Normalize data and extract required model inputs
    if dataset == 'P19' or dataset == 'physionet':
        T, F = Ptrain[0]['arr'].shape
        D = len(Ptrain[0]['extended_static'])
        Ptrain_tensor = np.zeros((len(Ptrain), T, F))
        Ptrain_static_tensor = np.zeros((len(Ptrain), D))

        for i in range(len(Ptrain)):
            Ptrain_tensor[i] = Ptrain[i]['arr']
            Ptrain_static_tensor[i] = Ptrain[i]['extended_static']

        # Calculate mean and standard deviation of variables in the training set
        mf, stdf = getStats(Ptrain_tensor)
        ms, ss = getStats_static(Ptrain_static_tensor, dataset=dataset)

        if args.irr_emb or args.model in ['patchtst', 'timexer', 'patchmixer', 'tsmixer']:
            Ptrain_tensor, Ptrain_mask_tensor, Ptrain_static_tensor, Ptrain_time_tensor, ytrain_tensor, maxlen \
                = tensorize_normalize_extract_feature_patch(Ptrain, ytrain, mf, stdf, ms, ss, args)
            Pval_tensor, Pval_mask_tensor, Pval_static_tensor, Pval_time_tensor, yval_tensor, maxlen \
                = tensorize_normalize_extract_feature_patch(Pval, yval, mf, stdf, ms, ss, args)
            Ptest_tensor, Ptest_mask_tensor, Ptest_static_tensor, Ptest_time_tensor, ytest_tensor, maxlen \
                = tensorize_normalize_extract_feature_patch(Ptest, ytest, mf, stdf, ms, ss, args)
        else:
            Ptrain_tensor, Ptrain_mask_tensor, Ptrain_static_tensor, Ptrain_time_tensor, ytrain_tensor, maxlen \
                = tensorize_normalize_extract_feature(Ptrain, ytrain, mf, stdf, ms, ss, args)
            Pval_tensor, Pval_mask_tensor, Pval_static_tensor, Pval_time_tensor, yval_tensor, maxlen \
                = tensorize_normalize_extract_feature(Pval, yval, mf, stdf, ms, ss, args)
            Ptest_tensor, Ptest_mask_tensor, Ptest_static_tensor, Ptest_time_tensor, ytest_tensor, maxlen \
                = tensorize_normalize_extract_feature(Ptest, ytest, mf, stdf, ms, ss, args)
    
    elif dataset == 'P12':
        T, F = Ptrain[0]['arr'].shape
        D = len(Ptrain[0]['extended_static'])
        Ptrain_tensor = np.zeros((len(Ptrain), T, F))
        Ptrain_static_tensor = np.zeros((len(Ptrain), D))

        for i in range(len(Ptrain)):
            Ptrain_tensor[i] = Ptrain[i]['arr']
            Ptrain_static_tensor[i] = Ptrain[i]['extended_static']

        # Calculate mean and standard deviation of variables in the training set
        mf, stdf = getStats(Ptrain_tensor)
        ms, ss = getStats_static(Ptrain_static_tensor, dataset=dataset)

        if args.irr_emb:
            Ptrain_tensor, Ptrain_mask_tensor, Ptrain_static_tensor, Ptrain_time_tensor, ytrain_tensor, maxlen \
                = tensorize_normalize_extract_feature_patch(Ptrain, ytrain, mf, stdf, ms, ss, args)
            Pval_tensor, Pval_mask_tensor, Pval_static_tensor, Pval_time_tensor, yval_tensor, maxlen \
                = tensorize_normalize_extract_feature_patch(Pval, yval, mf, stdf, ms, ss, args)
            Ptest_tensor, Ptest_mask_tensor, Ptest_static_tensor, Ptest_time_tensor, ytest_tensor, maxlen \
                = tensorize_normalize_extract_feature_patch(Ptest, ytest, mf, stdf, ms, ss, args)
        elif args.model in ['patchtst', 'timexer', 'patchmixer', 'tsmixer']:
            Ptrain_tensor, Ptrain_mask_tensor, Ptrain_static_tensor, Ptrain_time_tensor, ytrain_tensor, maxlen \
                = tensorize_normalize_extract_feature_patchtst(Ptrain, ytrain, mf, stdf, ms, ss, args)
            Pval_tensor, Pval_mask_tensor, Pval_static_tensor, Pval_time_tensor, yval_tensor, maxlen \
                = tensorize_normalize_extract_feature_patchtst(Pval, yval, mf, stdf, ms, ss, args)
            Ptest_tensor, Ptest_mask_tensor, Ptest_static_tensor, Ptest_time_tensor, ytest_tensor, maxlen \
                = tensorize_normalize_extract_feature_patchtst(Ptest, ytest, mf, stdf, ms, ss, args)
        else:
            Ptrain_tensor, Ptrain_mask_tensor, Ptrain_static_tensor, Ptrain_time_tensor, ytrain_tensor, maxlen \
                = tensorize_normalize_extract_feature(Ptrain, ytrain, mf, stdf, ms, ss, args)
            Pval_tensor, Pval_mask_tensor, Pval_static_tensor, Pval_time_tensor, yval_tensor, maxlen \
                = tensorize_normalize_extract_feature(Pval, yval, mf, stdf, ms, ss, args)
            Ptest_tensor, Ptest_mask_tensor, Ptest_static_tensor, Ptest_time_tensor, ytest_tensor, maxlen \
                = tensorize_normalize_extract_feature(Ptest, ytest, mf, stdf, ms, ss, args)
    
    elif dataset == 'PAM':
        T, F = Ptrain[0].shape
        D = 1

        Ptrain_tensor = Ptrain
        Ptrain_static_tensor = np.zeros((len(Ptrain), D))

        mf, stdf = getStats(Ptrain)
        
        if args.irr_emb or args.model in ['patchtst', 'timexer', 'patchmixer', 'tsmixer']:
            Ptrain_tensor, Ptrain_mask_tensor, Ptrain_static_tensor, Ptrain_time_tensor, ytrain_tensor, maxlen \
                = tensorize_normalize_extract_feature_pam_patch(Ptrain, ytrain, mf, stdf, args)
            Pval_tensor, Pval_mask_tensor, Pval_static_tensor, Pval_time_tensor, yval_tensor, maxlen \
                = tensorize_normalize_extract_feature_pam_patch(Pval, yval, mf, stdf, args)
            Ptest_tensor, Ptest_mask_tensor, Ptest_static_tensor, Ptest_time_tensor, ytest_tensor, maxlen \
                = tensorize_normalize_extract_feature_pam_patch(Ptest, ytest, mf, stdf, args)
        else:
            Ptrain_tensor, Ptrain_mask_tensor, Ptrain_static_tensor, Ptrain_time_tensor, ytrain_tensor, maxlen \
                = tensorize_normalize_extract_feature_pam(Ptrain, ytrain, mf, stdf, args)
            Pval_tensor, Pval_mask_tensor, Pval_static_tensor, Pval_time_tensor, yval_tensor, maxlen \
                = tensorize_normalize_extract_feature_pam(Pval, yval, mf, stdf, args)
            Ptest_tensor, Ptest_mask_tensor, Ptest_static_tensor, Ptest_time_tensor, ytest_tensor, maxlen \
                = tensorize_normalize_extract_feature_pam(Ptest, ytest, mf, stdf, args)
            
    elif dataset == 'mimic3':
        T, F = timestamp_num, variables_num
        Ptrain_tensor = np.zeros((len(Ptrain), T, F))
        for i in range(len(Ptrain)):
            Ptrain_tensor[i][:Ptrain[i][4]] = Ptrain[i][2]

        # Calculate mean and standard deviation of variables in the training set
        mf, stdf = getStats(Ptrain_tensor)

        Ptrain_tensor, Ptrain_mask_tensor, Ptrain_static_tensor, Ptrain_time_tensor, ytrain_tensor, maxlen \
            = tensorize_normalize_exact_feature_mimic3_patch(Ptrain, ytrain, mf, stdf, args)
        Pval_tensor, Pval_mask_tensor, Pval_static_tensor, Pval_time_tensor, yval_tensor, maxlen \
            = tensorize_normalize_exact_feature_mimic3_patch(Pval, yval, mf, stdf, args)
        Ptest_tensor, Ptest_mask_tensor, Ptest_static_tensor, Ptest_time_tensor, ytest_tensor, maxlen \
            = tensorize_normalize_exact_feature_mimic3_patch(Ptest, ytest, mf, stdf, args)
    
    if args.irr_emb:
        args.maxlen = 0
    else:
        args.maxlen = maxlen
        
    # Load the model
    if args.model in ['hipatch']:
        model = hipatch(args).to(args.device)
    else:
        model = Model(args).to(args.device)
        
    logger.info(model)
    logger.info(f'parameters: {count_parameters(model)}')

    # Cross-entropy loss, Adam optimizer
    criterion = torch.nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # Upsample minority class
    if args.dataset != 'PAM':
        idx_0 = np.where(ytrain == 0)[0]
        idx_1 = np.where(ytrain == 1)[0]
        n0, n1 = len(idx_0), len(idx_1)
        expanded_idx_1 = np.concatenate([idx_1, idx_1, idx_1], axis=0)
        expanded_n1 = len(expanded_idx_1)
        K0 = n0 // int(batch_size / 2)
        K1 = expanded_n1 // int(batch_size / 2)
        n_batches = np.min([K0, K1])
    else:
        all_indices = np.arange(len(ytrain))
        n_batches = len(all_indices) // batch_size
    
    best_val_epoch = 0
    best_loss_val = 100.0
    save_time = None  # Initialize save_time

    # logger.info('Stop epochs: %d, Batches/epoch: %d, Total batches: %d' % (num_epochs, n_batches, num_epochs * n_batches))

    # Training loop
    for epoch in range(num_epochs):
        if epoch - best_val_epoch > 10:
            break
        """Training"""
        model.train()

        #np.random.shuffle(all_indices)
        # Shuffle data
        if args.dataset != 'PAM':
            np.random.shuffle(expanded_idx_1)
            I1 = expanded_idx_1
            np.random.shuffle(idx_0)
            I0 = idx_0
        else:
            np.random.shuffle(all_indices)
        
        train_time = 0.
        st = time.time()
        for n in range(n_batches):        
            # Get current batch data
            if args.dataset != 'PAM':
                idx0_batch = I0[n * int(batch_size / 2):(n + 1) * int(batch_size / 2)]
                idx1_batch = I1[n * int(batch_size / 2):(n + 1) * int(batch_size / 2)]
                idx = np.concatenate([idx0_batch, idx1_batch], axis=0)
            else:
                idx = all_indices[n * batch_size : (n+1) * batch_size]
            
            P, P_mask, P_static, P_time, y = \
                Ptrain_tensor[idx].cuda(), Ptrain_mask_tensor[idx].cuda(), Ptrain_static_tensor[idx].cuda() if d_static != 0 else None, \
                    Ptrain_time_tensor[idx].cuda(), ytrain_tensor[idx].cuda()

            # Backward pass
            outputs = model.classification(P, P_time, P_mask, P_static)
            optimizer.zero_grad()
            loss = criterion(outputs, y)
            loss.backward()
            optimizer.step()
        et = time.time()
        train_time += (et - st)
        logger.info('training time per epoch :{:.3f}s'.format(train_time))
        
        """Validation"""
        model.eval()
        with torch.no_grad():
            val_time = 0.
            st = time.time()
            out_val = evaluate_model_patch(model, Pval_tensor, Pval_mask_tensor, Pval_static_tensor, Pval_time_tensor, n_classes=n_class, batch_size=batch_size)
            et = time.time()
            val_time += (et - st)
            logger.info("val time per epoch :{:.2f}s".format(val_time))
            
            # Calculate validation loss using logits (before softmax)
            val_loss = criterion(out_val.cuda(), torch.from_numpy(yval.squeeze(1)).long().cuda())
            
            # Apply softmax for predictions and metrics
            out_val = torch.softmax(out_val, dim=1).detach().cpu().numpy()
            y_val_pred = np.argmax(out_val, axis=1)

            acc_val = np.sum(yval.ravel() == y_val_pred.ravel()) / yval.shape[0]
            if args.n_class == 2:
                auc_val = roc_auc_score(yval, out_val[:, 1])
                aupr_val = average_precision_score(yval, out_val[:, 1])
                logger.info(
                    "Validation: Epoch %d, train_loss:%.4f, val_loss:%.4f, acc_val: %.2f, aupr_val: %.2f, auc_val: %.2f" %
                    (epoch, loss.item(), val_loss.item(), acc_val * 100, aupr_val * 100, auc_val * 100))
            else:
                prec_val = precision_score(yval.ravel(), y_val_pred, average='macro')
                rec_val = recall_score(yval.ravel(), y_val_pred, average='macro')
                f1_val = f1_score(yval.ravel(), y_val_pred, average='macro')
                logger.info(
                    "Validation: Epoch %d, train_loss:%.4f, val_loss:%.4f, acc_val: %.2f, prec_val: %.2f, rec_val: %.2f, f1_val: %.2f" %
                    (epoch, loss.item(), val_loss.item(), acc_val * 100, prec_val * 100, rec_val * 100, f1_val*100))

            # Save the model weights with the best loss on the validation set
            if val_loss.item() < best_loss_val:
                best_val_epoch = epoch
                best_loss_val = val_loss.item()
                save_time = str(int(time.time()))
                torch.save(model.state_dict(), model_path + '_' + dataset + '_' + save_time + '_' + str(k) + '.pt')
                
                # if dataset == 'PAM':
                #     if not os.path.exists("tsne/"):
                #         utils.makedirs("tsne/")
                #     if args.irr_emb:
                #         tsne_filename = "tsne/{}_{}_{}_{}hdims_{}nlayers_{}nheads_train_epoch{}.png".format(
                #             args.dataset, args.model, args.mode, args.hid_dim, args.nlayer, args.nhead, best_val_epoch)
                #     else:
                #         tsne_filename = "tsne/{}_{}_{}hdims_{}nlayers_{}nheads_train_epoch{}.png".format(
                #             args.dataset, args.model, args.hid_dim, args.nlayer, args.nhead, best_val_epoch)
                #     t_sne_visualize(model, ytrain_tensor.detach().cpu().numpy(), Ptrain_tensor, Ptrain_mask_tensor, Ptrain_static_tensor, 
                #                     Ptrain_time_tensor, batch_size=batch_size, save_path=tsne_filename)
                #     if args.irr_emb:
                #         tsne_filename = "tsne/{}_{}_{}_{}hdims_{}nlayers_{}nheads_test_epoch{}.png".format(
                #             args.dataset, args.model, args.mode, args.hid_dim, args.nlayer, args.nhead, best_val_epoch)
                #     else:
                #         tsne_filename = "tsne/{}_{}_{}hdims_{}nlayers_{}nheads_test_epoch{}.png".format(
                #             args.dataset, args.model, args.hid_dim, args.nlayer, args.nhead, best_val_epoch)
                #     y_test = ytest.copy()
                #     t_sne_visualize(model, y_test, Ptest_tensor, Ptest_mask_tensor, Ptest_static_tensor, Ptest_time_tensor,
                #                     batch_size=batch_size, save_path=tsne_filename)

    """Testing"""
    if save_time is None:
        logger.warning(f"No model was saved during training for fold {k+1}. Skipping testing.")
        continue
    
    model.eval()
    model_filename = model_path + '_' + dataset + '_' + save_time + '_' + str(k) + '.pt'
    if not os.path.exists(model_filename):
        logger.error(f"Model file not found: {model_filename}. Skipping testing.")
        continue
    model.load_state_dict(torch.load(model_filename))
    logger.info(f"Loaded model for testing: {model_filename}")
    with torch.no_grad():
        test_time = 0.
        st = time.time()
        out_test = evaluate_model_patch(model, Ptest_tensor, Ptest_mask_tensor, Ptest_static_tensor, Ptest_time_tensor, n_classes=n_class, batch_size=batch_size)
        et = time.time()
        test_time += (et - st)
        logger.info("test time per epoch :{:.2f}s".format(test_time))
        
        probs = torch.softmax(out_test, dim=1).detach().cpu().numpy()
        ypred = np.argmax(probs, axis=1)
        y_test_real = ytest.ravel()
        
        acc = np.sum(y_test_real == ypred) / y_test_real.shape[0]
        if args.n_class == 2:
            auc = roc_auc_score(y_test_real, probs[:, 1])
            aupr = average_precision_score(y_test_real, probs[:, 1])
            logger.info('Testing: AUROC = %.2f | AUPRC = %.2f | Accuracy = %.2f' % (auc * 100, aupr * 100, acc * 100))
        else:
            prec = precision_score(y_test_real, ypred, average='macro')
            rec = recall_score(y_test_real, ypred, average='macro')
            f1 = f1_score(y_test_real, ypred, average='macro')
            logger.info('Testing: PRECISION = %.2f | RECALL = %.2f | F1 = %.2f | Accuracy = %.2f' % (prec * 100, rec * 100, f1*100, acc * 100))

    acc_arr.append(acc * 100)
    if args.n_class == 2:
        auprc_arr.append(aupr * 100)
        auroc_arr.append(auc * 100)
    else:
        precision_arr.append(prec * 100)
        recall_arr.append(rec * 100)
        f1_arr.append(f1 * 100)

# Display the mean and standard deviation of five runs
mean_acc, std_acc = np.mean(acc_arr), np.std(acc_arr)
mean_auprc, std_auprc = np.mean(auprc_arr), np.std(auprc_arr)
mean_auroc, std_auroc = np.mean(auroc_arr), np.std(auroc_arr)
mean_prec, std_prec = np.mean(precision_arr), np.std(precision_arr)
mean_rec, std_rec = np.mean(recall_arr), np.std(recall_arr)
mean_f1, std_f1 = np.mean(f1_arr), np.std(f1_arr)
logger.info('------------------------------------------')
print('args.dataset', args.dataset)
if args.n_class == 2:
    logger.info('AUPRC    = %.1f±%.1f' % (mean_auprc, std_auprc))
    logger.info('AUROC    = %.1f±%.1f' % (mean_auroc, std_auroc))
else:
    logger.info('PRECISION    = %.1f±%.1f' % (mean_prec, std_prec))
    logger.info('RECALL    = %.1f±%.1f' % (mean_rec, std_rec))
    logger.info('F1    = %.1f±%.1f' % (mean_f1, std_f1))
logger.info('Accuracy = %.1f±%.1f' % (mean_acc, std_acc))