import matplotlib.pyplot as plt 
import numpy as np
import pandas as pd
import sys
import torch

import cox
import kaplan_meier
import parameters_util
import survival_analysis
import torch_dataset
import train_model


def plot_Cox(ax, cph, features, col_names, time_points_pred):
    dummy_y = np.zeros( (features.shape[0], 2) )
    cox_data_test = np.concatenate([features, dummy_y], 1)
    cox_data_test_df = pd.DataFrame(data=cox_data_test, columns=col_names)
    y_surv_df = cph.predict_survival_function(cox_data_test_df, time_points_pred)
    y_surv = y_surv_df.values
    y_surv[0,:] = 1.0
    y_surv_mean = np.mean(y_surv, 1)
    x = time_points_pred + time_points_pred[1] / 2.0
    ax.plot(x, y_surv_mean, label='Cox')

def plot_KaplanMeier(ax, y, e):
    time_points, pi = kaplan_meier.compute_KaplanMeier(y, e)
    if len(time_points) == 0:
        return
    t_max_KM = time_points[-1]

    cmap = plt.get_cmap("tab10")
    x = []
    y = []
    pi = np.insert(pi, 0, 1.0)
    for i in range(len(time_points)):
        x.append(time_points[i])
        x.append(time_points[i])
        y.append(pi[i])
        y.append(pi[i+1])
    ax.plot(x, y, label='Kaplan-Meier', color=cmap(3))  # red
    #ax.legend(loc=legend_loc)

def plot_mean_prediction(ax, model, features, time_points_pred):
    y_pred = model.predict(features)
    #print(y_pred)
    #print(y_pred.shape)
    y_pred = y_pred.to('cpu').detach().numpy().copy()
    y_surv = 1.0 - np.cumsum(y_pred, axis=1)
    #print(y_surv.shape)
    y_surv_mean = np.mean(y_surv, 0)
    #print(y_surv_mean.shape)
    #sys.exit()
    x = time_points_pred + time_points_pred[1] / 2.0
    ax.plot(x[:-1], y_surv_mean[:-1], label='Fix-256')

def plot_prediction(ax, model, features, time_points_pred):
    y_pred = model.predict(features)
    y_pred = y_pred.to('cpu').detach().numpy().copy()
    y_surv = 1.0 - np.cumsum(y_pred, axis=1)
    x = time_points_pred + time_points_pred[1] / 2.0
    ax.plot(x[:-1], y_surv.T[:-1], label='Fix-256')

def plot_prediction_Cox(ax, cph, features, col_names, time_points_pred):
    dummy_y = np.zeros( (features.shape[0], 2) )
    cox_data_test = np.concatenate([features, dummy_y], 1)
    cox_data_test_df = pd.DataFrame(data=cox_data_test, columns=col_names)
    y_surv_df = cph.predict_survival_function(cox_data_test_df, time_points_pred)
    y_surv = y_surv_df.values
    y_surv[0,:] = 1.0
    #y_surv_mean = np.mean(y_surv, 1)
    x = time_points_pred + time_points_pred[1] / 2.0
    ax.plot(x, y_surv, label='Cox')
