# Copyright 2023 solo-learn development team.

# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to use,
# copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
# Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all copies
# or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
# PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
# FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

from typing import Dict, List, Sequence

import torch
import torch.nn.functional as F

def multi_attribute_ce_loss(logits_list, targets, dataset_attr_names, nn_type="encoder"):
    """
    logits_list: List of length 6, each element is a [B, C_i] tensor of logits
    targets: [B, 6] tensor of ground truth class indices for each attribute
    """
    losses = []
    list_of_acc1 = []

    for i in range(6):
        logits_i = logits_list[i]        # [B, C_i]
        targets_i = targets[:, i]        # [B]

        # compute loss
        loss_i = F.cross_entropy(logits_i, targets_i)
        losses.append(loss_i)

        # compute accuracy
        acc1 = accuracy_at_k(logits_i,targets_i,top_k=(1,))[0]
        list_of_acc1.append(acc1)        

    # log separate metrics
    separate_metrics = dict(zip([f"{nn_type}_{attr_name_i}_acc1" for attr_name_i in dataset_attr_names], list_of_acc1))

    return sum(losses) / len(losses), separate_metrics

# def custom_multi_attribute_loss_regression_only(logits, targets):
#     """
#     Calculates the Mean Squared Error (MSE) loss for all attributes,
#     treating them all as continuous regression problems.

#     Args:
#         logits (torch.Tensor): The raw output from the neural network (predictions).
#                                Shape: [batch_size, num_attributes]
#         targets (torch.Tensor): The true values for each attribute.
#                                 Shape: [batch_size, num_attributes]

#     Returns:
#         torch.Tensor: The total scalar MSE loss.
#     """
#     logits_float = logits.float()
#     targets_float = targets.float()
    
#     loss = F.mse_loss(logits_float, targets_float)

#     return loss

def compute_balanced_accuracy(tps, fns, tns, fps):
    recall_positive = torch.where(
        tps + fns > 0,
        tps / (tps + fns),
        torch.zeros_like(tps, device=tps.device)
    )
    recall_negative = torch.where(
        tns + fps > 0,
        tns / (tns + fps),
        torch.zeros_like(tns, device=tns.device)
    )

    balanced_acc = (recall_positive + recall_negative) / 2.0
    return balanced_acc

def multi_label_metrics(outputs: torch.Tensor, targets: torch.Tensor, dataset_attr_names, nn_type="encoder"):
    """
    Calculates various metrics for multi-label classification using only PyTorch.

    Args:
        outputs (torch.Tensor): (batch_size, num_labels), raw logits from the model.
        targets (torch.Tensor): (batch_size, num_labels), binary ground truth labels (0 or 1).

    Returns:
        dict: A dictionary containing calculated metrics.
    """

    if nn_type == "encoder":
        dataset_attr_names = [f"separate_metrics/encoder_{attr_name_i}" for attr_name_i in dataset_attr_names]
    elif nn_type == "proj":
        dataset_attr_names = [f"separate_metrics/proj_{attr_name_i}" for attr_name_i in dataset_attr_names]
    elif nn_type == "linear":
        dataset_attr_names = [f"separate_metrics/linear_{attr_name_i}" for attr_name_i in dataset_attr_names]
    else:
        raise ValueError

    with torch.no_grad():
        outputs = outputs #.detach().cpu()
        targets = targets #.detach().cpu()

        # Convert logits to probabilities using sigmoid
        probs = torch.sigmoid(outputs)
        
        # Convert probabilities to binary predictions (threshold at 0.5)
        preds = (probs >= 0.5).float()

        batch_size, num_labels = targets.shape

        # initialize the metrics dict
        metrics = {}

        # --- Separate Metrics ---
        
        # accuracy_for_each_attribute = (preds == targets).float().mean(dim=0)
        # metrics["separate_metrics"] = dict(zip(dataset_attr_names, accuracy_for_each_attribute))

        # Components for Balanced Accuracy (per attribute)
        # Calculate TP, TN, FP, FN for each attribute in this batch
        batch_tps = torch.zeros(num_labels, device=targets.device)
        batch_tns = torch.zeros(num_labels, device=targets.device)
        batch_fps = torch.zeros(num_labels, device=targets.device)
        batch_fns = torch.zeros(num_labels, device=targets.device)

        for i in range(num_labels):
            y_true_attr = targets[:, i]
            y_pred_attr = preds[:, i]
            
            batch_tps[i] = ((y_pred_attr == 1) & (y_true_attr == 1)).float().sum()
            batch_tns[i] = ((y_pred_attr == 0) & (y_true_attr == 0)).float().sum()
            batch_fps[i] = ((y_pred_attr == 1) & (y_true_attr == 0)).float().sum()
            batch_fns[i] = ((y_pred_attr == 0) & (y_true_attr == 1)).float().sum()

        metrics["separate_metrics"] = {
            f'{nn_type}_tps': batch_tps,
            f'{nn_type}_tns': batch_tns,
            f'{nn_type}_fps': batch_fps,
            f'{nn_type}_fns': batch_fns,
        }

        # --- Aggregate Metrics ---

        # 1. Exact Match Ratio (Subset Accuracy)
        # Check if all 40 predictions for an image match all 40 targets
        exact_match_correct_samples = (preds == targets).all(dim=1) # Boolean tensor
        metrics["aggregate_metrics/exact_match_ratio"] = exact_match_correct_samples.float().mean() #.item()

        # 2. Hamming Score / Hamming Loss
        # Hamming Loss: Average of (number of mismatches / total labels per sample)
        # Or, total mismatches / total elements
        mismatches = (preds != targets).float().sum() #.item()
        total_elements = batch_size * num_labels
        metrics["aggregate_metrics/hamming_loss"] = mismatches / total_elements
        metrics["aggregate_metrics/hamming_score"] = 1.0 - metrics["aggregate_metrics/hamming_loss"]

        # 3. Jaccard Index (Average Jaccard Similarity / IoU)
        # For each sample, calculate (TP / (TP + FP + FN)) and then average
        intersection = (preds * targets).sum(dim=1) # Sum of True Positives per sample
        union = (preds + targets).sum(dim=1) - intersection # Sum of (TP + FP + FN) per sample

        # Handle cases where union is 0 (no true or predicted labels for a sample)
        jaccard_per_sample = torch.where(union != 0, intersection / union, torch.tensor(0.0, device=union.device))
        metrics["aggregate_metrics/average_jaccard_index"] = jaccard_per_sample.mean() #.item()
    
    return metrics['aggregate_metrics/exact_match_ratio'], metrics['aggregate_metrics/hamming_score'], metrics["aggregate_metrics/average_jaccard_index"], metrics["separate_metrics"]

def accuracy_at_k(
    outputs: torch.Tensor, targets: torch.Tensor, top_k: Sequence[int] = (1, 5)
) -> Sequence[int]:
    """Computes the accuracy over the k top predictions for the specified values of k.

    Args:
        outputs (torch.Tensor): output of a classifier (logits or probabilities).
        targets (torch.Tensor): ground truth labels.
        top_k (Sequence[int], optional): sequence of top k values to compute the accuracy over.
            Defaults to (1, 5).

    Returns:
        Sequence[int]:  accuracies at the desired k.
    """

    with torch.no_grad():
        maxk = max(top_k)
        batch_size = targets.size(0)

        _, pred = outputs.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(targets.view(1, -1).expand_as(pred))

        res = []
        for k in top_k:
            correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


def weighted_mean(outputs: List[Dict], key: str, batch_size_key: str) -> float:
    """Computes the mean of the values of a key weighted by the batch size.

    Args:
        outputs (List[Dict]): list of dicts containing the outputs of a validation step.
        key (str): key of the metric of interest.
        batch_size_key (str): key of batch size values.

    Returns:
        float: weighted mean of the values of a key
    """

    value = 0
    n = 0
    for out in outputs:
        value += out[batch_size_key] * out[key]
        n += out[batch_size_key]
    value = value / n
    return value.squeeze(0)
