import json
import numpy as np
import os
import sys
from sklearn.model_selection import KFold, train_test_split

import cox
import datamodule
import kaplan_meier
import neural_network
import postprocessing
import train_model


def save_prediction(trainer, datamodule):
    # extract train and val index
    index = []
    for idx in datamodule.index_list:
        index.extend(idx)
    index.extend(datamodule.val_indices[datamodule.fold])

    predict_dataloader = datamodule.predict_dataloader(index)
    y_pred = trainer.predict(dataloaders=predict_dataloader, ckpt_path='best')

    pred_dict = {}
    for i in range(len(index)):
        pred_dict[str(index[i])] = []
        for j in range(y_pred[0][i].shape[0]):
            pred_dict[str(index[i])].append(y_pred[0][i,j].item())

    filename = 'prediction/%s' % datamodule.dataset_name
    filename += '_logarithmic_%d.json' % datamodule.fold
    with open(filename, 'w') as f:
        json.dump(pred_dict, f, indent=2)

def train_and_validate(datamodule, nn_param):
    if nn_param['model'] == 'Cox':
        return cox.train_and_validate(datamodule, nn_param)
    if nn_param['model'] == 'Kaplan-Meier':
        return kaplan_meier.train_and_validate(datamodule, nn_param)

    if nn_param.get('use_pytorch_lightning', False):
        trainer = train_model.execute_lightning(datamodule, nn_param)
        val_dataloader = datamodule.val_dataloader()
        val_loss = trainer.validate(dataloaders=val_dataloader)
        print('val_loss', val_loss)
    else:
        model = train_model.execute(datamodule, nn_param)
        val_results = {}
        val_dataloader = datamodule.val_dataloader()
        val_loss = neural_network.test(model, val_dataloader)
    return val_loss

def train_and_test(datamodule, args):
    if args.model == 'Cox':
        return cox.train_and_test(datamodule, args)
    if args.model == 'Kaplan-Meier':
        return kaplan_meier.train_and_test(datamodule, args)

    if args.use_pytorch_lightning:
        trainer = train_model.execute_lightning(datamodule, args)
        # test
        test_results = {}
        test_dataloader = datamodule.test_dataloader()
        results = trainer.test(dataloaders=test_dataloader, ckpt_path='best')
        for key, value in results[0].items():
            test_results[key] = value
        return test_results
    else:
        model = train_model.execute(datamodule, nn_param)
        # test
        test_results = {}
        for bs_test in nn_param.get('batch_size_test_list', [128]):
            print('test batch size %d' % bs_test)
            datamodule.set_test_batch_size(bs_test)
            test_dataloader = datamodule.test_dataloader()
            results = neural_network.test(model, test_dataloader)
            print(results)
            for key, value in results.items():
                test_results["{0}_{1:05d}".format(key, bs_test)] = value
        return test_results

def train_and_predict(datamodule, args):
    if args.model == 'Cox':
        return cox.train_and_predict(datamodule, args)
    if args.model == 'Kaplan-Meier':
        return kaplan_meier.train_and_predict(datamodule, args)

    if args.use_pytorch_lightning:
        trainer = train_model.execute_lightning(datamodule, args)
        predict_dataloader = datamodule.predict_dataloader()
        y_pred = trainer.predict(dataloaders=predict_dataloader, ckpt_path='best')
        test_dataloader = datamodule.test_dataloader()
        test_results = trainer.test(dataloaders=test_dataloader, ckpt_path='best')

        if args.save_prediction:
            save_prediction(trainer, datamodule)
        #print('test score', test_results)
        return y_pred[0], test_results[0]
        #return y_pred[0].to('cpu').detach().numpy()
    else:
        model = train_model.execute(datamodule, nn_param)
        predict_dataloader = datamodule.predict_dataloader()
        return neural_network.predict(predict_dataloader, model)

def evaluate(args):
    # create data module
    dm = datamodule.DataModule(args)

    # evaluate
    if args.cross_validation > 0:
        print('five-fold cross validation')
        test_losses = []
        scores = []
        for i in range(5):
            dm.set_fold(i)
            score = train_and_test(dm, args)
            scores.append(score)
        print(dm.dataset_name)
        print(args)
        postprocessing.write_test_scores(scores, dm.dataset_name)
    else:
        ret_scores = train_and_test(dm, args)
        #ret_scores = train_and_predict(dm, args)
        print(ret_scores)
