import math
from typing import List, Optional, Tuple

import numpy as np
import torch
from torchmetrics.functional import precision_recall_curve
from torchmetrics import AveragePrecision, F1Score


def compute_term_centric_auprc(logits: torch.tensor, labels: torch.tensor) -> Tuple[float, float]:
    n_labels = labels.shape[-1]
    micro_aupr = AveragePrecision(task="multilabel", num_labels=n_labels, average='micro')
    macro_aupr = AveragePrecision(task="multilabel", num_labels=n_labels, average='macro')
    return micro_aupr(logits, labels).item(), macro_aupr(logits, labels).item()

def get_precision_recall(logits: torch.tensor, labels: torch.tensor) -> Tuple[float, float]:
    precision, recall, _ = precision_recall_curve(logits, labels.int(), task="binary")
    return torch.mean(precision).item(), torch.mean(recall).item()

def retrieve_thresholds(logits: torch.Tensor, labels: torch.Tensor, n_thresholds: int=100) -> torch.Tensor:
    """
    Instead of a linear mapping of the thresholds between 0 and 1,
    we extract the values of the logits that are most interesting for the change in recall
    (i.e. the highest values for each sample + some a bit under)
    so that recall is somewhat evenly distributed between 0 and 1 in order to compute AUC with the best precision.
    """

    # highest values per sample to set two thirds of the threshold values
    mean_n_labels = min(int(labels.shape[-1] / 2), 3 * math.ceil(torch.mean(torch.sum(labels, dim=-1).float()).item()))
    sorted_logits = torch.sort(torch.flatten(torch.sort(logits, dim=-1)[0][:, -mean_n_labels : ]), descending=True)[0]
    thresholds1 = sorted_logits[::math.ceil(1.5 * len(sorted_logits) / (n_thresholds - 1))]

    # complete with the upper remaining half of the values for each samples for the last third of the thresholds.
    sorted_logits = torch.sort(torch.flatten(torch.sort(logits, dim=-1)[0][:, -int(logits.shape[-1] / 2) : -mean_n_labels]), descending=True)[0]
    thresholds2 = sorted_logits[::math.ceil(3 * len(sorted_logits) / (n_thresholds - 1))]
    thresholds = torch.sort(torch.cat([thresholds1, thresholds2]), descending=True)[0]

    return torch.cat([torch.Tensor([1.]).to(logits.device), thresholds, torch.Tensor([0.]).to(logits.device)])

def precision_recall(probs: torch.Tensor, labels: torch.Tensor, n_thresholds: int=100, eps=1e-8) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Compute precision and recall for multiple thresholds.
    (used here for Fmax or AUC by transposing logits and labels dimensions --> so in case of AUC usage, swap n_samples and n_labels in the dimensions)
    """
    # retrieve thresholds so that recall is more or less evenly distributed in the range [0, 1] on n_threshods steps
    thresholds = retrieve_thresholds(probs, labels, n_thresholds)
    thresholds = thresholds.reshape((1, 1, thresholds.shape[0]))                            # [1, 1, n_thresholds]

    probs_ = probs.reshape((probs.shape[0], probs.shape[1], 1)) # [n_samples, n_labels, 1]
    labels_ = labels.reshape((labels.shape[0], labels.shape[1], 1)) # [n_samples, n_labels, 1]

    preds = probs_ >= thresholds                                              # [n_samples, n_labels, n_thresholds]
    P_inter_T = preds * labels_                                       # [n_samples, n_labels, n_thresholds]

    # precisions
    n_true_pos = torch.sum(P_inter_T, dim=-2)               # [n_samples, n_thresholds]
    preds_sum = torch.sum(preds, dim=-2)                             # [n_samples, 1]
    precisions_per_term = n_true_pos / (preds_sum + eps)       # [n_samples, n_thresholds]
    precisions = torch.sum(precisions_per_term, dim=-2) / (torch.sum(preds_sum > 0, dim=-2) + eps)  # [n_thresholds]

    # recalls
    labels_sum = torch.sum(labels_, dim=-2)                  # [n_samples, 1]
    recalls_per_term = n_true_pos / (labels_sum + eps)         # [n_samples, n_thresholds]
    recalls = torch.sum(recalls_per_term, dim=-2) / (torch.sum(labels_sum > 0, dim=-2) + eps)       # [n_thresholds]

    return precisions, recalls

def auc(precisions: torch.Tensor, recalls: torch.Tensor) -> float:
    """
    Compute AUC using trapezoid area.
    Recall is supposed to be ordered from 0 to 1.
    """
    avg = .5 * (precisions[1 : ] + precisions[ : -1])
    diff = recalls[1 : ] - recalls[ : -1]
    return torch.sum(avg * diff).item()

def precision_recall_with_clusters(
        probs: torch.Tensor, labels: torch.Tensor, clusters: torch.Tensor, n_thresholds: int=100, eps: float=1e-8
    ) -> Tuple[float, float]:
    """
    Function used to compute precision and recall for AUPRC
    logits and labels should be of size [n_labels, n_samples] while weights should be of size [n_samples]
    """
    # retrieve thresholds so that recall is more or less evenly distributed in the range [0, 1] on n_threshods steps
    thresholds = retrieve_thresholds(probs, labels, n_thresholds)
    thresholds = thresholds.reshape((1, 1, thresholds.shape[0]))                            # [1, 1, n_thresholds]

    probs = probs.reshape((probs.shape[0], probs.shape[1], 1)) # [n_labels, n_samples, 1]
    labels = labels.reshape((labels.shape[0], labels.shape[1], 1)) # [n_labels, n_samples, 1]
    clusters = clusters.reshape((1, clusters.shape[0], 1))               # [1, n_samples, 1]

    preds = probs >= thresholds                                              # [n_labels, n_samples, n_thresholds]
    correct_preds = clusters * preds * labels                                  # [n_labels, n_samples, n_thresholds]
    n_good_per_term = torch.sum(correct_preds, dim=-2)               # [n_labels, n_thresholds]
    preds_sum = torch.sum(clusters * preds, dim=-2)                        # [n_labels, n_thresholds]
    labels_sum = torch.sum(clusters * labels, dim=-2)                      # [n_labels, n_thresholds]
    precision_per_term = n_good_per_term / (torch.sum(clusters * preds, dim=-2) + eps)                  # [n_labels, n_thresholds]
    recall_per_term = n_good_per_term / (torch.sum(clusters * labels, dim=-2) + eps)                    # [n_labels, n_thresholds]
    precision = torch.sum(precision_per_term, dim=-2) / (torch.sum(preds_sum > 0, dim=-2) + eps)  # [n_thresholds]
    recall = torch.sum(recall_per_term, dim=-2) / (torch.sum(labels_sum > 0, dim=-2) + eps)       # [n_thresholds]

    return precision, recall

def compute_presence(x):
    """ Helper function to apply on a list of 'cluster' tensors with map. """
    return torch.sum(x > 0, dim=-2)

def compute_sum(x):
    """ Helper function to apply on a list of 'cluster' tensors with map."""
    return torch.sum(x, dim=-2)

def fmax(precisions: torch.Tensor, recalls: torch.Tensor, eps: float=1e-8) -> float:
    # computes the max of f1 scores for precision and recall at multiple thresholds
    f1_scores = 2 * precisions * recalls / (precisions + recalls + eps)
    return torch.max(f1_scores).item()

def compute_protein_centric_fmax_vanilla(probs: torch.tensor, labels: torch.tensor) -> float:
    probs = probs.T
    labels = labels.T
    f1_scores = []
    n_labels = labels.shape[-1]
    for t in np.arange(0, 1, 0.01):
        f1 = F1Score(task="multilabel", num_labels=n_labels, threshold=t)
        f1_scores.append(f1(probs, labels))
    return max(f1_scores).item()

def compute_protein_centric_reweighted_fmax(
        probs: torch.Tensor,
        labels: torch.Tensor,
        clusters: np.ndarray,
        n_thresholds: Optional[int]=100,
        eps: Optional[float]=1e-8,
        do_reweighted_fmax: bool=False,
    ) -> Optional[float]:
    
    """
    Computes Fmax score. Optionally averages precision and recall values across protein clusters 
    before computing the Fmax, to smooth out biases due to different cluster sizes.
    
    Arguments:
        probs: [N_seq, N_labels], logits or probabilities
        labels: [N_seq, N_labels], the ground truth labels
        clusters: indices of clusters proteins belong to
        n_thresholds: used to evaluate the precision-recall curves
        eps: to avoid division by zero during vectorization
        do_vanilla_fmax: computes vanilla Fmax score
        do_reweighted_fmax: computes Fmax score with cluster reweighting. 

        "do_vanilla_fmax" and "do_reweighted_fmax" can both be set to True, in which case
        the function returns both scores.

    Returns:
        Fmax score.
    """
    
    if torch.min(probs) < 0 or torch.max(probs) > 1:
        probs = torch.sigmoid(probs)

    # thresholds may be of size < N_thresholds if N_seq < N_thresholds
    thresholds = retrieve_thresholds(probs, labels, n_thresholds)                          # [N_thresholds] 
    thresholds = thresholds.reshape((1, 1, thresholds.shape[0]))                           # [1, 1, N_thresholds]

    probs = probs.reshape((probs.shape[0], probs.shape[1], 1))                             # [N_seq, N_labels, 1] 
    labels = labels.reshape((labels.shape[0], labels.shape[1], 1))                         # [N_seq, N_labels, 1]       
    preds = probs >= thresholds                                                            # [N_seq, N_labels, N_thresholds]                                         
    correct_preds = preds * labels                                                         # [N_seq, N_labels, N_thresholds]             
    correct_sum = torch.sum(preds, dim=-2)                                                 # [N_seq, N_thresholds]              
    label_sum = torch.sum(labels, dim=-2)                                                  # [N_seq, N_thresholds]
    n_true_pos = torch.sum(correct_preds, dim=-2)                                          # [N_seq, N_thresholds]

    precision_per_threshold = n_true_pos / (torch.sum(preds, dim=-2) + eps)                # [N_seq, N_thresholds]  
    recall_per_threshold = n_true_pos / (torch.sum(labels, dim=-2) + eps)                  # [N_seq, N_thresholds]

    if do_reweighted_fmax:
        # perform operations per protein cluster
        clusters_select = [clusters == c for c in np.unique(clusters)]                         # [N_seq]
        precision_per_threshold = [precision_per_threshold[sele] for sele in clusters_select]  # [N_seq, N_thresholds]
        recall_per_threshold = [recall_per_threshold[sele] for sele in clusters_select]        # [N_seq, N_thresholds]
        correct_sum = [correct_sum[sele] for sele in clusters_select]                          # [N_seq, N_thresholds]
        label_sum = [label_sum[sele] for sele in clusters_select]                              # [N_seq, 1]
        
        # perform computations on each cluster with map for precision and recall then aggregate
        precision_per_cluster = torch.stack(list(map(compute_sum, precision_per_threshold)))   # [N_clusters, N_thresholds]
        recall_per_cluster = torch.stack(list(map(compute_sum, recall_per_threshold)))         # [N_clusters, N_thresholds]         
        correct_sum = torch.stack(list(map(compute_presence, correct_sum)))                    # [N_clusters, N_thresholds]
        label_sum = torch.stack(list(map(compute_presence, label_sum)))                        # [N_clusters, 1]
        precisions = torch.sum(precision_per_cluster / (correct_sum + eps), dim=-2) / (torch.sum(correct_sum > 0, dim=-2) + eps)  # [N_thresholds]
        recalls = torch.sum(recall_per_cluster / (label_sum + eps), dim=-2) / (torch.sum(label_sum > 0, dim=-2) + eps)            # [N_thresholds]
        fmax_score = fmax(precisions, recalls)

    else:
        precisions = torch.sum(precision_per_threshold / (correct_sum + eps), dim=-2) / (torch.sum(correct_sum > 0, dim=-2) + eps)  # [N_thresholds]
        recalls = torch.sum(recall_per_threshold / (label_sum + eps), dim=-2) / (torch.sum(label_sum > 0, dim=-2) + eps)            # [N_thresholds]
        fmax_score = fmax(precisions, recalls)

    return fmax_score

def compute_metrics(
        logits: torch.tensor, 
        labels: torch.tensor, 
        do_reweighted_fmax: bool=False, 
        clusters: Optional[np.ndarray]=None
    ) -> List[Optional[float]]:

    metrics = []
    labels = labels.int()
    fmax_score = compute_protein_centric_fmax_vanilla(logits, labels)
    metrics.append(fmax_score)

    if do_reweighted_fmax:
        reweighted_fmax_score = compute_protein_centric_reweighted_fmax(logits, labels, clusters, do_reweighted_fmax=True)
        metrics.append(reweighted_fmax_score)

    #term_logits = torch.transpose(logits, dim0=0, dim1=1)
    #term_labels = torch.transpose(labels, dim0=0, dim1=1)
    micro_aupr_score, macro_aupr_score = compute_term_centric_auprc(logits, labels)
    metrics.append(macro_aupr_score)
    metrics.append(micro_aupr_score)
    return metrics
