import json
import numpy as np
import sys
import torch

import datamodule
import kaplan_meier
import loss_function
import metric
import survival_analysis


def cross_validation(args):
    dm = datamodule.DataModule(args)

    print('five-fold cross validation')
    results = {
        'f_mean': [],
        'test_score': [],
        'logarithmic_score': [],
        'D-calibration': [],
        'KM-calibration': []
    }
    for i in range(5):
        dm.set_fold(i)
        f_pred, test_score = survival_analysis.train_and_predict(dm, args)
        label_np = dm.y[dm.test_indices[dm.fold]]
        label = torch.from_numpy(label_np.astype(np.float32)).clone()
        #print(label)

        if args.loss_function=='Portnoy_pwl':
            dcal = metric.Dcal_qr_pwl(f_pred, label[:,0], label[:,1], dm.y_max)
            print('dcal', dcal)
            results['D-calibration'].append(dcal.item())
            # convert to distribution regression
            f_pred = qr2dr(f_pred, dm.y_max)
        else:
            dcal = metric.Dcal_dr_pwl(f_pred, label[:,0], label[:,1], dm.y_max)
            print('dcal', dcal)
            results['D-calibration'].append(dcal.item())

        results['f_mean'].append(f_pred.mean(axis=0).tolist())
        results['test_score'].append(test_score['test_loss'])

        log_score = loss_function.logarithmic_simple_pwl(f_pred, label[:,0],
                                                        label[:,1], dm.y_max)
        print('log_score', log_score)
        results['logarithmic_score'].append(log_score.item())

        kmcal = metric.KMcal(f_pred, label[:,0], label[:,1], dm.y_max)
        print('kmcal', kmcal)
        results['KM-calibration'].append(kmcal.item())

    # write json file
    if args.model=='Kaplan-Meier':
        arg1 = args.model
    else:
        arg1 = args.loss_function
    filename = 'json/{0}_{1}_{2}'.format(dm.dataset_name, arg1, args.num_bin)
    if args.loss_function=='DeepHit':
        filename += '_{0}'.format(args.DeepHit_alpha)
    if args.withoutEM:
        filename += '_withoutEM.json'
    else:
        filename += '.json'
    print('Writing '+filename)
    with open(filename, 'w') as f:
        json.dump(results, f, indent=4)

    sys.exit()

def KaplanMeier_on_test(args):
    dm = datamodule.DataModule(args)

    print('compute Kaplan-Meier curves on test data')
    results = {
        'f_mean': []
    }
    for i in range(5):
        dm.set_fold(i)
        label_np = dm.y[dm.test_indices[dm.fold]]
        label = torch.from_numpy(label_np.astype(np.float32)).clone()

        e_dist, _ = kaplan_meier.estimate_empirical_distribution(label[:,0],
                                                                label[:,1],
                                                                dm.y_max,
                                                                args.num_bin)

        results['f_mean'].append(e_dist.tolist())

    # write json file
    filename = 'json/{0}_Kaplan-Meier_test_{1}.json'.format(dm.dataset_name, args.num_bin)
    print('Writing '+filename)
    with open(filename, 'w') as f:
        json.dump(results, f, indent=4)

    sys.exit()
