# @Author  : Peizhao Li
# @Contact : peizhaoli05@gmail.com

import numpy as np
from scipy.optimize import linear_sum_assignment
import torch
from scipy.linalg import sqrtm
import torch.nn.functional as F
from itertools import combinations

def predict(data_loader, encoder, dfc):
    features_1 = []
    labels_1 = []
    features_2 = []
    labels_2 = []
    encoder.eval()
    dfc.eval()
    
    with torch.no_grad():
        for idx, (img, label) in enumerate(data_loader[0]):
            img = img.cuda()
            feat = dfc(encoder(img)[0])
            features_1.append(feat.detach())
            labels_1.append(label)

        for idx, (img, label) in enumerate(data_loader[1]):
            img = img.cuda()

            feat = dfc(encoder(img)[0])
            features_2.append(feat.detach())
            labels_2.append(label)
                
    return torch.cat(features_1).max(1)[1], torch.cat(labels_1).long(), torch.cat(features_2).max(1)[1], torch.cat(labels_2).long()


def cluster_accuracy(y_true, y_predicted, cluster_number=None):
    """
    Calculate clustering accuracy after using the linear_sum_assignment function in SciPy to
    determine reassignments.

    :param y_true: list of true cluster numbers, an integer array 0-indexed
    :param y_predicted: list  of predicted cluster numbers, an integer array 0-indexed
    :param cluster_number: number of clusters, if None then calculated from input
    :return: reassignment dictionary, clustering accuracy
    """
    if cluster_number is None:
        cluster_number = max(y_predicted.max(), y_true.max()) + 1  # assume labels are 0-indexed
    count_matrix = np.zeros((cluster_number, cluster_number), dtype=np.int64)
#     print(y_predicted)
    for i in range(y_predicted.size):
        count_matrix[y_predicted[i], y_true[i]] += 1

    row_ind, col_ind = linear_sum_assignment(count_matrix.max() - count_matrix)
    reassignment = dict(zip(row_ind, col_ind))
    accuracy = count_matrix[row_ind, col_ind].sum() / y_predicted.size

    return reassignment, accuracy


def entropy(input):
    epsilon = 1e-5
    entropy = -input * torch.log(input + epsilon)
    entropy = torch.sum(entropy, dim=0)
    return entropy


def balance(predicted, size_0, num_sens = 2, k=10):
    count = torch.zeros((k, num_sens))
    
    for i in range(size_0):
        count[predicted[i], 0] += 1
    for i in range(size_0, predicted.shape[0]):
        count[predicted[i], 1] += 1

    count[count == 0] = 1e-5

    balance_0 = torch.min(count[:, 0] / count[:, 1])
    balance_1 = torch.min(count[:, 1] / count[:, 0])

    en_0 = entropy(count[:, 0] / torch.sum(count[:, 0]))
    en_1 = entropy(count[:, 1] / torch.sum(count[:, 1]))

    return min(balance_0, balance_1).numpy(), en_0.numpy(), en_1.numpy()

def calc_FFDC(list_1, list_2, num_clusters):
    print([len(i) for i in list_1])
    print([len(i) for i in list_2])
    
    list_2 = [torch.cat(cls, 0) for cls in list_2]
    list_1 = [torch.cat(cls, 0) for cls in list_1]
    
    if (min([len(i) for i in list_1]) < 2) or (min([len(i) for i in list_2]) < 2):
        return [np.inf], [np.inf]
    
    centroids = [torch.cat((list_1[idx], list_2[idx]), 0).mean(0).view(1,-1).cpu().numpy().astype(np.double) for idx in range(num_clusters)]

    cluster_features_1 = [cls.cpu().numpy().astype(np.double) - centroids[idx] if len(cls) > 0 else centroids[idx] \
                      for idx, cls in enumerate(list_1)]

    cluster_features_2 = [cls.cpu().numpy().astype(np.double) - centroids[idx] if len(cls) > 0 else centroids[idx] \
                      for idx, cls in enumerate(list_2)]

    N = [cluster_features_1[idx].shape[0] + cluster_features_2[idx].shape[0] for idx in range(num_clusters)]
    N_1 = [cluster_features_1[idx].shape[0] for idx in range(num_clusters)]
    N_2 = [cluster_features_2[idx].shape[0] for idx in range(num_clusters)]
    
    C_1 = np.array([cls.sum(0) for cls in cluster_features_1])
    C_2 = np.array([cls.sum(0) for cls in cluster_features_2])
    
    Z = [np.concatenate((cluster_features_1[idx], cluster_features_2[idx])) for idx in range(num_clusters)]
    Z_F = [sum(sum(Z[idx] ** 2 / (N_1[idx] - 1)*(N_2[idx] - 1))) for idx in range(num_clusters)]

    FFD_1 = [sum((C_1[idx]/ N_1[idx] - C_2[idx] / N_2[idx]) ** 2) for idx in range(num_clusters) if N[idx] > 0]
    
    UH_F = [np.sqrt(sum(sum((cluster_features_1[idx] - C_1[idx]) **2))) for idx in range(num_clusters)]
    VH_F = [np.sqrt(sum(sum((cluster_features_2[idx] - C_2[idx]) **2))) for idx in range(num_clusters)]
    
    s = [(UH_F[idx]/np.sqrt(N_1[idx] - 1) - VH_F[idx]/np.sqrt(N_2[idx] - 1))**2 \
                     for idx in range(num_clusters)]
    
    FFDC = [np.sqrt(FFD_1[idx] + s[idx] + Z_F[idx]) for idx in range(num_clusters)]
    
    return FFDC, Z_F
