import numpy as np
import sys
import torch
import torch.nn as nn

import loss_function
import util


def compute_loss(F_pred, G_pred, z, e, z_max):
    EPS = 0.000001
    n_bin = F_pred.shape[1] - 1
    uncensored = e.bool()
    loss = 0.0
    idx, ratio = loss_function.locate_z(z, e, z_max, n_bin, False)

    F_sq0 = F_pred[:,1:-1] * F_pred[:,1:-1]
    F_sq1 = (1.0 - F_pred[:,1:-1]) * (1.0 - F_pred[:,1:-1])

    lower_fill = np.tri(n_bin, n_bin, -1, dtype=np.float32)[idx.view(-1)]
    lower_fill = torch.from_numpy(lower_fill.astype(np.float32))
    lower_fill = lower_fill.view((F_pred.shape[0], n_bin))

    Gu, _ = loss_function.compute_censored(G_pred[uncensored],
                                            idx[uncensored].view(-1,1),
                                            ratio[uncensored])

    temp1 = F_sq1[uncensored] * (1-lower_fill[uncensored,:-1]) / (1.0 - Gu.view(-1,1) + EPS)
    temp2 = F_sq0 * lower_fill[:,:-1] / (1.0 - G_pred[:,1:-1] + EPS)
    return torch.sum(temp1) + torch.sum(temp2)

class SurvivalGameLoss(nn.Module):
    def __init__(self, args, z_max):
        super(SurvivalGameLoss, self).__init__()
        self.args = args
        self.z_max = z_max

    def forward(self, f_pred, g_pred, z, e, train_val_test):
        F_pred = util.convert_f2F(f_pred)
        G_pred = util.convert_f2F(g_pred)
        loss1 = compute_loss(F_pred, G_pred, z, e, self.z_max)
        loss2 = compute_loss(G_pred, F_pred, z, 1-e, self.z_max)
        return loss1 + loss2
