import math
import torch
# out, time, delta=out, time_train, delta_train
def l2loss(out, time, delta):
    n = len(delta)
    time = time.view(n,1)
    delta = delta.view(n,1)
    A = torch.log(time) - out
    B = A*A*delta
    B = B.sum()
    return B
def eaftloss(out, time, delta): ##loss function for AFT or EH
    ia, ib = out.size()
    if ib == 1: ###loss function for AFT
        n = len(delta)
        h = 1.30*math.pow(n,-0.2)
        #h 1.304058*math.pow(n,-0.2)  ## 1.304058*n^(-1/5) or 1.587401*math.pow(n,-0.333333) 1.587401*n^(-1/3)
        time = time.view(n,1)
        delta = delta.view(n,1)
        
        # R = g(Xi) + log(Oi)
        R = torch.add(out,torch.log(time)) 
        
        # Rj - Ri
        rawones = torch.ones([1,n], dtype = out.dtype)
        R1 = torch.mm(R,rawones)
        R2 = torch.mm(torch.t(rawones),torch.t(R))
        DR = R1 - R2 
        
        # K[(Rj-Ri)/h]
        K = normal_density(DR/h)
        Del = torch.mm(delta, rawones)
        DelK = Del*K 
        
        # (1/nh) *sum_j Deltaj * K[(Rj-Ri)/h]
        Dk = torch.sum(DelK, dim=0)/(n*h)
        
        # log {(1/nh) * Deltaj * K[(Rj-Ri)/h]}    
        log_Dk = torch.log(Dk)     
        A = torch.t(delta)*log_Dk/n   
        S1 = A.sum()  
        
        ncdf=torch.distributions.normal.Normal(torch.tensor([0.0], dtype = out.dtype), torch.tensor([1.0], dtype = out.dtype)).cdf
        P = ncdf(DR/h)
        CDF_sum = torch.sum(P, dim=0)/n
        Q = torch.log(CDF_sum)
        S2 = -(delta*Q.view(n,1)).sum()/n
             
        S0 = -(delta*torch.log(time)).sum()/n
        
        S = S0 + S1 + S2 
        S = -S
    else: ### loss function for Extended hazard model
        n = len(out[:,0])
        h = 1.30*math.pow(n,-0.2)  ## or 1.59*n^(-1/3)
        time = time.view(n,1)
        delta = delta.view(n,1)
        g1 = out[:,0].view(n,1)
        g2 = out[:,1].view(n,1)
        
        # R = g(Xi) + log(Oi)
        R = torch.add(g1,torch.log(time)) 
        
        S1 =  (delta*g2).sum()/n
        S2 = -(delta*R).sum()/n
        
        # Rj - Ri
        rawones = torch.ones(1,n)
        R1 = torch.mm(R,rawones)
        R2 = torch.mm(torch.t(rawones),torch.t(R))
        DR = R1 - R2 
        
        # K[(Rj-Ri)/h]
        K = normal_density(DR/h)
        Del = torch.mm(delta, rawones)
        DelK = Del*K 
        
        # (1/nh) *sum_j Deltaj * K[(Rj-Ri)/h]
        Dk = torch.sum(DelK, dim=0)/(n*h)  ## Dk would be zero as learning rate too large!
        
        # log {(1/nh) * Deltaj * K[(Rj-Ri)/h]}    
        log_Dk = torch.log(Dk)    
        
        S3 = (torch.t(delta)*log_Dk).sum()/n    
        
        # Phi((Rj-Ri)/h)
        ncdf=torch.distributions.normal.Normal(torch.tensor([0.0]), torch.tensor([1.0])).cdf
        P = ncdf(DR/h) 
        L = torch.exp(g2-g1)
        LL = torch.mm(L,rawones)
        LP_sum = torch.sum(LL*P, dim=0)/n
        Q = torch.log(LP_sum)
        
        S4 = -(delta*Q.view(n,1)).sum()/n
        
        S = S1 + S2 + S3 + S4  
        S = -S
    return S

def normal_density(a):  
    b = 0.3989423*torch.exp(-0.5*torch.pow(a,2.0))
    return b

def coxloss(out, time, delta):
    n = len(delta)
    time = time.view(n,1)
    delta = delta.view(n,1)
    rawones = torch.ones(1,n)
    
    T1 = torch.mm(time,rawones)
    T2 = torch.t(T1)
    Risk = torch.as_tensor((T2 <= T1),dtype=T1.dtype)
    
    e = torch.sum(torch.mm(torch.exp(out),rawones)*Risk,dim=0)
    v = torch.log(e)
    
    S = -((out-v.view(n,1))*delta).sum()/n
    
    return S

