import math
import torch
import numpy as np
import torch.nn.functional as F

from typing import Optional, Union, Tuple
from typing_extensions import Literal
from precision_recall import (
    compute_precision_and_recall,
    average_precision_and_recall_curves,
)


DIMS_DICT = {
    "micro" : (0, 1),
    "per_sample": 1,
    "per_label": 0,
}


def convert_logits_to_probs(
    logits: torch.Tensor,
) -> torch.Tensor:
    
    if torch.min(logits) < 0 or torch.max(logits) > 1:
        return torch.sigmoid(logits)
    
    return logits

def select_present_labels(
    logits: torch.Tensor,
    labels: torch.Tensor,
    clusters: torch.Tensor,
    remove_null_labels: bool,
    remove_null_samples: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
    
    if remove_null_labels:
        good_labels = torch.sum(labels, dim=0) > 0
        logits = logits[:, good_labels]
        labels = labels[:, good_labels]
    
    if remove_null_samples:
        good_samples = torch.sum(labels, dim=1) > 0
        logits = logits[good_samples]
        labels = labels[good_samples]
        if clusters is not None:
            clusters = clusters[good_samples]
    
    return logits, labels, clusters

def compute_weights_from_clusters(
    clusters: Optional[torch.Tensor] = None,
    averaging_type: Optional[Literal["micro", "per_sample", "per_label"]] = "per_sample",
    average_before_metric: Optional[bool] = True,
    eps: Optional[float] = 1e-12,
) -> Tuple[Union[torch.Tensor, None], Union[torch.Tensor, None]]:
    
    if clusters is None or averaging_type == 'micro' or (averaging_type == "per_sample" and average_before_metric):
        return None, None
    
    one_hot_clusters = F.one_hot(clusters)
    weights_per_cluster = 1 / (one_hot_clusters.sum(dim=0, keepdim=True) + eps)
    weights_per_cluster[weights_per_cluster < 1 / one_hot_clusters.shape[0]] = 0
    weights_per_sample = torch.sum(one_hot_clusters * weights_per_cluster, dim=1)
    
    if averaging_type == "per_sample":
        return None, weights_per_sample
    
    return weights_per_sample.unsqueeze(1), None

def compute_f_beta_score(
    precision: torch.Tensor,
    recall: torch.Tensor,
    beta: float,
    eps: float = 1e-12,
) -> torch.Tensor:
    
    beta_squarred = beta ** 2
    return (1 + beta_squarred) * precision * recall / (beta_squarred * precision + recall + eps)


def compute_fmax_score(
    precision: torch.Tensor,
    recall: torch.Tensor,
    dim: Optional[int] = -1,
    eps: Optional[float] = 1e-12,
) -> torch.Tensor:
    
    f1_score = compute_f_beta_score(precision, recall, 1, eps)
    fmax, _ = torch.max(f1_score, dim=dim)
    return fmax

def compute_auc(
    precision: torch.Tensor,
    recall: torch.Tensor,
    dim: Optional[int] = -1,
    eps: Optional[float] = 1e-12,
) -> torch.Tensor:
    
    idx_before = torch.arange(precision.shape[dim] - 1, device=precision.device)
    idx_after = torch.arange(1, precision.shape[dim], device=precision.device)
    precision_values = torch.index_select(precision, dim, idx_before)
    delta_recall = torch.index_select(recall, dim, idx_before) - torch.index_select(recall, dim, idx_after)
    return torch.sum(precision_values * delta_recall, dim=dim)

METRICS_DICT = {
    "fmax" : compute_fmax_score,
    "auprc" : compute_auc,
}

def compute_precision_recall_metrics(
    logits: torch.Tensor,
    labels: torch.Tensor,
    metric_type: Optional[Literal["fmax", "auprc"]] = "fmax",
    averaging: Optional[Literal["micro", "per_sample", "per_label"]] = "per_sample",
    average_before_metric: Optional[bool] = True,
    n_thresholds: Optional[float] = 100,
    equidistant_thresholds: Optional[bool] = True,
    clusters: Optional[torch.Tensor] = None,
    remove_null_labels: Optional[bool] = None,
    remove_null_samples: Optional[bool] = None,
    eps: Optional[float] = 1e-12,
) -> float:
    
    """
    Args:
        `logits` (torch.Tensor): [n_samples, n_labels]
        `labels` (torch.Tensor): [n_samples, n_labels]
        `metric_type` (str, optional): 'fmax' or 'auprc'
        `averaging` (str, optional): The type of averaging to apply to the data after precision and recall have been computed. Either 'micro', 'per_sample' or 'per_label'.
        `average_before_metric` (bool, optional): If `True`, applies the averaging on the precision and recall curves. If `False`, applies the averaging on the Fmax scores. Defaults to `True`.
        `n_thresholds` (int, optional): The number of thresholds to evaluate precision and recall on. Defaults to 100.
        `equidistant_thresholds` (bool, optional): If `True`, the thresholds are linearly distributed between 0 and 1. Does not support a better option yet. Defaults to `True`.
        `clusters` (torch.Tensor, optional): [n_samples] The clusters to which each sample belongs. If provided, a weighted average will be computed over the sample, with weights being equal to 1 / cluster_size. Defaults to `None`.
        `remove_null_labels` (bool, optional): If `True`, all columns of labels that only contain null values are removed. If `None`, the parameter will be `True` if `averaging` is set to `per_label` and `False` otherwise. Defaults to `None`.
        `remove_null_samples` (bool, optional): If `True`, all rows of samples that only contain null values are removed. If `None`, the parameter will be `True` if `averaging` is set to `per_sample` and `False` otherwise. Defaults to `None`.
        `eps` (float, optional): The error to add to the denominator to avoid division by zero. Defaults to 1e-12.
    
    Returns:
        (float): _description_
    """
    
    metric = METRICS_DICT[metric_type]
    if remove_null_labels is None and averaging == "per_label":
        remove_null_labels = True
    elif remove_null_labels is None:
        remove_null_labels = False
    if remove_null_samples is None and averaging == "per_sample":
        remove_null_samples = True
    elif remove_null_samples is None:
        remove_null_samples = False
    
    logits, labels, clusters = select_present_labels(logits, labels, clusters, remove_null_labels, remove_null_samples)
    probs = convert_logits_to_probs(logits)
    precision_recall_weights, averaging_weights = compute_weights_from_clusters(clusters, averaging, average_before_metric, eps)
    precision, recall, predictions_sum = compute_precision_and_recall(
        probs,
        labels,
        dim=DIMS_DICT[averaging],
        n_thresholds=n_thresholds,
        equidistant_thresholds=equidistant_thresholds,
        weights=precision_recall_weights,
        eps=eps,
    )
    
    if averaging == "micro":
        return metric(precision, recall, dim=0, eps=eps).detach().item()
    
    if average_before_metric:
        precision, recall = average_precision_and_recall_curves(
            precision,
            recall,
            predictions_sum,
            clusters=clusters if averaging == "per_sample" else None,
            eps=eps,
        )
        return metric(precision, recall, dim=0, eps=eps).detach().item()
    
    scores = metric(precision, recall, dim=1, eps=eps)
    if averaging_weights is None:
        return torch.mean(scores).detach().item()
    
    return (torch.sum(averaging_weights * scores) / (torch.sum(averaging_weights) + eps)).detach().item()