import matplotlib.pyplot as plt
import numpy as np
import random
import sys
import torch
from lifelines import KaplanMeierFitter

import loss_function
import survival_analysis_loss
import torch_dataset
import train_model


def compute_KaplanMeier(z, e):
    # inner function
    def count_num_survive(time_point, t):
        return np.count_nonzero(t >= time_point)

    # get list of uncensored event times
    time_points, counts = np.unique(z[e > 0], return_counts=True)
    if len(time_points) == 0:
        return time_points, counts

    # compute survival rate of each time by vectorize
    vfunc = np.vectorize(count_num_survive)
    vfunc.excluded.add(1)  # fix second argument
    ni = vfunc(time_points, z)
    p = 1 - counts/ni
    pi = np.cumprod(p)

    return time_points, pi

def estimate_empirical_distribution(z, e, z_max, n_bin):
    # estimate KM curve
    all_censored = False
    time_points_KM, survival_rates_KM = compute_KaplanMeier(z,e)
    if len(time_points_KM) == 0:  # if all data points are censored
        time_points_KM = np.array([ torch.min(z) ])
        all_censored = True
    if time_points_KM[0] == 0.0:
        time_points_KM[0] = 0.0000001
    survival_rates_KM = np.append(1.0, survival_rates_KM)
    survival_rates_KM = np.append(survival_rates_KM, 0.0)

    # compute valid index
    z_censored = z[e==0.0]
    temp = z_censored[z_censored > time_points_KM[-1]]
    if all_censored or len(temp)==0:
        time_KM_invalid_index = int((time_points_KM[-1] / z_max) * n_bin)
    else:
        time_KM_invalid_index = int((temp.min() / z_max) * n_bin)

    # compute emprirical distribution from KM curve
    time_points_threshold = np.linspace(0, z_max, n_bin+1)
    indices_tp_threshold = np.searchsorted(time_points_KM,
                                            time_points_threshold,
                                            side='right')
    survival_rates = np.take(survival_rates_KM, indices_tp_threshold)
    survival_rates[-1] = 0.0
    empirical_dist = survival_rates[:-1] - survival_rates[1:]
    empirical_dist = torch.from_numpy(empirical_dist.astype(np.float32)).clone()
    return empirical_dist, time_KM_invalid_index

def compute_KMtn_KW21(y, n_bin, y_max):
    hist, bins = np.histogram(y[:,0], bins=(n_bin-1), range=(0.0,y_max))
    hist = np.append(hist, 0).reshape(1,-1)
    return hist / y.shape[0]

def plot(dataset, filename):
    kmf = KaplanMeierFitter()
    kmf.fit(dataset.original_y[:,0], event_observed=dataset.original_y[:,1]) 
    kmf.plot_survival_function()
    #plt.show()
    print('Write '+filename)
    plt.savefig(filename)
    sys.exit()

def plot_df(df, x, y, censored, filename):
    for name, group in df.groupby(x):
        kmf = KaplanMeierFitter()
        kmf.fit(group[y], event_observed=group[censored], label = 'class=' + str(name))
        kmf.plot_survival_function()
    #plt.show()
    print('Write '+filename)
    plt.savefig(filename)
    sys.exit()

def plot_numpy(dm, y_pred=None, model_info=None, show_histogram=False,
               use_test_data=False, filename=None, legend_loc='lower left',
               add_title=False, exit=True):
    # extract data
    if use_test_data:
        t = dm.y[dm.test_indices[0],0]
        uncensored = dm.y[dm.test_indices[0],1]
    else:
        if hasattr(dataset, 'original_y'):
            t = dataset.original_y[:,0]
            uncensored = dataset.original_y[:,1]
        else:
            print('original_y is not found, and so original_y_train is used instead')
            t = dataset.original_y_train[:,0]
            uncensored = dataset.original_y_train[:,1]
    t_max = t.max()

    # compute Kaplan-Meier survival curve
    time_points, pi = compute_KaplanMeier(t, uncensored)
    t_max_KM = time_points[-1]

    # setup graph
    cmap = plt.get_cmap("tab10")
    fig = plt.figure(figsize=(5,3.75))
    plt.title(dm.dataset_name)
    plt.xlabel('Time')
    plt.ylabel('Survival rate')
    plt.xlim([0,t_max])
    plt.ylim([0,1])
    if add_title:
        plt.title(dm.dataset_name)

    # plot prediction
    if y_pred is not None:
        n_bin = y_pred.shape[1]
        boundaries = np.linspace(0.0, dm.y_max, n_bin)
        x = boundaries + 0.5 * dm.y_max / (n_bin-1)
        n_max = int((t_max_KM / dm.y_max) * (n_bin-1))
        #print(boundaries)
        for y, info in zip(y_pred, model_info):
            #print('x',x)
            #print('y',y)
            #print('info',info)
            #plt.plot(x[:n_max], y[:n_max], label=info['label'],
            #         color=cmap(info['color']))
            plt.plot(x[:n_max], y[:n_max])

    # plot Kaplan-Meier curve
    x = []
    y = []
    pi = np.insert(pi, 0, 1.0)
    for i in range(len(time_points)):
        x.append(time_points[i])
        x.append(time_points[i])
        y.append(pi[i])
        y.append(pi[i+1])
    plt.plot(x, y, label='Kaplan-Meier', color=cmap(3))  # red
    plt.legend(loc=legend_loc)

    # show graph
    if filename is None:
        plt.show()
    else:
        print('Writing %s' % filename)
        plt.savefig(filename)
    if exit:
        sys.exit()

def plot_numpy_q(dm, F_pred=None, use_test_data=False, filename=None,
                legend_loc='lower left', add_title=False, exit=True):
    # extract data
    if use_test_data:
        t = dm.y[dm.test_indices[0],0]
        uncensored = dm.y[dm.test_indices[0],1]
    else:
        if hasattr(dataset, 'original_y'):
            t = dataset.original_y[:,0]
            uncensored = dataset.original_y[:,1]
        else:
            print('original_y is not found, and so original_y_train is used instead')
            t = dataset.original_y_train[:,0]
            uncensored = dataset.original_y_train[:,1]
    t_max = t.max()

    # compute Kaplan-Meier survival curve
    time_points, pi = compute_KaplanMeier(t, uncensored)
    t_max_KM = time_points[-1]

    # setup graph
    cmap = plt.get_cmap("tab10")
    fig = plt.figure(figsize=(5,3.75))
    plt.title(dm.dataset_name)
    plt.xlabel('Time')
    plt.ylabel('Survival rate')
    plt.xlim([0,t_max])
    plt.ylim([0,1])
    if add_title:
        plt.title(dm.dataset_name)

    # plot prediction
    if F_pred is not None:
        #print('F_pred', F_pred)
        n_bin = F_pred.shape[0]
        boundaries = np.linspace(0.0, 1.0, n_bin+1)
        y = 1.0 - boundaries
        x = np.append(0.0, F_pred) * dm.y_max
        plt.plot(x, y, label='Portnoy')

    # plot Kaplan-Meier curve
    x = []
    y = []
    pi = np.insert(pi, 0, 1.0)
    for i in range(len(time_points)):
        x.append(time_points[i])
        x.append(time_points[i])
        y.append(pi[i])
        y.append(pi[i+1])
    plt.plot(x, y, label='Kaplan-Meier', color=cmap(3))  # red
    plt.legend(loc=legend_loc)

    # show graph
    if filename is None:
        plt.show()
    else:
        print('Writing %s' % filename)
        plt.savefig(filename)
    if exit:
        sys.exit()

def plot_samples_numpy(dataset, num_sample, size_sample, filename=None):
    # setup graph
    fig = plt.figure()
    ax1 = fig.add_subplot(111)
    ax1.set_xlim([0,1])
    ax1.set_ylim([0,1])

    # plot Kaplan-Meier curves
    for j in range(num_sample):
        idx = np.random.randint(dataset.original_y.shape[0], size=size_sample)
        t = dataset.original_y[idx,0]
        uncensored = dataset.original_y[idx,1]
        time_points, pi = compute_KaplanMeier(t, uncensored)
        x = []
        y = []
        pi = np.insert(pi, 0, 1.0)
        for i in range(len(time_points)):
            x.append(time_points[i])
            x.append(time_points[i])
            y.append(pi[i])
            y.append(pi[i+1])
        ax1.plot(x, y, label=('sample%d' % j))
        ax1.legend(loc='lower left')

    # show graph
    if filename is None:
        plt.show()
    else:
        plt.savefig(filename)
    sys.exit()

def train_and_predict(dm, args):
    # train
    label_train_np = dm.y[dm.train_indices[dm.fold]]
    label = torch.from_numpy(label_train_np.astype(np.float32)).clone()
    empirical_dist, time_KM_invalid_index = estimate_empirical_distribution(
        label[:,0], label[:,1], dm.y_max, args.num_bin)

    # return empirical distribution
    num_test = dm.y[dm.test_indices[dm.fold]].shape[0]
    return torch.tile(empirical_dist, (num_test,1)), { 'test_loss': 0.0 }

def train_and_test(dataset, nn_param):
    # train
    n_bin = nn_param['n_bin']
    train_model.transform_data(dataset, nn_param)
    empirical_dist, time_KM_invalid_index = loss_function.estimate_empirical_dist_KM(
        dataset.y_train[:,0], dataset.y_train[:,1],
        dataset.y_max, n_bin)

    # test
    nw = 0
    dataset_test = torch_dataset.TorchDataset(dataset.x_test, dataset.y_test)
    test_results = {}
    for bs_test in nn_param.get('batch_size_test_list', [128]):
        #print('test batch size %d' % bs_test)
        test_dataloader = torch.utils.data.DataLoader(dataset_test,
                                                      num_workers=nw,
                                                      batch_size = bs_test,
                                                      shuffle = True,
                                                      drop_last = False)

        y_max = dataset.y_max

        # predict
        total_loss = {}
        count = 0
        for (inputs, labels) in test_dataloader:
            y = labels[:,0]
            e = labels[:,1]
            y_pred = np.tile(empirical_dist, (y.shape[0], 1))
            y_pred = torch.from_numpy(y_pred.astype(np.float32)).clone()

            for name, loss_fn_def in nn_param['loss_function_test'].items():
                temp = 0.0
                for loss_fn, param in loss_fn_def.items():
                    temp += survival_analysis_loss.compute_loss(loss_fn, y_pred,
                                                                y, e, param,
                                                                y_max, nn_param)
                if name in total_loss:
                    total_loss[name] += temp
                else:
                    total_loss[name] = temp
            count += 1

        for key, value in total_loss.items():
            test_results["{0}_{1:05d}".format(key, bs_test)] = float(value) / count
    return test_results
