import numpy as np
import sys
import torch

import kaplan_meier
import util


def Lz(y_pred, y, e, y_max):
    '''
    Compute -log f(t|x) [Lz loss in the DRSA paper]
    y_pred: [batch_size, n_bin] predictions of event time
    y: [batch_size] event time or last observation time in the dataset
    e: [batch_size] censored (0) or uncensored (1)
    y_max: maximum time of y
    '''
    EPSILON = 0.0000001
    n_bin = y_pred.shape[1]
    uncensored = e.bool().view(-1)
    idx = ((y[uncensored] / y_max) * (n_bin-1)).to(torch.long)
    p = torch.gather(y_pred[uncensored], 1, idx.view(-1,1))
    return -torch.sum(torch.log(p + EPSILON))

def Luncensored(y_pred, y, e, y_max):
    '''
    Compute -log F(t|x) [L_{uncensored} loss in the DRSA paper]
    y_pred: [batch_size, n_bin] predictions of event time
    y: [batch_size] event time or last observation time in the dataset
    e: [batch_size] censored (0) or uncensored (1)
    y_max: maximum time of y
    '''
    EPSILON = 0.0000001
    n_bin = y_pred.shape[1]
    uncensored = e.bool().view(-1)
    idx = ((y[uncensored] / y_max) * (n_bin-1)).to(torch.long)
    cum_pred = torch.cumsum(y_pred[uncensored], dim=1)
    ftx = torch.gather(cum_pred, 1, idx.view(-1,1))
    return -torch.sum(torch.log(ftx + EPSILON))

def Lcensored(y_pred, y, e, y_max):
    '''
    Compute -log S(t|x) [L_{censored} loss in the DRSA paper]
    y_pred: [batch_size, n_bin] predictions of event time
    y: [batch_size] event time or last observation time in the dataset
    e: [batch_size] censored (0) or uncensored (1)
    y_max: maximum time of y
    '''
    EPSILON = 0.0000001
    n_bin = y_pred.shape[1]
    uncensored = e.bool().view(-1)
    idx = ((y[~uncensored] / y_max) * (n_bin-1)).to(torch.long)
    cum_pred = torch.cumsum(y_pred[~uncensored], dim=1)
    ftx = torch.gather(cum_pred, 1, idx.view(-1,1))
    return -torch.sum(torch.log(1.0 + EPSILON - ftx))

def Brier(y_pred, y, e, y_max):
    '''
    Compute Brier score with >2 bins
    (in discrete setting, t=c means uncensored and progressive rounding)
    y_pred: [batch_size, n_bin] predictions of event time
    y: [batch_size] event time or last observation time in the dataset
    e: [batch_size] censored (0) or uncensored (1)
    y_max: maximum time of y
    '''
    n_bin = y_pred.shape[1]
    uncensored = e.bool().view(-1)
    idx = ((y / y_max) * (n_bin-1)).to(torch.long)
    one_hot = torch.nn.functional.one_hot(idx[uncensored], num_classes=n_bin)
    diff_uncensored = y_pred[uncensored] - one_hot
    loss_uncensored = diff_uncensored * diff_uncensored
    cum_pred = torch.cumsum(y_pred[~uncensored], dim=1)
    ftx = torch.gather(cum_pred, 1, idx[~uncensored].view(-1,1))
    loss_censored = (1 - ftx) * (1 - ftx)
    return torch.sum(loss_uncensored) + torch.sum(loss_censored)

def kernel_loss(y_pred, y, e, y_max):
    '''
    Compute kernel loss
    y_pred: [batch_size, n_bin] predictions of event time
    y: [batch_size] event time or last observation time in the dataset
    e: [batch_size] censored (0) or uncensored (1)
    y_max: maximum time of y
    '''
    sigma = 1.0
    batch_size = len(y)

    mask_uncensored = (torch.t(e).view(-1) > 0.0)
    y_uncensored = y[mask_uncensored]
    n_uncensored = len(y_uncensored)
    if n_uncensored <= 1:
        return torch.tensor(0.0, requires_grad=True)

    yh = y_uncensored.repeat((n_uncensored,1))
    yv = y_uncensored.view((-1,1)).repeat((1,n_uncensored))
    mask_a = torch.gt(yh,yv)

    # compute survival rates
    s_array = []
    y_pred_uncensored = y_pred[mask_uncensored,:]
    n_bin = y_pred_uncensored.shape[1]
    for i in range(n_uncensored):
        idx = int((y_uncensored[i] / y_max) * (n_bin-1))
        r = (y_uncensored[i] - y_max * idx / (n_bin-1)) / (y_max / (n_bin-1))
        p1 = torch.sum(y_pred_uncensored[:,:idx], 1)
        p2 = torch.sum(y_pred_uncensored[:,:idx+1], 1)
        s = 1.0 - ((1-r) * p1 + r * p2)
        s_array.append(s.view(1,-1))
    s_pred = torch.cat(s_array, axis=0)

    s_diag = torch.diag(s_pred, 0)
    s_2 = s_diag.view((-1,1)).repeat((1,n_uncensored))
    temp = torch.exp((s_2 - s_pred) / sigma)
    return torch.sum(temp[mask_a])

def locate_z(z, e, z_max, boundaries, exclude_uncensored=True):
    uncensored = e.bool()
    if type(boundaries) is int:
        n_bin = boundaries
        idx = torch.div(z * n_bin, z_max, rounding_mode='floor')
        if exclude_uncensored:
            ratio = z[~uncensored] * n_bin / z_max - idx[~uncensored]
        else:
            ratio = z * n_bin / z_max - idx
        idx = idx.to(torch.long).view(-1,1)
    else:
        idx = torch.searchsorted(boundaries, z.view(-1,1), right=True) - 1
        if exclude_uncensored:
            idx_censored = idx[~uncensored].view(-1)
            b_ub = boundaries[idx_censored+1]
            b_lb = boundaries[idx_censored]
            ratio = (z[~uncensored] - b_lb) / (b_ub - b_lb)
        else:
            b_ub = boundaries[idx.view(-1)+1]
            b_lb = boundaries[idx.view(-1)]
            ratio = (z - b_lb) / (b_ub - b_lb)
    return idx, ratio

def compute_censored(F_censored, idx, c_ratio):
    Fc_lb = torch.gather(F_censored, 1, idx).view(-1)
    Fc_ub = torch.gather(F_censored, 1, idx+1).view(-1)
    Fc = (1.0 - c_ratio) * Fc_lb + c_ratio * Fc_ub
    Fc = torch.clamp(Fc, min=0.0, max=0.999999)
    return Fc, Fc_ub

def logarithmic_pwl(f_pred, z, e, z_max, withoutEM=False, f_boundaries=None):
    n_bin = f_pred.shape[1]
    uncensored = e.bool()
    loss = 0.0
    if f_boundaries is None:
        idx, c_ratio = locate_z(z, e, z_max, n_bin)
    else:
        idx, c_ratio = locate_z(z, e, z_max, f_boundaries)

    f_pred_uncensored = torch.gather(f_pred[uncensored], 1, idx[uncensored])
    loss -= torch.sum(torch.log(f_pred_uncensored+0.000001))

    F_censored = util.convert_f2F(f_pred[~uncensored])
    Fc, Fc_ub = compute_censored(F_censored, idx[~uncensored], c_ratio)
    diff_c = Fc_ub - Fc
    alpha = diff_c / (1.0 - Fc)
    if not withoutEM:
        alpha = alpha.detach()  # delete gradients
    temp1 = alpha * torch.log(diff_c + 0.000001)
    temp2 = (1.0 - alpha) * torch.log(1.000001 - Fc_ub)
    loss -= torch.sum(temp1 + temp2)

    return loss / f_pred.shape[0]

def logarithmic_simple_pwl(f_pred, z, e, z_max, f_boundaries=None):
    n_bin = f_pred.shape[1]
    uncensored = e.bool()
    loss = 0.0
    if f_boundaries is None:
        idx, c_ratio = locate_z(z, e, z_max, n_bin)
    else:
        idx, c_ratio = locate_z(z, e, z_max, f_boundaries)

    f_pred_uncensored = torch.gather(f_pred[uncensored], 1, idx[uncensored])
    loss -= torch.sum(torch.log(f_pred_uncensored+0.000001))

    F_censored = util.convert_f2F(f_pred[~uncensored])
    _, Fc_ub = compute_censored(F_censored, idx[~uncensored], c_ratio)
    loss -= torch.sum(torch.log(1.000001 - Fc_ub))

    return loss / f_pred.shape[0]

def Brier_pwl(f_pred, z, e, z_max, withoutEM=False, f_boundaries=None):
    n_bin = f_pred.shape[1]
    uncensored = e.bool()
    if f_boundaries is None:
        idx, c_ratio = locate_z(z, e, z_max, n_bin)
    else:
        idx, c_ratio = locate_z(z, e, z_max, f_boundaries)

    one_hot = torch.nn.functional.one_hot(idx.view(-1), num_classes=n_bin)
    one_hot = one_hot.to(torch.float)
    idx_censored = idx[~uncensored].view(-1)
    upper_fill = np.tri(n_bin, n_bin, -1, dtype=np.float32).T[idx_censored]
    upper_fill = torch.from_numpy(upper_fill)
    coef = one_hot
    F_censored = util.convert_f2F(f_pred[~uncensored])
    Fc, Fc_ub = compute_censored(F_censored, idx_censored.view(-1,1),
                                c_ratio)
    alpha = (Fc_ub - Fc) / (1.0 - Fc)
    coef[~uncensored] = one_hot[~uncensored] * alpha.view(-1,1)
    beta = f_pred[~uncensored] * upper_fill / (1.0 - Fc.view(-1,1))
    coef[~uncensored] += beta
    if not withoutEM:
        coef = coef.detach()  # delete gradients
    diff1 = (f_pred - 1.0) * (f_pred - 1.0)
    diff0 = f_pred * f_pred
    loss = torch.sum(coef*diff1 + (1.0-coef)*diff0)

    return loss / f_pred.shape[0]

def ProperRPS_pwl(f_pred, z, e, z_max, withoutEM=False,
                    f_boundaries=None):
    n_bin = f_pred.shape[1]
    uncensored = e.bool()
    loss = 0.0
    if f_boundaries is None:
        idx, c_ratio = locate_z(z, e, z_max, n_bin)
    else:
        idx, c_ratio = locate_z(z, e, z_max, f_boundaries)

    F_pred = util.convert_f2F(f_pred)
    F_sq0 = F_pred[:,1:-1] * F_pred[:,1:-1]
    F_sq1 = (1.0 - F_pred[:,1:-1]) * (1.0 - F_pred[:,1:-1])

    lower_fill = np.tri(n_bin, n_bin, -1, dtype=np.float32)[idx.view(-1)]
    lower_fill = torch.from_numpy(lower_fill.astype(np.float32))
    lower_fill = lower_fill.view(f_pred.shape)

    Fc, _ = compute_censored(F_pred[~uncensored],
                             idx[~uncensored].view(-1,1),
                             c_ratio)
    Fc = Fc.view(-1,1)
    alpha = (F_pred[~uncensored,1:-1] - Fc) / (1 - Fc)
    if not withoutEM:
        alpha = alpha.detach()  # delete gradients
    coef1 = 1.0 - lower_fill[:,:-1]
    coef1[~uncensored] *= alpha
    loss = (1.0-coef1)*F_sq0 + coef1*F_sq1
    return torch.sum(loss) / f_pred.shape[0]

def SurvivalCRPS_pwl(f_pred, z, e, z_max, f_boundaries=None):
    n_bin = f_pred.shape[1]
    uncensored = e.bool()
    loss = 0.0
    if f_boundaries is None:
        idx, ratio = locate_z(z, e, z_max, n_bin, False)
    else:
        idx, ratio = locate_z(z, e, z_max, f_boundaries, False)
    F_pred = util.convert_f2F(f_pred)
    F_sq_mean0 = F_pred[:,:-1] * F_pred[:,:-1]
    F_sq_mean0 += F_pred[:,1:] * F_pred[:,1:]
    F_sq_mean0 /= 2.0
    F_sq_mean1 = (1.0-F_pred[:,:-1]) * (1.0-F_pred[:,:-1])
    F_sq_mean1 += (1.0-F_pred[:,1:]) * (1.0-F_pred[:,1:])
    F_sq_mean1 /= 2.0
    if f_boundaries is None:
        len_intervals = z_max / f_pred.shape[1]
    else:
        len_intervals = f_boundaries[1:] - f_boundaries[:-1]
    lower_fill = np.tri(n_bin, n_bin, -1, dtype=np.float32)[idx.view(-1)]
    lower_fill = torch.from_numpy(lower_fill.astype(np.float32))
    loss += torch.sum(lower_fill * len_intervals * F_sq_mean0)

    Fz_lb = torch.gather(F_pred, 1, idx).view(-1)
    Fz_ub = torch.gather(F_pred, 1, idx+1).view(-1)
    if f_boundaries is None:
        len_interval = len_intervals
    else:
        len_interval = len_intervals[idx]
    Fz = ratio * Fz_lb + (1.0-ratio) * Fz_ub
    Fz_sq_mean_l0 = (Fz_lb*Fz_lb + Fz*Fz)/2.0
    Fz_sq_mean_u1 = ((1.0-Fz_ub)*(1.0-Fz_ub) + (1.0-Fz)*(1.0-Fz))/2.0
    loss += torch.sum(len_interval * ratio * Fz_sq_mean_l0)
    loss += torch.sum(len_interval * (1.0-ratio[uncensored]) * Fz_sq_mean_u1[uncensored])

    idx_uncensored = idx[uncensored].view(-1)
    upper_fill = np.tri(n_bin, n_bin, -1, dtype=np.float32).T[idx_uncensored]
    upper_fill = torch.from_numpy(upper_fill.astype(np.float32))
    loss += torch.sum(upper_fill * len_intervals * F_sq_mean1[uncensored])

    return loss / f_pred.shape[0]

def Portnoy_pwl(f_pred, z, e, z_max, withoutEM=False):
    n_bin = f_pred.shape[1]
    uncensored = e.bool()
    loss = 0.0

    z = z.view(-1,1)
    taus = torch.linspace(0.0, 1.0, steps=n_bin+1)
    F_pred = util.convert_f2F(f_pred) * z_max
    left = (F_pred[:,1:-1] > z)

    left_uncensored = left[uncensored]
    diff_uncensored = z[uncensored] - F_pred[uncensored,1:-1]
    loss += torch.sum((diff_uncensored * (taus[1:-1]-1.0))[left_uncensored])
    loss += torch.sum((diff_uncensored * taus[1:-1])[~left_uncensored])

    tau_c = util.time2quantiles_qr(F_pred[~uncensored], taus, z[~uncensored])
    w = ((taus[1:-1] - tau_c) / (1 - tau_c))
    if not withoutEM:
        w = w.detach()  # delete gradients
    left_censored = left[~uncensored]
    diff_censored = z[~uncensored] - F_pred[~uncensored,1:-1]
    diff_max = z_max - F_pred[~uncensored,1:-1]
    loss += torch.sum((diff_censored * (taus[1:-1]-1.0) * w)[left_censored])
    loss += torch.sum((diff_censored * taus[1:-1])[~left_censored])
    loss += torch.sum((diff_max * taus[1:-1] * (1-w))[left_censored])

    return loss / f_pred.shape[0]
