import numpy as np
from munkres import Munkres
from sklearn import metrics
import torch
import torch.nn.functional as F
from torchmetrics import Metric


def cluster_acc(y_true, y_pred):
    """Compute the clustering accuracy between true and predicted labels.

    The code is taken from https://github.com/Tiger101010/DAEGC/blob/main/DAEGC/evaluation.py
    and is similar to https://github.com/karenlatong/AGC-master/blob/master/metrics.py
    """
    # Remap the classes in y_true and y_pred to [0, numclass1-1] and [0, numclass2-1] respectively
    unique_y_true = np.unique(y_true)
    unique_y_pred = np.unique(y_pred)

    y_true_remap = {val: idx for idx, val in enumerate(unique_y_true)}
    y_pred_remap = {val: idx for idx, val in enumerate(unique_y_pred)}

    y_true = np.array([y_true_remap[val] for val in y_true])
    y_pred = np.array([y_pred_remap[val] for val in y_pred])

    # Add missing classes from y_true to y_pred and vice versa
    missing_in_pred = set(y_true_remap.values()) - set(y_pred_remap.values())
    missing_in_true = set(y_pred_remap.values()) - set(y_true_remap.values())

    y_pred = np.append(y_pred, list(missing_in_pred))
    y_true = np.append(y_true, [0] * len(missing_in_pred))

    y_true = np.append(y_true, list(missing_in_true))
    y_pred = np.append(y_pred, [0] * len(missing_in_true))

    # Recalculate unique classes and their counts
    numclass1 = len(np.unique(y_true))
    numclass2 = len(np.unique(y_pred))
    l1 = list(set(y_true))
    l2 = list(set(y_pred))

    cost = np.zeros((numclass1, numclass2), dtype=int)
    for i, c1 in enumerate(l1):
        mps = [i1 for i1, e1 in enumerate(y_true) if e1 == c1]
        for j, c2 in enumerate(l2):
            mps_d = [i1 for i1 in mps if y_pred[i1] == c2]
            cost[i][j] = len(mps_d)

    # match two clustering results by Munkres algorithm
    m = Munkres()
    cost = cost.__neg__().tolist()
    indexes = m.compute(cost)

    # get the match results
    new_predict = np.zeros(len(y_pred))
    for i, c in enumerate(l1):
        # correponding label in l2:
        c2 = l2[indexes[i][1]]

        # ai is the index with label==c2 in the pred_label list
        ai = [ind for ind, elm in enumerate(y_pred) if elm == c2]
        new_predict[ai] = c

    acc = metrics.accuracy_score(y_true, new_predict)
    f1_macro = metrics.f1_score(y_true, new_predict, average="macro")
    f1_micro = metrics.f1_score(y_true, new_predict, average="micro")

    return acc, f1_macro, f1_micro


class FuzzyMutualInformation(Metric):
    """Compute the Fuzzy Mutual Information and Normalized Fuzzy Mutual Information
    between soft cluster assignments and class labels.

    Args:
    - num_classes (int): Number of classes.
    - num_clusters (int): Number of clusters.
    - eps (float): Small value to prevent division by zero and log(0).
    - dist_sync_on_step (bool): Synchronize metric state across devices during training.
    """

    def __init__(
        self, num_classes: int, num_clusters: int, eps=1e-10, dist_sync_on_step=False
    ):
        super().__init__(dist_sync_on_step=dist_sync_on_step)

        self.num_classes = num_classes
        self.num_clusters = num_clusters
        self.eps = eps

        # States for accumulating counts/probabilities
        self.add_state(
            "joint_probs_numerator",
            default=torch.zeros(num_clusters, num_classes),
            dist_reduce_fx="sum",
        )
        self.add_state(
            "cluster_probs_sum",
            default=torch.zeros(num_clusters, 1),
            dist_reduce_fx="sum",
        )
        self.add_state(
            "class_probs_sum", default=torch.zeros(1, num_classes), dist_reduce_fx="sum"
        )
        self.add_state("total_samples", default=torch.tensor(0.0), dist_reduce_fx="sum")

    def update(self, soft_clusters: torch.Tensor, class_labels: torch.Tensor):
        """Update the metric states with the current batch data.

        Args:
        - soft_clusters (torch.Tensor): Soft cluster assignments of shape (batch_size, num_clusters)
        - class_labels (torch.Tensor): Class labels of shape (batch_size,)
        """
        batch_size = soft_clusters.size(0)

        # Ensure soft_clusters are normalized
        soft_clusters = soft_clusters / (
            soft_clusters.sum(dim=1, keepdim=True) + self.eps
        )

        # One-hot encode class labels
        class_probs = F.one_hot(class_labels, num_classes=self.num_classes).float()

        # Update total samples
        self.total_samples += batch_size

        # Update joint probabilities numerator
        batch_joint_probs = torch.matmul(
            soft_clusters.T, class_probs
        )  # Shape: (num_clusters, num_classes)
        self.joint_probs_numerator += batch_joint_probs

        # Update marginal probabilities sums
        self.cluster_probs_sum += soft_clusters.sum(
            dim=0, keepdim=True
        ).T  # Shape: (num_clusters, 1)
        self.class_probs_sum += class_probs.sum(
            dim=0, keepdim=True
        )  # Shape: (1, num_classes)

    def compute(self):
        """Compute the Fuzzy Mutual Information and Normalized Fuzzy Mutual Information.

        Returns:
        - PMI (float): The fuzzy mutual information value.
        - NPMI (float): The normalized fuzzy mutual information value.
        """
        eps = self.eps

        # Normalize accumulated counts to probabilities
        joint_probs = (
            self.joint_probs_numerator / self.total_samples
        )  # Shape: (num_clusters, num_classes)
        cluster_probs = (
            self.cluster_probs_sum / self.total_samples
        )  # Shape: (num_clusters, 1)
        class_probs_mean = (
            self.class_probs_sum / self.total_samples
        )  # Shape: (1, num_classes)

        # Compute the product of marginals P(cluster) * P(class)
        marginal_prod = (
            torch.matmul(cluster_probs, class_probs_mean) + eps
        )  # Shape: (num_clusters, n_classes)

        # Compute mutual information matrix
        mutual_info_matrix = joint_probs * torch.log(
            (joint_probs + eps) / marginal_prod
        )

        # Compute PMI
        FMI = mutual_info_matrix.sum()

        # Compute entropies
        H_cluster = -(cluster_probs * torch.log(cluster_probs + eps)).sum()
        H_class = -(class_probs_mean * torch.log(class_probs_mean + eps)).sum()

        # Compute NPMI
        denominator = torch.sqrt(H_cluster * H_class) + eps
        FNMI = FMI / denominator

        return FNMI


class FuzzyClusterCosine(Metric):
    """Compute the cosine similarity between soft cluster assignments and class labels.

    This metric accumulates the soft cluster assignments and class labels over all batches,
    and computes the cosine similarity at the end.

    Note:
        - This may consume a lot of memory for large datasets.
    """

    def __init__(self, num_clusters: int, num_classes: int, dist_sync_on_step=False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)

        self.num_clusters = num_clusters
        self.num_classes = num_classes

        # Initialize states as empty tensors
        self.add_state(
            "U_all", default=torch.empty(0, num_clusters), dist_reduce_fx="cat"
        )
        self.add_state(
            "V_all", default=torch.empty(0, dtype=torch.long), dist_reduce_fx="cat"
        )

    def update(self, U: torch.Tensor, V: torch.Tensor):
        """Update the metric states with the current batch data.

        Args:
            U (torch.Tensor): Soft cluster assignments of shape (batch_size, num_clusters).
            V (torch.Tensor): Class labels of shape (batch_size,).
        """
        # Ensure that states are on the same device as input tensors
        device = U.device

        if self.U_all.device != device:
            self.U_all = self.U_all.to(device)
            self.V_all = self.V_all.to(device)

        # Detach tensors to avoid unnecessary computation graph retention
        U = U.detach()
        V = V.detach()

        # Concatenate current batch to the accumulated tensors
        self.U_all = torch.cat([self.U_all, U], dim=0)
        self.V_all = torch.cat([self.V_all, V], dim=0)

    def compute(self):
        """Compute the consine similarity between soft cluster assignments
        and one-hot representations of class labels.

        Returns:
            float: The consine similarity value between 0 and 1.
        """
        U = self.U_all  # Shape: (n_samples, n_clusters)
        V = self.V_all.long()  # Shape: (n_samples,)

        # Convert class labels to one-hot encoding
        V_one_hot = F.one_hot(V, num_classes=self.num_classes).float()

        # Compute agreement matrices
        UUT = torch.matmul(U, U.T)  # Shape: (n_samples, n_samples)
        VVT = torch.matmul(V_one_hot, V_one_hot.T)  # Shape: (n_samples, n_samples)

        # Compute numerator and denominator
        numerator = torch.sum(UUT * VVT)
        denominator = torch.sqrt(torch.sum(UUT**2) * torch.sum(VVT**2))

        # Handle zero denominator
        if denominator == 0:
            return torch.tensor(0.0, device=U.device)

        COS = numerator / denominator

        return COS

# Test cluster_acc
if __name__ == "__main__":
    y_true = np.array([4, 3, 1, 1, 2, 5])
    y_pred = np.array([7, 0, 1, 1, 2, 2])
    acc, f1_macro, f1_micro = cluster_acc(y_true, y_pred)
    print(f"Accuracy: {acc}, F1 Macro: {f1_macro}, F1 Micro: {f1_micro}")

    y_true = np.array([0, 0, 1, 1, 2, 2])
    y_pred = np.array([0, 0, 1, 2, 2, 7])
    acc, f1_macro, f1_micro = cluster_acc(y_true, y_pred)
    print(f"Accuracy: {acc}, F1 Macro: {f1_macro}, F1 Micro: {f1_micro}")

    y_true = np.array([0, 0, 1, 1, 2, 2])
    y_pred = np.array([0, 0, 1, 1, 1, 1])
    acc, f1_macro, f1_micro = cluster_acc(y_true, y_pred)
    print(f"Accuracy: {acc}, F1 Macro: {f1_macro}, F1 Micro: {f1_micro}")
