import torch
from utils import *

def GetAccuracyTensor(model, X, thresh=.5):
    prob = model(X)
    assert prob.shape[1]==2, "Function designed for binary classification."
    
    acc_tensor = torch.hstack(((prob[:,1]<thresh).reshape((-1,1)), (prob[:,1]>=thresh).reshape((-1,1)))).double()
    acc_tensor = acc_tensor.reshape((acc_tensor.shape[0], acc_tensor.shape[1], 1))
    return acc_tensor

def GetAccuracyMultiTensor(model, X, set_Y):
    ###
    prob = model(X)
    acc_tensor = OneHotEncode(model(X).argmax(dim=1), set_Y).double()
    acc_tensor = acc_tensor.reshape((acc_tensor.shape[0], acc_tensor.shape[1], 1))
    return acc_tensor

def GetPrecisionRecallTensor(model, X, thresh=.5):
    ###
    prob = model(X)
    assert prob.shape[1]==2, "Function designed for binary classification."
    tensor = torch.hstack(((prob[:,1]<thresh).reshape((-1,1)), (prob[:,1]>=thresh).reshape((-1,1)))).double()
    tensor[:, 0] = 0
    tensor = tensor.reshape((tensor.shape[0], tensor.shape[1], 1))
    return tensor

def GetLogLossTensor(model, X):
    eps = 1e-20
    p = model(X)
    loss_tensor = -torch.log(p + eps).double()
    loss_tensor = loss_tensor.reshape((loss_tensor.shape[0], loss_tensor.shape[1], 1))
    return loss_tensor