import math
import torch
import numpy as np
import torch.nn.functional as F

from typing import Optional, Union, Tuple




def reshape_probs_and_labels_for_multidim(
    probs: torch.Tensor,
    labels: torch.Tensor,
    dim: Tuple[int],
) -> Tuple[torch.Tensor, torch.Tensor, int]:
    
    remaining_dims = [d for d in range(len(probs.shape)) if d not in dim]
    permutation = remaining_dims + list(dim)
    new_shape = [probs.shape[d] for d in remaining_dims] + [-1]
    probs = torch.permute(probs, permutation).reshape(new_shape)
    labels = torch.permute(labels, permutation).reshape(new_shape)
    
    return probs.unsqueeze(-1), labels.unsqueeze(-1), len(probs.shape) - 1


def format_probs_and_labels(
    probs: torch.Tensor,
    labels: torch.Tensor,
    dim: Union[int, Tuple[int]],
) -> Tuple[torch.Tensor, torch.Tensor]:
    
    if type(dim) != int:
        return reshape_probs_and_labels_for_multidim(probs, labels, dim)
    
    return probs.unsqueeze(-1), labels.unsqueeze(-1), dim


def format_weights(
    weights: Union[torch.Tensor, None],
    probs: torch.Tensor,
    dim: int,
) -> torch.Tensor:
    if weights is not None:
        if len(weights.shape) == 1:
            assert weights.shape[0] == probs.shape[dim], \
            (
                "If you provide 1D weights, please make sure"
                "their size is equal to the products of the dimensions"
                "on which precision and recall are computed."
                f"Provided weights are of shape {weights.shape[0]}"
                f"and did not match dimension of size {probs.shape[dim]}."
            )
        
        else:
            assert np.prod(weights.shape) == probs.shape[dim], \
            (
                "Please make sure that the product of the weights dimensions"
                "is equal to the product of the dimensions on which"
                "precision and recall are computed."
                f"Provided weights are of shape {np.prod(weights.shape[0])} once flattened"
                f"and did not match dimension of size {probs.shape[dim]}."
            )
    
    else:
        weights = torch.ones(probs.shape[dim], device=probs.device)
    
    new_shape = [1] * len(probs.shape)
    new_shape[dim] = probs.shape[dim]
    #repeat_shape = probs.shape
    #repeat_shape[dim] = 1
    return weights.reshape(new_shape) #.repeat(repeat_shape) # weights.shape == probs.shape


def compute_precision_weights_from_clusters(
    one_hot_clusters: torch.Tensor,
    predictions_sum: torch.Tensor,
    eps: Optional[float] = 1e-12,
) -> torch.Tensor:
    
    one_hot_clusters = one_hot_clusters.unsqueeze(2)
    one_hot_clusters_positive_predictions_sum = one_hot_clusters * (predictions_sum > 0).unsqueeze(1)
    cluster_sizes_positive_predictions_sum = one_hot_clusters_positive_predictions_sum.sum(dim=0, keepdim=True)
    weights_per_cluster_and_threshold = 1 / (cluster_sizes_positive_predictions_sum + eps)
    weights_per_cluster_and_threshold[weights_per_cluster_and_threshold < 1 / one_hot_clusters.shape[0]] = 0
    return torch.sum(one_hot_clusters * weights_per_cluster_and_threshold, dim=1) * (predictions_sum > 0)


def compute_recall_weights_from_clusters(
    one_hot_clusters: torch.Tensor,
    eps: Optional[float] = 1e-12,
) -> torch.Tensor:
    
    weights = 1 / (one_hot_clusters.sum(dim=0, keepdim=True) + eps)
    weights[weights < 1 / one_hot_clusters.shape[0]] = 0
    return torch.sum(one_hot_clusters * weights, dim=1)


def compute_weights_from_clusters_averaging(
    clusters: torch.Tensor,
    predictions_sum: torch.Tensor,
    eps: Optional[float] = 1e-12
) -> Tuple[torch.Tensor, torch.Tensor]:
    
    # clusters: [n_samples]
    # predictions_sum: [n_samples, n_thresholds]
    
    one_hot_clusters = F.one_hot(clusters)
    precision_weights = compute_precision_weights_from_clusters(one_hot_clusters, predictions_sum, eps)
    recall_weights = compute_recall_weights_from_clusters(one_hot_clusters, eps).unsqueeze(1)
    return precision_weights, recall_weights


def retrieve_thresholds_from_probs(
    probs: torch.Tensor,
    labels: torch.Tensor,
) -> torch.Tensor:
    
    #raise NotImplementedError("Will be implemented soon")
    
    probs_flat = torch.flatten(probs)
    labels_flat = torch.flatten(labels)
    sorted_thresholds, _ = torch.sort(probs_flat[labels_flat == 1])
    start = torch.zeros(1, dtype=probs.dtype, device=probs.device)
    end = torch.ones(1, dtype=probs.dtype, device=probs.device)
    return torch.cat((start, sorted_thresholds, end))


def retrieve_thresholds(
    probs: torch.Tensor,
    labels: torch.Tensor,
    n_thresholds: Optional[int] = 100,
    equidistant_thresholds: Optional[bool] = True,
) -> torch.Tensor:
    
    if equidistant_thresholds:
        thresholds = torch.linspace(0, 1, n_thresholds, device=probs.device)
    
    else:
        thresholds = retrieve_thresholds_from_probs(probs, labels)
    
    new_shape = [1] * (len(probs.shape) - 1) + [thresholds.shape[0]]
    return thresholds.reshape(new_shape)


def compute_precision_and_recall(
    probs: torch.Tensor,
    labels: torch.Tensor,
    dim: Optional[Union[int, Tuple[int]]] = 0,
    n_thresholds: Optional[int] = 100,
    equidistant_thresholds: Optional[bool] = True,
    weights: Optional[torch.Tensor] = None,
    eps: Optional[float] = 1e-12,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Compute precision and recall using an procedure inspired by micro-AUPRC computations
    This procedure adds the possibility to compute precision and recall:
        - on any number of dimensions at the same time (they are flattened into one dimension before computations)
        - with weights specific to each label

    Args:
        `probs` (torch.Tensor): Probabilities between 0 and 1 used for prediction.
        `labels` (torch.Tensor): The ground truth labels.
        `dim` (Union[int, Tuple[int]], optional): The dimension(s) on which precision and recall should be computed. Defaults to 0.
        `n_thresholds` (int, optional): The number of thresholds to evaluate precision and recall on. Defaults to 100.
        `equidistant_thresholds` (bool, optional): If `True`, the thresholds are linearly distributed between 0 and 1. Does not support a better option yet. Defaults to `True`.
        `weights` (torch.Tensor, optional): The weights specific to each label on the dimensions given by `dim`. Putting weights at this stage is highly debatable though. Defaults to None.
        `eps` (float, optional): The error to add to the denominator to avoid division by zero. Defaults to 1e-12.

    Returns:
        (Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): The precision and recall evaluated at every possible threshold, as well as the number of positive predictions at each threshold.
    """
    
    probs, labels, dim = format_probs_and_labels(probs, labels, dim)
    thresholds = retrieve_thresholds(probs, labels, n_thresholds, equidistant_thresholds)
    weights = format_weights(weights, probs, dim)
    
    predictions = probs >= thresholds
    true_positive = predictions * labels
    labels_sum = torch.sum(weights * labels, dim=dim)
    predictions_sum = torch.sum(weights * predictions, dim=dim)
    true_positive_sum = torch.sum(weights * true_positive, dim=dim)
    
    precision = true_positive_sum / (predictions_sum + eps)
    recall = true_positive_sum / (labels_sum + eps)
    return precision, recall, predictions_sum


def average_precision_and_recall_curves(
    precision: torch.Tensor,
    recall: torch.Tensor,
    predictions_sum: torch.Tensor,
    clusters: Optional[torch.Tensor] = None,
    eps: Optional[float] = 1e-12,
) -> Tuple[torch.Tensor, torch.Tensor]:
    
    """
    Args:
        `precision` (torch.Tensor): [n_samples, n_thresholds] or [n_labels, n_thresholds]
        `recall` (torch.Tensor): [n_samples, n_thresholds] or [n_labels, n_thresholds]
        `predictions_sum` (torch.Tensor): [n_samples, n_thresholds] or [n_labels, n_thresholds]
        `clusters` (torch.Tensor): [n_samples] or [n_labels]
        `eps` (float, optional): The error to add to the denominator to avoid division by zero

    Returns:
        (Tuple[torch.Tensor, torch.Tensor]) : The precision and recall averaged over dimension 0 (each of size [n_thresholds])
    """
    
    if clusters is None:
        precision = torch.sum(precision, dim=0) / (torch.sum(recall > 0, dim=0) + eps)
        recall = torch.mean(recall, dim=0)
    
    else:
        precision_weights, recall_weights = compute_weights_from_clusters_averaging(clusters, predictions_sum)
        precision = torch.sum(precision_weights * precision, dim=0) / (torch.sum(precision_weights, dim=0) + eps)
        recall = torch.sum(recall_weights * recall, dim=0) / (torch.sum(recall_weights, dim=0) + eps)
    
    return precision, recall







































"""
def compute_precision_and_recall(
    logits: torch.Tensor,
    labels: torch.Tensor,
    dim: Optional[Union[int, Tuple[int]]] = 0,
    weights: Optional[torch.Tensor] = None,
):
    Compute precision and recall using an procedure inspired by micro-AUPRC computations
    This procedure adds the possibility to compute precision and recall:
        - on any number of dimensions at the same time (they are flattened into one dimension before computations)
        - with weights specific to each label

    Args:
        `logits` (torch.Tensor): The output of the model
        `labels` (torch.Tensor): The ground truth labels
        `dim` (Optional[Union[int, Tuple[int]]], optional): The dimension(s) on which precision and recall should be computed. Defaults to 0.
        `weights` (Optional[torch.Tensor], optional): The weights specific to each label on the dimensions given by `dim`. Defaults to None.

    Returns:
        (Tuple[torch.Tensor, torch.Tensor]): The precision and recall evaluated at every possible threshold.
    
    logits, labels, dim = format_logits_and_labels(logits, labels, dim)
    weights = format_weights(weights, logits, dim)
    
    order = logits.argsort(descending=True, dim=dim)
    weighted_labels = torch.gather(weights * labels, dim, order)
    weights = torch.gather(weights, dim, order)
    weighted_labels_cumsum = weighted_labels.cumsum(dim=dim)
    precision = weighted_labels_cumsum / weights.cumsum(dim=dim)
    recall = weighted_labels_cumsum / weighted_labels.sum(dim=dim, keepdims=True)
    
    return precision, recall
"""