#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import builtins
import math
import os
import random
import shutil
import time
import warnings

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.multiprocessing as mp
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
import sys
import os
from sksurv.metrics import concordance_index_censored



torch.autograd.set_detect_anomaly(True)

import MSRL_model as CL_model
from loader import *
from utils import *
from utils_survival import *



parser = argparse.ArgumentParser(description='MAE pre-training')
parser.add_argument('root', metavar='DIR',
                    help='path to dataset')
parser.add_argument('--train', default='../data/EGFR_train_5folds.csv', type=str, metavar='PATH',
                    help='path to train data_root (default: none)')
parser.add_argument('--test', default='../data/EGFR_train_5folds.csv', type=str, metavar='PATH',
                    help='path to test data_root (default: none)')
parser.add_argument('--test-only', action='store_true',
                    help='')
parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
                    help='number of data loading workers (default: 32)')
parser.add_argument('--epochs', default=100, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=32, type=int,
                    metavar='N',
                    help='mini-batch size (default: 64), this is the total '
                         'batch size of all GPUs on the current node when '
                         'using Data Parallel or Distributed Data Parallel')
parser.add_argument("--warmup-epochs", default=50, type=int,
                        help="Number of epochs for the linear learning-rate warm up.")
parser.add_argument('--lr', '--learning-rate', default=1e-4, type=float,
                    metavar='LR', help='initial (base) learning rate', dest='lr')
parser.add_argument('--min-lr', default=1e-6, type=float,
                    help='minimum learning rate')
parser.add_argument('--momentum', default=0.999, type=float, metavar='M',
                    help='momentum')
parser.add_argument('--weight-decay', type=float, default=0.04, help="""Initial value of the
                        weight decay. With ViT, a smaller value at the beginning of training works well.""")
parser.add_argument('--weight-decay-end', type=float, default=0.4, help="""Final value of the
                        weight decay. We use a cosine schedule for WD and using a larger decay by
                        the end of training improves performance for ViTs.""")
parser.add_argument('-p', '--print-freq', default=10, type=int,
                    metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                    help='evaluate model on validation set')
parser.add_argument('--world-size', default=-1, type=int,
                    help='number of nodes for distributed training')
parser.add_argument('--rank', default=-1, type=int,
                    help='node rank for distributed training')
parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
                    help='url used to set up distributed training')
parser.add_argument('--dist-backend', default='nccl', type=str,
                    help='distributed backend')
parser.add_argument('--seed', default=None, type=int,
                    help='seed for initializing training. ')
parser.add_argument('--gpu', default=None, type=int,
                    help='GPU id to use.')
parser.add_argument('--multiprocessing-distributed', action='store_true',
                    help='Use multi-processing distributed training to launch '
                         'N processes per node, which has N GPUs. This is the '
                         'fastest way to use PyTorch for either single node or '
                         'multi node data parallel training')
parser.add_argument('--weighted-sample', action='store_true',
                    help='')
parser.add_argument('--num-classes', default=3, type=int,
                    help='number of classes (default: 5)')
parser.add_argument('--fold', default=0, type=int,
                    help='fold of cv')

# additional configs:
parser.add_argument('--pretrained', default='', type=str,
                    help='path to simsiam pretrained checkpoint')
parser.add_argument('--slide-encoder-pth', default='', type=str,
                    help='path to simsiam pretrained checkpoint')
parser.add_argument('--init-graph', default='', type=str,
                    help='path to init_graph checkpoint')
parser.add_argument('--rna-features', default='', type=str,
                    help='path to init_graph checkpoint')
parser.add_argument('--lars', action='store_true',
                    help='Use LARS')

# gigapath specific configs:
parser.add_argument('--model-arch', default='gigapath_slide_enc12l768d_base', type=str, metavar='MODEL',
                    help='Name of model to train')
parser.add_argument('--max-size', default=2048, type=int,
                    help='images input size')
parser.add_argument('--patch-in-chans', default=1536, type=int,
                    help='in_chans')
parser.add_argument('--rna-in-chans', default=4999, type=int,
                    help='in_chans')

# graph learner configs:
parser.add_argument('--in-chans', default=768, type=int,
                    help='in_chans')
parser.add_argument('--leaner-layers', default=2, type=int,
                    help='number of layers in the graph learner')
parser.add_argument('--leaner-k', default=5, type=int,
                    help='number of knn neighbors')
parser.add_argument('--dropedge-rate', default=0.5, type=float,
                    help='dropedge rate')
parser.add_argument('--sparse', action='store_true',
                    help='use sparse matrix')
parser.add_argument('--activation-learner', default='relu', type=str,
                    help='activation function in graph learner')

# graph encoder configs:
parser.add_argument('--nlayers', default=3, type=int,
                    help='number of layers in the graph encoder')
parser.add_argument('--hidden-dim', default=384, type=int,
                    help='hidden dim in the graph encoder')
parser.add_argument('--rep-dim', default=256, type=int,
                    help='rep dim in the graph encoder')
parser.add_argument('--dropout', default=.5, type=float,
                    help='dropout rate in the graph encoder')


parser.add_argument('--norm_pix_loss', action='store_true',
                    help='Use (per-patch) normalized pixels as targets for computing loss')
parser.set_defaults(norm_pix_loss=False)

parser.add_argument('--save-path', default='../exp_results/bs1_075_vit_L_p16/',
                    help='Path where save the model checkpoint')
parser.add_argument('--reg', type=float, default=1e-5, help='L2-regularization weight decay (default: 1e-5)')
parser.add_argument('--alpha_surv', type=float, default=0.0, help='How much to weigh uncensored patients')
parser.add_argument('--reg_type', type=str, choices=['None', 'omic', 'pathomic'], default='None', help='Which network submodules to apply L1-Regularization (default: None)')
parser.add_argument('--lambda_reg', type=float, default=1e-4, help='L1-Regularization Strength (Default 1e-4)')
parser.add_argument('--task_type', type=str, default='survival', help='Which network submodules to apply L1-Regularization (default: None)')
parser.add_argument('--bag_loss', type=str, choices=['svm', 'ce', 'ce_surv', 'nll_surv', 'cox_surv'], default='nll_surv', help='slide-level classification loss function (default: ce)')


def main():
    args = parser.parse_args()
    args.checkpoint = os.path.join(args.save_path, "checkpoints")
    args.checkpoint_matrix = os.path.join(args.save_path, "checkpoint-matrix")
    args.checkpoint_roc = os.path.join(args.save_path, "checkpoint_roc")
    args.checkpoint_csv = args.save_path
    print('============================================================')
    print(f"FOLD {args.fold}")

    if args.checkpoint is not None:
        os.makedirs(args.checkpoint, exist_ok=True)
    if args.checkpoint_matrix:
        os.makedirs(args.checkpoint_matrix, exist_ok=True)
    if args.checkpoint_roc:
        os.makedirs(args.checkpoint_roc, exist_ok=True)

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    if args.gpu is not None:
        warnings.warn('You have chosen a specific GPU. This will completely '
                      'disable data parallelism.')

    if args.dist_url == "env://" and args.world_size == -1:
        args.world_size = int(os.environ["WORLD_SIZE"])

    args.distributed = args.world_size > 1 or args.multiprocessing_distributed

    # slurmd settings
    args.rank = int(os.environ["SLURM_PROCID"])
    args.world_size = int(os.environ["SLURM_NPROCS"])

    ngpus_per_node = torch.cuda.device_count()
    if args.multiprocessing_distributed:
        # Since we have ngpus_per_node processes per node, the total world_size
        # needs to be adjusted accordingly
        args.world_size = ngpus_per_node * args.world_size
        # Use torch.multiprocessing.spawn to launch distributed processes: the
        # main_worker process function
        mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
    else:
        # Simply call main_worker function
        main_worker(args.gpu, ngpus_per_node, args)


def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu
    # args.gpu = None

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)
        torch.distributed.barrier()

    # suppress printing if not master
    if args.multiprocessing_distributed and args.rank != 0:
        def print_pass(*args):
            pass
        builtins.print = print_pass

    # create model
    if not args.test_only:
        train_dataset = TCGAKDataset_survival(
            args.root,
            args.train,
            set='train',
            max_size=args.max_size,
            init_graph=args.rna_features,
            args=args)
        
        valid_dataset = TCGAKDataset_survival(
            args.root,
            args.test,
            set='test',
            max_size=args.max_size,
            aug=False,
            args=args)

    print("train:", len(train_dataset))
    print("test:", len(valid_dataset))
    args.leaner_k = int(args.leaner_k * (len(train_dataset) / 655))
    model = CL_model.CL_model(args, train_dataset.omic_sizes, slide_encoder_pth=args.slide_encoder_pth, init_graph=args.init_graph, class_num=args.num_classes)


    # load from pre-trained, before DistributedDataParallel constructor
    if args.pretrained:
        if os.path.isfile(args.pretrained):
            print("=> loading checkpoint '{}'".format(args.pretrained))
            loc = 'cuda:{}'.format(args.gpu)
            checkpoint = torch.load(args.pretrained, map_location=loc)
            msg_leaner = model.online_graph_learner.load_state_dict(checkpoint['fused_graph_learner'])
            print("=> missing_keys", msg_leaner.missing_keys)
            print("=> online_graph_learner loaded\n")

            msg_leaner = model.target_omic_adaptor.load_state_dict(checkpoint['omic_adaptor'])
            print("=> missing_keys", msg_leaner.missing_keys)
            print("=> target_omic_adaptor loaded\n")

            msg_leaner = model.target_graph_learner.load_state_dict(checkpoint['fused_graph_learner'])
            print("=> missing_keys", msg_leaner.missing_keys)
            print("=> target_graph_learner loaded\n")

            graph_encoder_state_dict = { k.replace("fused_encoder.", ""): v for k, v in checkpoint['GCL_model'].items()
                                        if k.startswith("fused_encoder.gnn_encoder_layers")
                                        }
            
            msg_encoder = model.target_graph_encoder.load_state_dict(graph_encoder_state_dict)
            print("=> missing_keys", msg_encoder.missing_keys)
            print("=> target_graph_encoder loaded\n")

            msg_encoder = model.online_graph_encoder.load_state_dict(graph_encoder_state_dict)
            print("=> missing_keys", msg_encoder.missing_keys)
            print("=> online_graph_encoder loaded")

            print("=> loaded pre-trained model '{}'".format(args.pretrained))
        else:
            print("=> no checkpoint found at '{}'".format(args.pretrained))


    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()
    print(model)

    

    if args.lars:
        print("=> use LARS optimizer.")
        from apex.parallel.LARC import LARC
        optimizer = LARC(optimizer=optimizer, trust_coefficient=.001, clip=False)

    if args.weighted_sample:
        print('activate weighted sampling')
        if args.distributed:
            train_sampler = DistributedWeightedSampler(
                train_dataset, train_dataset.get_weights(), args.world_size, args.rank)
        else:
            train_sampler = torch.utils.data.sampler.WeightedRandomSampler(
                train_dataset.get_weights(), len(train_dataset), replacement=True
            )
    else:
        if args.distributed:
            train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
        else:
            train_sampler = None

    
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True, sampler=train_sampler)
    
    val_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)
    
    # define loss function (criterion) and optimizer
    params_groups = get_params_groups(model)
    optimizer = torch.optim.AdamW(params_groups)
    lr_schedule = cosine_scheduler(
        args.lr * (args.batch_size * max(args.world_size, 1)) / 32.,  # linear scaling rule
        args.min_lr,
        args.epochs, len(train_loader),
        warmup_epochs=20,
    )
    wd_schedule = cosine_scheduler(
        args.weight_decay,
        args.weight_decay_end,
        args.epochs, len(train_loader),
    )



    if args.bag_loss == 'ce_surv':
        survival_criterion = CrossEntropySurvLoss(alpha=args.alpha_surv)
    elif args.bag_loss == 'nll_surv':
        survival_criterion = NLLSurvLoss(alpha=args.alpha_surv)
    elif args.bag_loss == 'cox_surv':
        survival_criterion = CoxSurvLoss(device=args.gpu)
    else:
        raise NotImplementedError


    criterion = nn.BCELoss().cuda(args.gpu)


    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            if args.gpu is None:
                checkpoint = torch.load(args.resume)
            else:
                # Map model to be loaded to specified single gpu.
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint = torch.load(args.resume, map_location='cpu')
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    recorder = Record_survival(args.checkpoint_csv + 'record.csv')

    if args.evaluate:
        val_loader = torch.utils.data.DataLoader(
            valid_dataset, batch_size=args.batch_size, shuffle=False,
            num_workers=args.workers, pin_memory=True)
        validate(val_loader, model, 'test', args)
        return

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        # train for one epoch
        train_record = train(train_loader, model, survival_criterion, criterion, optimizer, lr_schedule, wd_schedule, epoch, args,reg_fn=None)
        

        # evaluate on validation set
        val_record = validate(val_loader, model, survival_criterion, criterion, optimizer, lr_schedule, wd_schedule, epoch, args,reg_fn=None)
        recorder.update([str(epoch)] + list(train_record) + list(val_record))

        if (epoch+1) % 5 == 0:
            if not args.multiprocessing_distributed or (args.multiprocessing_distributed
                    and args.rank % ngpus_per_node == 0):
                save_checkpoint({
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'optimizer' : optimizer.state_dict(),
                }, False, filename='{}/checkpoint_{:04d}.pth.tar'.format(args.checkpoint, epoch))

        # break


def train(train_loader, model, survival_criterion, criterion, optimizer, lr_schedule, wd_schedule, epoch, args, reg_fn, gc=10):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4f')
    slide_losses = AverageMeter('slide_loss', ':.4f')
    structure_losses = AverageMeter('structure_loss', ':.4f')
    graph_losses = AverageMeter('graph_loss', ':.4f')
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, losses, slide_losses, structure_losses, graph_losses],
        prefix='Train: ')
    model.train()

    train_loss_surv, train_loss = 0., 0.
    all_risk_scores = np.zeros((len(train_loader) * train_loader.batch_size))
    all_censorships = np.zeros((len(train_loader) * train_loader.batch_size))
    all_event_times = np.zeros((len(train_loader) * train_loader.batch_size))

    end = time.time()
    for i, (wsidata1, wsi_pos1, wsidata2, wsi_pos2, rna_features, slide_ids, labels, event_time, c) in enumerate(train_loader):
        # measure data loading time
        load_time = time.time()
        it = len(train_loader) * epoch + i  # global training iteration
        for g_idx, param_group in enumerate(optimizer.param_groups):
            param_group["lr"] = lr_schedule[it]
            if g_idx == 0:  # only the first group is regularized
                param_group["weight_decay"] = wd_schedule[it]
        if args.gpu is not None:
            view1 = wsidata1.float().cuda(args.gpu, non_blocking=True)
            coords1 = wsi_pos1.long().cuda(args.gpu, non_blocking=True)
            view2 = wsidata2.float().cuda(args.gpu, non_blocking=True)
            coords2 = wsi_pos2.long().cuda(args.gpu, non_blocking=True)
            omic_list = [[omic_feature.float().cuda(args.gpu, non_blocking=True) for omic_feature in f] for f in rna_features]
            c = c.float().cuda(args.gpu, non_blocking=True)
            labels = labels.long().cuda(args.gpu, non_blocking=True)

        x_1, sub_A_1, sub_G_1, x_2, sub_A_2, sub_G_2, logits, hazards, S, Y_hat = model(view1, coords1, view2, coords2, omic_list, slide_ids)
        b_size, _ = x_1.size()
        pos_eye = torch.eye(b_size).to(x_1.device)

        slide_embedding_loss = calc_lower_bound(x_1, x_2.detach(), pos_eye)

        num_zeros = torch.sum((sub_A_1.bool() | sub_A_2.bool()) == 0)  
        num_non_zeros = torch.sum((sub_A_1.bool() | sub_A_2.bool()) != 0)
        eps = 1e-7  
        num_zeros = num_zeros.to(dtype=torch.float32)
        num_non_zeros = num_non_zeros.to(dtype=torch.float32)

        a = (num_zeros + num_non_zeros) / (2 * (num_zeros) + eps)  
        b = (num_zeros + num_non_zeros) / (2 * (num_non_zeros ) + eps)

        
        weights = torch.where(
            (sub_A_1.bool() | sub_A_2.bool()) == 0,
            a.expand_as(sub_A_1),  
            b.expand_as(sub_A_1)
        )
        loss_ = criterion(sub_A_1, sub_A_2.detach())
        structure_loss = torch.sum(loss_ * weights) / torch.sum(weights)  
        graph_embedding_loss = calc_lower_bound(sub_G_1, sub_G_2.detach(), pos_eye)


        survival_loss = survival_criterion(hazards=hazards, S=S, Y=labels, c=c)
        survival_loss_value = survival_loss.item()
        if reg_fn is None:
            loss_reg = 0
        else:
            loss_reg = reg_fn(model) * args.lambda_reg
        risk = -torch.sum(S, dim=1).detach().cpu().numpy()
        for j, (r_, c_, e_) in enumerate(zip(risk, c, event_time)):
            all_risk_scores[i * train_loader.batch_size + j] = r_
            all_censorships[i * train_loader.batch_size + j] = c_.item()
            all_event_times[i * train_loader.batch_size + j] = e_


        train_loss_surv += survival_loss_value
        train_loss += survival_loss_value + loss_reg


        loss = slide_embedding_loss + structure_loss + graph_embedding_loss + survival_loss/wsidata1[0].size(0) + loss_reg
        # record loss
        slide_losses.update(slide_embedding_loss.item(), wsidata1[0].size(0))
        structure_losses.update(structure_loss.item(), wsidata1[0].size(0))
        graph_losses.update(graph_embedding_loss.item(), wsidata1[0].size(0))
        losses.update(loss.item(), wsidata1[0].size(0))

        optimizer.zero_grad()

        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)

    train_loss_surv /= len(train_loader)
    train_loss /= len(train_loader)
    c_index = concordance_index_censored((1-all_censorships).astype(bool), all_event_times, all_risk_scores, tied_tol=1e-08)[0]
    print('[Train] train-loss={:.3f}, fused-loss={:.3f}, structure-loss={:.4f}, graph-loss={:.3f}, train_loss_surv: {:.4f}, train_loss: {:.4f}, train_c_index: {:.4f} \n'.format(losses.avg, slide_losses.avg, structure_losses.avg, graph_losses.avg, train_loss_surv, train_loss, c_index))
    

    return '{:.3f}'.format(losses.avg), '{:.3f}'.format(slide_losses.avg), '{:.4f}'.format(structure_losses.avg), '{:.3f}'.format(graph_losses.avg), '{:.3f}'.format(train_loss_surv), '{:.3f}'.format(train_loss), '{:.3f}'.format(c_index)

def validate(validate_loader, model, sl_criterion1, criterion, optimizer, lr_schedule, wd_schedule, epoch, args, reg_fn):

    model.eval()
    
    all_risk_scores = np.zeros((len(validate_loader) * validate_loader.batch_size))
    all_censorships = np.zeros((len(validate_loader) * validate_loader.batch_size))
    all_event_times = np.zeros((len(validate_loader) * validate_loader.batch_size))

    with torch.no_grad():
        start = time.time()
        # try:
        for i, (wsidata1, wsi_pos, slide_ids, labels, event_time, c) in enumerate(validate_loader):
            # measure data loading time
            load_time = time.time()
            if args.gpu is not None:
                view1 = wsidata1.float().cuda(args.gpu, non_blocking=True)
                coords1 = wsi_pos.long().cuda(args.gpu, non_blocking=True)
                
                c = c.float().cuda(args.gpu, non_blocking=True)
                labels = labels.long().cuda(args.gpu, non_blocking=True)

            logits, hazards, S, Y_hat = model(view1, coords1, None, None, None, slide_ids)

            risk = -torch.sum(S, dim=1).detach().cpu().numpy()
            for j, (r_, c_, e_) in enumerate(zip(risk, c, event_time)):
                all_risk_scores[i * validate_loader.batch_size + j] = r_
                all_censorships[i * validate_loader.batch_size + j] = c_.item()
                all_event_times[i * validate_loader.batch_size + j] = e_
            batch_time = time.time()
    c_index = concordance_index_censored((1-all_censorships).astype(bool), all_event_times, all_risk_scores, tied_tol=1e-08)[0]
    print('[Val]val_c_index: {:.4f}'.format(c_index))
    return '{:.3f}'.format(c_index), 


if __name__ == '__main__':
    main()
