import math
import torch
import torch.nn.functional as F

def mse(pred, true):
    err = F.mse_loss(pred, true)
    return err

def rmse(pred, true):
    err = F.mse_loss(pred, true)
    err = math.sqrt(err)
    return err

def accuracy(pred, true):
    count=0
    for i in range(pred.shape[0]):
        if torch.equal(pred.argmax(axis=1)[i],true.argmax(axis=1)[i]):
            count+=1
    return count/pred.shape[0]

def direction_accuracy(pred, true):
    cond = (pred<1 & true<1) | (pred>1 & true>1)
    acc = torch.where(cond, 1, 0)
    acc = torch.mean(acc)
    
    return acc.item()

def mape(pred, true):
    err = torch.abs(torch.add(true, -pred)/true).view(-1).mean() * 100
    return err.item()