# Copyright 2023 solo-learn development team.

# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to use,
# copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
# Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all copies
# or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
# PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
# FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

from typing import Dict, List, Sequence, Tuple

import torch


def accuracy_at_k(
    outputs: torch.Tensor, targets: torch.Tensor, top_k: Sequence[int] = (1, 5)
) -> Sequence[int]:
    """Computes the accuracy over the k top predictions for the specified values of k.

    Args:
        outputs (torch.Tensor): output of a classifier (logits or probabilities).
        targets (torch.Tensor): ground truth labels.
        top_k (Sequence[int], optional): sequence of top k values to compute the accuracy over.
            Defaults to (1, 5).

    Returns:
        Sequence[int]:  accuracies at the desired k.
    """

    with torch.no_grad():
        maxk = max(top_k)
        batch_size = targets.size(0)

        _, pred = outputs.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(targets.view(1, -1).expand_as(pred))

        res = []
        for k in top_k:
            correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


def weighted_mean(outputs: List[Dict], key: str, batch_size_key: str) -> float:
    """Computes the mean of the values of a key weighted by the batch size.

    Args:
        outputs (List[Dict]): list of dicts containing the outputs of a validation step.
        key (str): key of the metric of interest.
        batch_size_key (str): key of batch size values.

    Returns:
        float: weighted mean of the values of a key
    """

    value = 0
    n = 0
    for out in outputs:
        value += out[batch_size_key] * out[key]
        n += out[batch_size_key]
    value = value / n
    return value.squeeze(0)

#############
# Sparsity Metrics -- since our entire project is about sparsity; let's add these here. 
#############

def batch_sparsity_metric(tensor_data: torch.Tensor, epsilon: float = 1e-12) -> Tuple[float, float, float]:
    """
    Calculates the sparsity metric S = (||z||_1)^2 / (||z||_2)^2 for each column of a PyTorch tensor.
    
    Args:
        tensor_data (torch.Tensor): The input tensor of shape (B, D), where B is batch size
                                    and D is the number of columns/features.
        epsilon (float): A small value added to the denominator to prevent division by zero.
                         Default is 1e-12.
                         
    Returns:
        Tuple[float, float, float]: max, mean, min sparsity metric across dimensions.
    """
    if not isinstance(tensor_data, torch.Tensor):
        raise TypeError("Input must be a torch.Tensor.")
    if tensor_data.ndim != 2:
        raise ValueError("Input tensor must be 2-dimensional (B x D).")
    
    B, D = tensor_data.shape

    # Calculate L1 norm for each column
    l1_norm_per_column = torch.linalg.norm(tensor_data, ord=1, dim=0)
    
    # Calculate squared L2 norm for each column
    l2_norm_sq_per_column = torch.linalg.norm(tensor_data, ord=2, dim=0)**2
    
    # Add epsilon to the denominator for numerical stability
    # The result will be a 1D tensor (D,)
    sparsity_metric_per_column = (l1_norm_per_column**2) / (l2_norm_sq_per_column + epsilon)
    
    # normalize by batch size
    sparsity_metric_per_column = sparsity_metric_per_column / B

    # get max, mean, min
    sparsity_metric_max = sparsity_metric_per_column.max().item()
    sparsity_metric_mean = sparsity_metric_per_column.mean().item()
    sparsity_metric_min = sparsity_metric_per_column.min().item()

    return sparsity_metric_max, sparsity_metric_mean, sparsity_metric_min


def embedding_sparsity_metric(embeddings: torch.Tensor, epsilon: float = 1e-12) -> Tuple[float, float, float]:
    """
    Calculates the embedding sparsity metric S = (1/D) * (||z_row||_1 / ||z_row||_2)^2 for each
    row (embedding) of a PyTorch tensor. This measures how sparsely features are
    activated within each embedding.

    Args:
        embeddings (torch.Tensor): The input tensor of shape (B, D), where B is batch size
                                   (number of embeddings) and D is the number of
                                   dimensions/features per embedding.
        epsilon (float): A small value added to the L2 norm before division to prevent
                         division by zero. Default is 1e-12.

    Returns:
        Tuple[float, float, float]: (max_sparsity, mean_sparsity, min_sparsity)
                                     calculated over the B embedding sparsity values.
    """
    if not isinstance(embeddings, torch.Tensor):
        raise TypeError("Input must be a torch.Tensor.")
    if embeddings.ndim != 2:
        raise ValueError("Input tensor must be 2-dimensional (B x D).")

    B, D = embeddings.shape

    # Calculate L1 norm for each row (embedding)
    l1_norm_per_row = torch.linalg.norm(embeddings, ord=1, dim=1)

    # Calculate L2 norm for each row (embedding)
    l2_norm_per_row = torch.linalg.norm(embeddings, ord=2, dim=1)

    l2_norm_per_row_stable = l2_norm_per_row + epsilon

    ratio = l1_norm_per_row / l2_norm_per_row_stable

    metric_per_embedding = (1.0 / D) * (ratio**2)

    # Get max, mean, min
    max_sparsity = metric_per_embedding.max().item()
    mean_sparsity = metric_per_embedding.mean().item()
    min_sparsity = metric_per_embedding.min().item()

    return max_sparsity, mean_sparsity, min_sparsity


def count_avg_nonzero_elements_per_dimension(tensor: torch.Tensor) -> float:
    return ((tensor != 0).sum(dim=0) / tensor.shape[0]).mean().item()


def count_avg_nonzero_elements_per_sample(tensor: torch.Tensor) -> float:
    return ((tensor != 0).sum(dim=1) / tensor.shape[1]).mean().item()


# The reason we have active_feature_fraction separately is because this metric is useful when we do energy based sparasity optimization
def active_feature_fraction(tensor_data: torch.Tensor, threshold: float = 1e-3) -> float:
    """
    Computes the fraction of active (nonzero) features in a batch.
    tensor_data: (B, D) -- batch of features
    """

    active_mask = (tensor_data.abs() > threshold)
    active_count = active_mask.sum().item()
    total_count = tensor_data.numel()
    return active_count / total_count
