import torch

import kaplan_meier
import util


def Dcal_dr_pwl(f_pred, z, e, z_max, numC=20):
    n_bin = f_pred.shape[1]
    uncensored = e.bool()

    boundaries = torch.linspace(0.0, 1.0, steps=n_bin+1) * z_max
    F_pred = util.convert_f2F(f_pred)

    quantiles = util.time2quantiles_dr(F_pred, boundaries, z).view(-1,1)

    c = torch.histc(quantiles[uncensored].view(-1), bins=numC,
                    min=0.0, max=1.0)

    v = quantiles[~uncensored]
    c_boundaries = torch.linspace(0.0, 1.0, steps=numC+1)
    v_in_C = ((c_boundaries[:-1] <= v) & (v < c_boundaries[1:]))
    v_in_C = v_in_C.to(torch.float)
    v_leq_C = (v < c_boundaries[:-1]).to(torch.float)
    temp1 = ((c_boundaries[1:] - v) / (1-v)) * v_in_C
    temp2 = ((c_boundaries[1:] - c_boundaries[:-1]) / (1-v)) * v_leq_C
    c += torch.sum(temp1 + temp2, 0)
    c /= quantiles.shape[0]

    diff = c - 1/numC
    return torch.sum(diff * diff)

def Dcal_qr_pwl(f_pred, z, e, z_max, numC=20):
    n_bin = f_pred.shape[1]
    uncensored = e.bool()

    F_pred = util.convert_f2F(f_pred) * z_max
    idx = torch.searchsorted(F_pred, z.view(-1,1), right=True)
    F_lb = torch.gather(F_pred, 1, idx-1).view(-1)
    F_ub = torch.gather(F_pred, 1, idx).view(-1)
    ratio = (z - F_lb) / (F_ub - F_lb)
    quantiles = (((idx-1).view(-1) + ratio) / numC).view(-1,1)

    c = torch.histc(quantiles[uncensored], bins=numC,
                    min=0.0, max=1.0)

    v = quantiles[~uncensored]
    c_boundaries = torch.linspace(0.0, 1.0, steps=numC+1)
    v_in_C = ((c_boundaries[:-1] <= v) & (v < c_boundaries[1:]))
    v_in_C = v_in_C.to(torch.float)
    v_leq_C = (v < c_boundaries[:-1]).to(torch.float)
    temp1 = ((c_boundaries[1:] - v) / (1-v)) * v_in_C
    temp2 = ((c_boundaries[1:] - c_boundaries[:-1]) / (1-v)) * v_leq_C
    c += torch.sum(temp1 + temp2, 0)
    c /= quantiles.shape[0]

    diff = c - 1/numC
    return torch.sum(diff * diff)

def KMcal(f_pred, z, e, z_max):
    n_bin = f_pred.shape[1]
    e_dist, invalid_idx = kaplan_meier.estimate_empirical_distribution(z,
                                                                       e,
                                                                       z_max,
                                                                       n_bin)
    f_pred_mean = torch.mean(f_pred,0)

    # compute logarithmic loss for KM valid region
    EPS = 0.0000001
    log_empirical = torch.log(e_dist[:invalid_idx]+EPS)
    log_mean_pred = torch.log(f_pred_mean[:invalid_idx]+EPS)
    loss_valid = torch.sum(e_dist[:invalid_idx] 
                            * (log_empirical - log_mean_pred))

    # compute logarithmic loss for KM invalid region
    sum_empirical = torch.sum(e_dist[invalid_idx:])
    log_sum_empirical = torch.log(sum_empirical + EPS)
    log_sum_pred = torch.log(torch.sum(f_pred_mean[invalid_idx:]) + EPS)
    loss_invalid = sum_empirical * (log_sum_empirical - log_sum_pred)

    return loss_valid + loss_invalid
