from typing import Dict, List, Sequence

import torch


def accuracy_at_k(
    outputs: torch.Tensor, targets: torch.Tensor, top_k: Sequence[int] = (1, 5)
) -> Sequence[int]:
    """Computes the accuracy over the k top predictions for the specified values of k.

    Args:
        outputs (torch.Tensor): output of a classifier (logits or probabilities).
        targets (torch.Tensor): ground truth labels.
        top_k (Sequence[int], optional): sequence of top k values to compute the accuracy over.
            Defaults to (1, 5).

    Returns:
        Sequence[int]:  accuracies at the desired k.
    """

    with torch.no_grad():
        maxk = max(top_k)
        batch_size = targets.size(0)

        _, pred = outputs.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(targets.view(1, -1).expand_as(pred))

        res = []
        for k in top_k:
            correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


def weighted_mean(outputs: List[Dict], key: str, batch_size_key: str) -> float:
    """Computes the mean of the values of a key weighted by the batch size.

    Args:
        outputs (List[Dict]): list of dicts containing the outputs of a validation step.
        key (str): key of the metric of interest.
        batch_size_key (str): key of batch size values.

    Returns:
        float: weighted mean of the values of a key
    """

    value = 0
    n = 0
    for out in outputs:
        if not batch_size_key in out.keys() or not key in out.keys():
            continue
         
        value += out[batch_size_key] * out[key]
        n += out[batch_size_key]

    if n == 0:
        return float('NaN')

    value = value / n
    return value.squeeze(0)


def CKA(X: torch.Tensor, Y: torch.Tensor) -> float:
    def HSIC(X, Y) -> float:
        GX = X @ X.T
        GY = Y @ Y.T

        n = GX.shape[0]
        H = torch.eye(n, device=X.device) - (1 / n)

        return torch.trace(GX @ H @ GY @ H)

    return HSIC(X, Y) / torch.sqrt(HSIC(X, X) * HSIC(Y, Y))
