import numpy as np
import sys
import torch
import torch.nn as nn

import loss_function
import metric


def compute_loss(loss_fn, y_pred, z, e, z_max, args):
    if loss_fn == 'RPS':
        return loss_function.RPS(y_pred, z, e, z_max)
    if loss_fn == 'DeepHit':
        alpha = args.DeepHit_alpha
        loss = 0.0
        loss += loss_function.Lz(y_pred, z, e, z_max)
        loss += loss_function.Lcensored(y_pred, z, e, z_max)
        loss += alpha * loss_function.kernel_loss(y_pred, z, e, z_max)
        return loss
    if loss_fn == 'DRSA':
        alpha = args.DRSA_alpha
        loss = 0.0
        loss += alpha * loss_function.Lz(y_pred, z, e, z_max)
        loss += (1-alpha) * loss_function.Luncensored(y_pred, z, e, z_max)
        loss += (1-alpha) * loss_function.Lcensored(y_pred, z, e, z_max)
        return loss
    if loss_fn == 'Brier':
        return loss_function.Brier(y_pred, z, e, z_max)
    if loss_fn == 'logarithmic_pwl':
        return loss_function.logarithmic_pwl(y_pred, z, e, z_max, args.withoutEM)
    if loss_fn == 'logarithmic_simple_pwl':
        return loss_function.logarithmic_simple_pwl(y_pred, z, e, z_max)
    if loss_fn == 'Brier_pwl':
        return loss_function.Brier_pwl(y_pred, z, e, z_max, args.withoutEM)
    if loss_fn == 'ProperRPS_pwl':
        return loss_function.ProperRPS_pwl(y_pred, z, e, z_max, args.withoutEM)
    if loss_fn == 'SurvivalCRPS_pwl':
        return loss_function.SurvivalCRPS_pwl(y_pred, z, e, z_max)
    if loss_fn == 'Portnoy_pwl':
        return loss_function.Portnoy_pwl(y_pred, z, e, z_max, args.withoutEM)

    print('Unknown loss function name: '+loss_fn)
    sys.exit()

class SurvivalLoss(nn.Module):
    def __init__(self, args, z_max, lf_name):
        super(SurvivalLoss, self).__init__()
        self.args = args
        self.z_max = z_max
        self.pred_train = None
        self.lf_name = lf_name

    def forward(self, y_pred, z, e, train_val_test):
        return compute_loss(self.lf_name, y_pred, z[:,0], e, self.z_max,
                            self.args)
