# https://github.com/YvanYin/Metric3D/blob/main/mono/utils/avg_meter.py
import numpy as np
import torch


def reformat_input(x):
    if not isinstance(x, torch.Tensor):
        x = torch.from_numpy(x)
    x = x.to(torch.float)
    return x


def absrel_pnt(pred, target, mask):
    pred, target, mask = reformat_input(pred), reformat_input(target), reformat_input(mask)
    assert pred.dim() == 3 and target.dim() == 3 and mask.dim() == 2
    if mask.sum() == 0:
        return None, None

    dist_gt = torch.norm(target, dim=-1)
    dist_err = torch.norm(pred - target, dim=-1)
    err_heatmap = dist_err / (dist_gt + (1e-10)) * mask
    err = err_heatmap.sum() / mask.sum()
    return err_heatmap.cpu().numpy(), err.item()


def absrel(pred, target, mask):
    pred, target, mask = reformat_input(pred), reformat_input(target), reformat_input(mask)
    assert pred.dim() == 2 and target.dim() == 2 and mask.dim() == 2
    if mask.sum() == 0:
        return None, None

    t_m = target * mask
    p_m = pred * mask
    t_m[mask < .5] = 0
    p_m[mask < .5] = 0

    err_heatmap = torch.abs(t_m - p_m) / (t_m + 1e-10)  # (H, W)
    err = err_heatmap.sum() / mask.sum()
    return err_heatmap.cpu().numpy(), err.item()


def rmse(pred, target, mask):
    pred, target, mask = reformat_input(pred), reformat_input(target), reformat_input(mask)
    assert pred.dim() == 2 and target.dim() == 2 and mask.dim() == 2
    if mask.sum() == 0:
        return None, None

    t_m = target * mask
    p_m = pred * mask
    t_m[mask < .5] = 0
    p_m[mask < .5] = 0

    err_heatmap = (t_m - p_m) ** 2  # (H, W)
    err = torch.sqrt(err_heatmap.sum() / mask.sum())
    return err_heatmap.cpu().numpy(), err.item()


def rmse_log(pred, target, mask):
    pred, target, mask = reformat_input(pred), reformat_input(target), reformat_input(mask)
    assert pred.dim() == 2 and target.dim() == 2 and mask.dim() == 2
    if mask.sum() == 0:
        return None, None

    t_m = target * mask
    p_m = pred * mask
    t_m[mask < .5] = 0
    p_m[mask < .5] = 0

    err_heatmap = ((torch.log10(p_m+1e-10) - torch.log10(t_m+1e-10)) * mask) ** 2  # (H, W)
    err = torch.sqrt(err_heatmap.sum() / mask.sum())
    return err_heatmap.cpu().numpy(), err.item()


def delta1(pred, target, mask):
    pred, target, mask = reformat_input(pred), reformat_input(target), reformat_input(mask)
    assert pred.dim() == 2 and target.dim() == 2 and mask.dim() == 2
    if mask.sum() == 0:
        return None, None

    t_m = target * mask
    p_m = pred

    gt_pred = t_m / (p_m + 1e-10)  # (H, W)
    pred_gt = p_m / (t_m + 1e-10)  # (H, W)
    gt_pred_gt = torch.stack([gt_pred, pred_gt], dim=-1)  # (H, W, 2)
    ratio_max = torch.amax(gt_pred_gt, dim=-1)  # (H, W)
    err_heatmap = (ratio_max - 1) * mask  # (H, W)
    ratio_max[mask < .5] = 99999

    delta_1_sum = torch.sum(ratio_max < 1.25)
    delta_2_sum = torch.sum(ratio_max < 1.25 ** 2)
    delta_3_sum = torch.sum(ratio_max < 1.25 ** 3)
    return err_heatmap.cpu().numpy(), (delta_1_sum / mask.sum()).item()


def delta2(pred, target, mask):
    pred, target, mask = reformat_input(pred), reformat_input(target), reformat_input(mask)
    assert pred.dim() == 2 and target.dim() == 2 and mask.dim() == 2
    if mask.sum() == 0:
        return None, (None, None, None)

    t_m = target * mask
    p_m = pred

    gt_pred = t_m / (p_m + 1e-10)  # (H, W)
    pred_gt = p_m / (t_m + 1e-10)  # (H, W)
    gt_pred_gt = torch.stack([gt_pred, pred_gt], dim=-1)  # (H, W, 2)
    ratio_max = torch.amax(gt_pred_gt, dim=-1)  # (H, W)
    err_heatmap = (ratio_max - 1) * mask  # (H, W)
    ratio_max[mask < .5] = 99999

    delta_1_sum = torch.sum(ratio_max < 1.25)
    delta_2_sum = torch.sum(ratio_max < 1.25 ** 2)
    delta_3_sum = torch.sum(ratio_max < 1.25 ** 3)
    return err_heatmap.cpu().numpy(), (delta_2_sum / mask.sum()).item()


def delta3(pred, target, mask):
    pred, target, mask = reformat_input(pred), reformat_input(target), reformat_input(mask)
    assert pred.dim() == 2 and target.dim() == 2 and mask.dim() == 2
    if mask.sum() == 0:
        return None, (None, None, None)

    t_m = target * mask
    p_m = pred

    gt_pred = t_m / (p_m + 1e-10)  # (H, W)
    pred_gt = p_m / (t_m + 1e-10)  # (H, W)
    gt_pred_gt = torch.stack([gt_pred, pred_gt], dim=-1)  # (H, W, 2)
    ratio_max = torch.amax(gt_pred_gt, dim=-1)  # (H, W)
    err_heatmap = (ratio_max - 1) * mask  # (H, W)
    ratio_max[mask < .5] = 99999

    delta_1_sum = torch.sum(ratio_max < 1.25)
    delta_2_sum = torch.sum(ratio_max < 1.25 ** 2)
    delta_3_sum = torch.sum(ratio_max < 1.25 ** 3)
    return err_heatmap.cpu().numpy(), (delta_3_sum / mask.sum()).item()


def delta0125(pred, target, mask):
    pred, target, mask = reformat_input(pred), reformat_input(target), reformat_input(mask)
    assert pred.dim() == 2 and target.dim() == 2 and mask.dim() == 2
    if mask.sum() == 0:
        return None, (None, None, None)

    t_m = target * mask
    p_m = pred

    gt_pred = t_m / (p_m + 1e-10)  # (H, W)
    pred_gt = p_m / (t_m + 1e-10)  # (H, W)
    gt_pred_gt = torch.stack([gt_pred, pred_gt], dim=-1)  # (H, W, 2)
    ratio_max = torch.amax(gt_pred_gt, dim=-1)  # (H, W)
    err_heatmap = (ratio_max - 1) * mask  # (H, W)
    ratio_max[mask < .5] = 99999

    delta_sum = torch.sum(ratio_max < 1.25 ** 0.125)
    return err_heatmap.cpu().numpy(), (delta_sum / mask.sum()).item()

def log10(pred, target, mask):
    pred, target, mask = reformat_input(pred), reformat_input(target), reformat_input(mask)
    assert pred.dim() == 2 and target.dim() == 2 and mask.dim() == 2
    if mask.sum() == 0:
        return None, None

    t_m = target * mask
    p_m = pred * mask
    t_m[mask < .5] = 0
    p_m[mask < .5] = 0

    err_heatmap = torch.abs((torch.log10(p_m+1e-10) - torch.log10(t_m+1e-10)) * mask)
    err = err_heatmap.sum() / mask.sum()
    return err_heatmap.cpu().numpy(), err.item()


standard_metric_funcs = dict(
    absrel=absrel,
    rmse=rmse,
    rmse_log=rmse_log,
    delta1=delta1,
    delta2=delta2,
    delta3=delta3,
    log10=log10,
)

