import numpy as np
import pandas as pd
import sys
import torch
from lifelines import CoxPHFitter
from lifelines.datasets import load_rossi

import survival_analysis_loss
import torch_dataset
import train_model


def generate_col_names(dataset):
    col_names = []
    for i in range(dataset.x_train.shape[1]):
        col_names.append('x%d' % i)
    col_names.append('y')
    col_names.append('e')
    return col_names

def train(dataset, nn_param):
    n_bin = nn_param['n_bin']
    train_model.transform_data(dataset, nn_param)
    col_names = generate_col_names(dataset)
    cox_data_train = np.concatenate([dataset.x_train, dataset.y_train], 1)
    cox_data_train_df = pd.DataFrame(data=cox_data_train, columns=col_names)
    cph = CoxPHFitter(penalizer=0.01)
    cph.fit(cox_data_train_df, duration_col='y', event_col='e')
    return cph

def predict(dataset, cph, nn_param):
    # prepare DataLoader
    batch_size = nn_param.get('batch_size_test', 128)
    dataset_test = torch_dataset.TorchDataset(dataset.x_test,
                                              dataset.y_test)
    test_loader = torch.utils.data.DataLoader(dataset_test,
                                              batch_size = batch_size,
                                              shuffle = False,
                                              drop_last = False)

    # prediction
    col_names = generate_col_names(dataset)
    y_max = dataset.y_max
    time_points_pred = np.linspace(0, y_max, nn_param['n_bin'])  # create n_bin-1 bins
    list_pred = []
    list_label = []
    for (inputs, labels) in test_loader:
        y = labels[:,0]
        e = labels[:,1]

        dummy_y = np.zeros( (inputs.shape[0], 2) )
        cox_data_test = np.concatenate([inputs, dummy_y], 1)
        cox_data_test_df = pd.DataFrame(data=cox_data_test, columns=col_names)
        y_surv_df = cph.predict_survival_function(cox_data_test_df, time_points_pred)
        y_surv = y_surv_df.values
        y_surv[0,:] = 1.0
        dummy_row = np.zeros(y_surv.shape[1])
        y_surv = np.vstack((y_surv,dummy_row))
        y_pred = y_surv[:-1,:] - y_surv[1:,:]
        y_pred = y_pred.T
        y_pred = torch.from_numpy(y_pred.astype(np.float32)).clone()

        list_pred.append(y_pred)
        list_label.append(labels)

    predictions = np.concatenate(list_pred)
    labels = np.concatenate(list_label)

    return predictions, labels

def train_and_predict(dataset, nn_param):
    cph = train(dataset, nn_param)
    y_pred, y_label = predict(dataset, cph, nn_param)
    return y_pred

def train_and_test(dataset, nn_param):
    # train
    cph = train(dataset, nn_param)

    # test
    col_names = generate_col_names(dataset)
    y_max = dataset.y_max
    time_points_pred = np.linspace(0, y_max, nn_param['n_bin'])  # create n_bin-1 bins
    #print('time_points_pred', time_points_pred)
    nw = 0
    dataset_test = torch_dataset.TorchDataset(dataset.x_test, dataset.y_test)
    test_results = {}
    for bs_test in nn_param.get('batch_size_test_list', [128]):
        #print('test batch size %d' % bs_test)
        test_dataloader = torch.utils.data.DataLoader(dataset_test,
                                                      num_workers=nw,
                                                      batch_size = bs_test,
                                                      shuffle = True,
                                                      drop_last = False)

        # predict
        total_loss = {}
        count = 0
        for (inputs, labels) in test_dataloader:
            y = labels[:,0]
            e = labels[:,1]

            dummy_y = np.zeros( (inputs.shape[0], 2) )
            cox_data_test = np.concatenate([inputs, dummy_y], 1)
            cox_data_test_df = pd.DataFrame(data=cox_data_test, columns=col_names)
            y_surv_df = cph.predict_survival_function(cox_data_test_df, time_points_pred)
            y_surv = y_surv_df.values
            y_surv[0,:] = 1.0
            dummy_row = np.zeros(y_surv.shape[1])
            y_surv = np.vstack((y_surv,dummy_row))
            y_pred = y_surv[:-1,:] - y_surv[1:,:]
            y_pred = y_pred.T
            y_pred = torch.from_numpy(y_pred.astype(np.float32)).clone()

            for name, loss_fn_def in nn_param['loss_function_test'].items():
                temp = 0.0
                for loss_fn, param in loss_fn_def.items():
                    temp += survival_analysis_loss.compute_loss(loss_fn, y_pred,
                                                                y, e, param,
                                                                y_max, nn_param)
                if name in total_loss:
                    total_loss[name] += temp
                else:
                    total_loss[name] = temp
            count += 1

        for key, value in total_loss.items():
            test_results["{0}_{1:05d}".format(key, bs_test)] = float(value) / count
    return test_results
