import torch
import math


def accuracy(predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
    predictions = predictions.argmax(dim=1).view(targets.shape)
    return (predictions == targets).sum().float() / targets.size(0)


def is_correct(predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
    predictions = predictions.argmax(dim=1).view(targets.shape)
    return predictions == targets


def prob_accuracy(accuracy: float, n_class: int, n_data: int):
    """
    Caulculate the probabily that the accuracy value happens with random prediction
    """
    assert 0. <= accuracy <= 1.
    n_correct = round(n_data * accuracy)
    return prob_n_correct(n_correct, n_class, n_data)


def prob_higher_accuracy(accuracy: float, n_class: int, n_data: int):
    assert 0. <= accuracy <= 1.
    n_correct = round(n_data * accuracy)
    prob = sum([
        prob_n_correct(n, n_class, n_data)
        for n in range(n_correct, n_data + 1)])
    
    return prob


def prob_n_correct(n_correct: int, n_class: int, n_data: int):
    assert n_correct <= n_data
    return math.comb(n_data, n_correct) *\
        ((1 / n_class)**n_correct) * \
        (1 - 1 / n_class)**(n_data - n_correct)
    
