import numpy as np
import torch
import torch.nn.functional as F


def compute_critical_score(net, loader, mode="l2", alpha=0.1, device="cpu"):
    n = len(loader.dataset)
    scores = torch.zeros(n).to(device)
    count = 0
    for (i, batch) in enumerate(loader):
        inputs, targets = batch[0].to(device), batch[1].to(device)
        if mode in ["l2", "l1"]:
            scores[count:(count + inputs.shape[0])] = torch.abs(
                net(inputs).squeeze() - targets
            )
        elif mode == "quantile":
            scores[count:(count + inputs.shape[0])] = torch.max(
                targets - net(inputs)[:, 1], net(inputs)[:, 0] - targets
            )
        count += inputs.shape[0]
    critical_score = torch.sort(scores)[0][int(np.ceil((1 - alpha) * (n + 1)))]
    return critical_score


def evaluate_coverage_length(net, loader, critical_score, mode="l2", device="cpu"):
    n = len(loader.dataset)
    coverage, length = 0., 0.
    for (i, batch) in enumerate(loader):
        inputs, targets = batch[0].to(device), batch[1].to(device)
        if mode in ["l2", "l1"]:
            coverage += torch.sum(torch.abs(targets - net(inputs).squeeze()) < critical_score)
            length += (2*critical_score) * inputs.shape[0]
        elif mode == "quantile":
            coverage += torch.sum((targets >= net(inputs)[:, 0] - critical_score)
                                  * (targets <= net(inputs)[:, 1] + critical_score))
            length += torch.sum(F.relu(net(inputs)[:, 1] - net(inputs)[:, 0] + 2 * critical_score))
    coverage /= n
    length /= n
    return coverage, length