import numpy as np
import torch
from scipy.optimize import linear_sum_assignment
from scipy.spatial.distance import cdist


def get_cost_matrix_from_probabilities(pred_probs1, pred_probs2, metric="correlation"):
    """
    Construct a cost matrix based on the distance between two sets of prediction probability vectors.

    Args:
        pred_probs1 (numpy.ndarray): Prediction probability vectors from model 1, shape (N, num_classes).
        pred_probs2 (numpy.ndarray): Prediction probability vectors from model 2, shape (N, num_classes).
        metric (str): Distance metric to use ('correlation', euclidean', 'cosine', 'kl', etc.).

    Returns:
        cost_matrix (numpy.ndarray): Cost matrix based on the chosen distance metric.
    """
    cost_matrix = cdist(pred_probs1.T, pred_probs2.T, metric=metric)
    return cost_matrix


def get_cost_matrix(conf_matrix1, conf_matrix2):
    cost_matrix = np.zeros_like(conf_matrix1, dtype=float)
    for i in range(len(conf_matrix1)):
        for j in range(len(conf_matrix2)):
            cost_matrix[i, j] = compatibility_score(conf_matrix1[i], conf_matrix2[j])
    return cost_matrix


def compatibility_score(vec1, vec2):
    if np.all(vec1 == 0) and np.all(vec2 == 0):
        return 100  # map for a dummy variable
    for i in range(len(vec1)):
        if vec1[i] > 0:
            if not (vec2[i] > 0):
                return 10**6  # map for an impossibile mapping
    return 0  # map for a valid mapping


def fix_conf_matrix(conf_matrix):
    for i in range(conf_matrix.shape[0]):
        max_index = np.argmax(conf_matrix[i])
        conf_matrix[i] = np.where(
            np.arange(conf_matrix.shape[1]) == max_index, conf_matrix[i], 0
        )
    return conf_matrix.T


def get_permutation_matrix(row_indices, col_indices, size):
    permutation_matrix = np.zeros((size, size))
    for i, j in zip(row_indices, col_indices):
        permutation_matrix[i, j] = 1
    return permutation_matrix


def apply_permutation(conf_matrix, permutation_matrix):
    if isinstance(conf_matrix, torch.Tensor) and isinstance(
        permutation_matrix, torch.Tensor
    ):
        return torch.matmul(permutation_matrix, conf_matrix).T
    elif isinstance(conf_matrix, np.ndarray) and isinstance(
        permutation_matrix, np.ndarray
    ):
        return np.dot(permutation_matrix, conf_matrix).T
    else:
        raise TypeError(
            "Both conf_matrix and permutation_matrix must be either both torch.Tensor or both np.ndarray."
        )


def align_solutions(conf_matrix1, conf_matrix2):
    conf_matrix1 = fix_conf_matrix(conf_matrix1)
    conf_matrix2 = fix_conf_matrix(conf_matrix2)

    cost_matrix = get_cost_matrix(conf_matrix1, conf_matrix2)
    row_indices, col_indices = linear_sum_assignment(cost_matrix)

    inf_in_assignment = False
    for i, j in zip(row_indices, col_indices):
        if cost_matrix[i, j] >= 10**6:
            inf_in_assignment = True

    if inf_in_assignment:
        return False, None

    permutation_matrix = get_permutation_matrix(
        row_indices, col_indices, conf_matrix1.shape[0]
    )
    return True, permutation_matrix


def permutation_matrix_from_probabilities(probs1, probs2, metric="correlation"):
    cost_matrix = get_cost_matrix_from_probabilities(probs1, probs2, metric)
    row_indices, col_indices = linear_sum_assignment(cost_matrix)

    n = cost_matrix.shape[0]
    permutation_matrix = np.zeros((n, n))

    permutation_matrix[row_indices, col_indices] = 1

    return permutation_matrix


def build_confusion_matrix(preds1, preds2, n_classes):
    confusion_matrix = np.zeros((n_classes, n_classes), dtype=int)

    # Populate the confusion matrix
    for i in range(preds1.shape[0]):
        confusion_matrix[preds1[i].item(), preds2[i].item()] += 1

    return confusion_matrix


def permutation_matrix_from_predictions(preds1, preds2, n_classes):
    confusion_matrix = build_confusion_matrix(preds1, preds2, n_classes)

    row_indices, col_indices = linear_sum_assignment(confusion_matrix, maximize=True)

    perm_matrix = np.zeros((n_classes, n_classes), dtype=np.float32)
    perm_matrix[row_indices, col_indices] = 1

    return torch.tensor(perm_matrix)


if __name__ == "__main__":
    conf_matrix1 = np.array(
        [
            [4, 3, 724, 0, 249],
            [9, 1, 1124, 0, 1],
            [1023, 3, 5, 0, 1],
            [1007, 0, 0, 0, 3],
            [2, 970, 4, 0, 6],
        ]
    )

    conf_matrix2 = np.array(
        [
            [155, 0, 0, 824, 1],
            [0, 0, 0, 1126, 9],
            [4, 0, 0, 2, 1026],
            [1, 0, 0, 0, 1009],
            [976, 0, 0, 4, 2],
        ]
    )

    probs = [
        [1.0, 0.0, 0.0, 0.0, 0.0],
        [0.0, 1.0, 0.0, 0.0, 0.0],
        [0.0, 0.0, 1.0, 0.0, 0.0],
        [0.0, 0.0, 0.0, 1.0, 0.0],
        [0.0, 0.0, 0.0, 0.0, 1.0],
    ]

    K = np.array([[0.2, 0.8], [0.1, 0.9], [0.4, 0.6], [0.7, 0.3], [0.5, 0.5]])

    conf_matrix1 = fix_conf_matrix(conf_matrix1)
    conf_matrix2 = fix_conf_matrix(conf_matrix2)

    cost_matrix = get_cost_matrix(conf_matrix1, conf_matrix2)
    row_indices, col_indices = linear_sum_assignment(cost_matrix)

    print("Optimal Alignment:")
    inf_in_assignment = False
    for i, j in zip(row_indices, col_indices):
        print(
            f"Class in Matrix 1: {i} -> Class in Matrix 2: {j} with cost {cost_matrix[i, j]}"
        )
        if cost_matrix[i, j] >= 10**6:
            inf_in_assignment = True

    if inf_in_assignment:
        print("No assignment possible")
    else:
        print("Perfect match")

        permutation_matrix = get_permutation_matrix(
            row_indices, col_indices, conf_matrix1.shape[0]
        )
        print("Permutation Matrix:")
        print(permutation_matrix)

        aligned_conf_matrix2 = apply_permutation(conf_matrix2, permutation_matrix)
        print("Aligned Confusion Matrix 2:")
        print(aligned_conf_matrix2)
        quit()

        for prob_vector in probs:
            aligned_prob_vector = apply_permutation(prob_vector, permutation_matrix)
            print("Aligned Probability Vector:")
            print(aligned_prob_vector)

        indices = np.argmax(permutation_matrix, axis=1)
        print(indices, type(indices), K.shape)
        print("Rearranged K")
        print(K[indices])
