import numpy as np
from torch import Tensor
import torch

class AverageDrop():
    def __init__(self):
        self.running_scores = []
        self.number_instances = 0

    def update(self, preds: Tensor, target: Tensor) -> None:
        assert preds.shape == target.shape
        tmp = 100 * torch.relu(target - preds) / target
        self.running_scores.append(tmp.sum().item())
        self.number_instances += preds.shape[0]

    def compute(self)->float:
        ad_scores = np.array(self.running_scores)
        return ad_scores.sum()/self.number_instances
    
class AverageGain():
    def __init__(self):
        self.running_scores = []
        self.number_instances = 0

    def update(self, preds: Tensor, target: Tensor) -> None:
        assert preds.shape == target.shape
        tmp = 100 * torch.relu(preds-target) / (1-target+1e-8)
        self.running_scores.append(tmp.sum().item())
        self.number_instances += preds.shape[0]

    def compute(self)->float:
        ad_scores = np.array(self.running_scores)
        return ad_scores.sum()/self.number_instances
    
class AverageIncrease():
    def __init__(self):
        self.number_of_increases = 0
        self.number_instances = 0

    def update(self, preds: Tensor, target: Tensor) -> None:
        assert preds.shape == target.shape
        self.number_of_increases += torch.sum(preds>target)
        self.number_instances += len(preds)

    def compute(self)->float:
        return (100 * self.number_of_increases / self.number_instances).item()
