import math
import numpy as np
import torch

def compute_mae(pred, target):
    '''
    :param pred: (batch_size, forecasting_length, multidim)
    :param target: (batch_size, forecasting_length, multidim)
    :return:
    '''
    loss = (pred - target).abs().mean(dim=[1,2])
    return loss

# mathmatical formulation can be found in https://arxiv.org/pdf/2202.11316.pdf https://arxiv.org/pdf/1704.04110.pdf
def compute_rho(rou: float, samples: torch.Tensor, labels: torch.Tensor, relative = False):
    '''
    :param rou: scalar
    :param samples: (batch_size, forecasting_length, multidim)
    :param labels: (sample_times, batch_size, forecasting_length, multidim)
    :param relative: scalar
    :return: scalar
    '''
    numerator = 0
    denominator = 0
    labels = labels.permute(0, 2, 1).flatten(0, 1)
    samples = samples.permute(0, 1, 3, 2).flatten(1, 2)
    pred_samples = samples.shape[0]
    for t in range(labels.shape[1]):
        zero_index = (labels[:, t] != 0)
        if zero_index.numel() > 0:
            rou_th = math.ceil(pred_samples * (1 - rou))
            rou_pred = torch.topk(samples[:, zero_index, t], dim=0, k=rou_th)[0][-1, :]
            abs_diff = labels[:, t][zero_index] - rou_pred
            numerator += 2 * (torch.sum(rou * abs_diff[labels[:, t][zero_index] > rou_pred]) - torch.sum(
                (1 - rou) * abs_diff[labels[:, t][zero_index] <= rou_pred])).item()
            denominator += torch.sum(labels[:, t][zero_index]).item()
    if relative:
        return [numerator, torch.sum(labels != 0).item()]
    else:
        return [numerator, np.abs(denominator)]
