import math
from itertools import combinations
from scipy.special import comb

import torch

import shutil
import torch
import os

from matplotlib import pyplot as plt
from sklearn.manifold import TSNE
from torch.utils.data import DataLoader
import numpy as np
import argparse
import random
import copy
from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score, accuracy_score, confusion_matrix

import metric
from dataloader import load_data, DatasetSplit, get_mask
from sklearn.cluster import KMeans
from scipy.optimize import linear_sum_assignment

import torch.nn.functional as F


def batch_compute_similarity_tensor(tensor1, tensor2):
    norm1 = torch.norm(tensor1, p=2, dim=2, keepdim=True)  # (batch_size, num_samples1, 1)
    norm2 = torch.norm(tensor2, p=2, dim=2, keepdim=True)  # (batch_size, num_samples2, 1)

    normalized_tensor1 = tensor1 / (norm1 + 1e-8)  # (batch_size, num_samples1, feature_dim)
    normalized_tensor2 = tensor2 / (norm2 + 1e-8)  # (batch_size, num_samples2, feature_dim)

    normalized_tensor2_transposed = normalized_tensor2.transpose(1, 2)

    similarity_tensor = torch.bmm(normalized_tensor1,
                                  normalized_tensor2_transposed)  # (batch_size, num_samples1, num_samples2)

    return similarity_tensor


def all_combinations(input_vector):
    result = []
    for r in range(1, len(input_vector) + 1):
        combos = combinations(input_vector, r)
        for combo in combos:
            result.append(sorted(list(combo), reverse=False))
    return result


def get_models_weights_list(valid_model_list, valid_dataset_list,valid_users, class_num, device):
    models_weights_list = []
    for an in range(valid_users):
        valid_model_list[an].eval()
        model = valid_model_list[an]
        mi = 0
        for batch_idx, (xs, ys) in enumerate(valid_dataset_list[an]):
            for v in range(len(xs)):
                xs[v] = xs[v].to(device)
                xs[v] = xs[v].to(torch.float32)
            f_list, common_f, xrrs, xrs, hs = model(xs, model.have_view)

            for v in model.have_view:
                # print(an, v, metric.mutual_information(f_list[v], common_f), metric.mutual_information(xrrs[v], xrs[v]))
                # mi += 1 / metric.mutual_information(f_list[v], common_f)
                # mi += math.sqrt(metric.mutual_information(f_list[v], common_f))
                mi += math.exp(-metric.mutual_information(f_list[v], common_f))
                # mi += metric.mutual_information(f_list[v], common_f) + metric.mutual_information(xrrs[v], xrs[v])

        models_weights_list.append(mi / len(model.have_view))
        # models_weights_list.append(mi)

    # print('models_weights_list', models_weights_list)
    return models_weights_list

def aggregate_models(model_list, weights=None):
    if weights is None:
        weights = [1.0 / len(model_list)] * len(model_list)
    agg_model = copy.deepcopy(model_list[0])
    agg_state_dict1 = agg_model.state_dict()
    for key in agg_state_dict1:
        agg_state_dict1[key].zero_()
    for model, weight in zip(model_list, weights):
        model_state_dict1 = model.state_dict()
        for key in agg_state_dict1:
            agg_state_dict1[key] += model_state_dict1[key] * weight / sum(weights)
    agg_model.load_state_dict(agg_state_dict1)
    # for v in range(len(agg_model.encoders)):
    #     vc_num = 0
    #     agg_state_dict2 = agg_model.encoders[v].state_dict()
    #     agg_state_dict3 = agg_model.decoders[v].state_dict()
    #     agg_state_dict4 = agg_model.v_decoders[v].state_dict()
    #     for key in agg_state_dict2:
    #         agg_state_dict2[key].zero_()
    #     for key in agg_state_dict3:
    #         agg_state_dict3[key].zero_()
    #     for key in agg_state_dict4:
    #         agg_state_dict4[key].zero_()
    #     for model in model_list:
    #         if v in model.have_view:
    #             vc_num += 1
    #             model_state_dict2 = model.encoders[v].state_dict()
    #             model_state_dict3 = model.decoders[v].state_dict()
    #             model_state_dict4 = model.v_decoders[v].state_dict()
    #             for key in agg_state_dict2:
    #                 agg_state_dict2[key] += model_state_dict2[key]
    #             for key in agg_state_dict3:
    #                 agg_state_dict3[key] += model_state_dict3[key]
    #             for key in agg_state_dict4:
    #                 agg_state_dict4[key] += model_state_dict4[key]
    #     if vc_num > 0:
    #         for key in agg_state_dict2:
    #             agg_state_dict2[key] /= vc_num
    #         for key in agg_state_dict3:
    #             agg_state_dict3[key] /= vc_num
    #         for key in agg_state_dict4:
    #             agg_state_dict4[key] /= vc_num
    #         agg_model.encoders[v].load_state_dict(agg_state_dict2)
    #         agg_model.decoders[v].load_state_dict(agg_state_dict3)
    #         agg_model.v_decoders[v].load_state_dict(agg_state_dict4)
    #     print('agg', v, vc_num)
    return agg_model


def get_label_by_KMeans(model, dataset, class_num, device):
    h_list, ys_list, rs_list = [], [], []
    model.eval()
    for batch_idx, (xs, ys) in enumerate(dataset):
        for v in range(len(xs)):
            xs[v] = xs[v].to(device)
            xs[v] = xs[v].to(torch.float32)
        f_list, common_f, xrrs, xrs, hs = model(xs, model.have_view)
        h_list.append(common_f)
        ys_list.append(ys)

    local_hs = torch.cat(h_list, dim=0)
    local_ys = torch.cat(ys_list, dim=0)

    kmeans = KMeans(n_clusters=class_num, init='k-means++', n_init=100)
    kmeans.fit(local_hs.detach().cpu().numpy())
    labels = kmeans.predict(local_hs.detach().cpu().numpy())
    return labels

def torch_kmeans(x, num_clusters, num_iters=100, init_centers=None):
    N, D = x.shape
    device = x.device

    if init_centers is None:
        indices = torch.randperm(N, device=device)[:num_clusters]
        centers = x[indices].clone()
    else:
        centers = init_centers.clone()

    for _ in range(num_iters):
        dists = torch.cdist(x, centers)       # [N, K]
        labels = torch.argmin(dists, dim=1)   # [N]

        centers.zero_()
        counts = torch.zeros(num_clusters, device=device)
        centers.index_add_(0, labels, x)      # 向各类中心累加
        counts.index_add_(0, labels, torch.ones(N, device=device))

        mask = counts > 0
        centers[mask] = centers[mask] / counts[mask].unsqueeze(1)

    return labels, centers


def valid_list_gpu(valid_model_list, valid_dataset_list, valid_users, class_num, device):
    all_preds = []
    all_trues = []

    init_centers = None

    for uid in range(valid_users):
        model = valid_model_list[uid].to(device).eval()

        feats_list = []
        labels_list = []
        with torch.no_grad():
            for xs, ys in valid_dataset_list[uid]:
                xs = [x.to(device).float() for x in xs]
                _, common_f, *_ = model(xs, model.have_view)
                feats_list.append(common_f)         # [batch, D]
                labels_list.append(ys.to(device))   # [batch]

        feats = torch.cat(feats_list, dim=0)       # [Ni, D], GPU
        trues = torch.cat(labels_list, dim=0)      # [Ni], GPU

        preds, centers = torch_kmeans(
            feats, class_num, num_iters=100, init_centers=init_centers
        )
        if uid == 0:
            init_centers = centers

        all_preds.append(preds)
        all_trues.append(trues)

    global_preds = torch.cat(all_preds, dim=0).cpu().numpy()
    global_trues = torch.cat(all_trues, dim=0).cpu().numpy()

    acc = compute_acc(global_preds, global_trues, class_num)
    nmi = normalized_mutual_info_score(global_preds, global_trues)
    ari = adjusted_rand_score(global_preds, global_trues)

    # print(f"overall acc: {acc:.4f}")
    # print(f"overall nmi: {nmi:.4f}")
    # print(f"overall ari: {ari:.4f}")

    return acc, nmi, ari, init_centers

def valid_list(valid_model_list, valid_dataset_list,valid_users, class_num, device):
    local_hs, local_ys, local_labels = [], [], []
    init_center = []
    for an in range(valid_users):
        h_list, ys_list= [], []
        valid_model_list[an].eval()
        model = valid_model_list[an]
        for batch_idx, (xs, ys) in enumerate(valid_dataset_list[an]):
            for v in range(len(xs)):
                xs[v] = xs[v].to(device)
                xs[v] = xs[v].to(torch.float32)
            f_list, common_f, xrrs, xrs, hs = model(xs, model.have_view)
            h_list.append(common_f)
            ys_list.append(ys)
        local_hs = torch.cat(h_list, dim=0)
        local_ys.append(torch.cat(ys_list, dim=0))

        if an == 0:
            kmeans = KMeans(n_clusters=class_num, init='k-means++', n_init=100)
            kmeans.fit(local_hs.detach().cpu().numpy())
            init_center = kmeans.cluster_centers_
        else:
            kmeans = KMeans(n_clusters=class_num, init=init_center)
            kmeans.fit(local_hs.detach().cpu().numpy())
        labels = kmeans.predict(local_hs.detach().cpu().numpy())
        local_labels.append(labels)

    global_ys = torch.cat(local_ys, dim=0)
    global_labels = np.concatenate(local_labels, axis=0)

    test_acc = compute_acc(global_labels, global_ys.detach().cpu().numpy(), class_num)
    test_nmi = normalized_mutual_info_score(global_labels, global_ys.detach().cpu().numpy())
    test_ari = adjusted_rand_score(global_labels, global_ys.detach().cpu().numpy())
    # print('overall acc', test_acc)
    # print('overall nmi', test_nmi)
    # print('overall ari', test_ari)

    return test_acc,test_nmi,test_ari,init_center


def valid_model(model, valid_dataset, nu, class_num, device):
    h_list, ys_list, rs_list = [], [], []
    model.eval()
    for batch_idx, (xs, ys) in enumerate(valid_dataset):
        for v in range(len(xs)):
            xs[v] = xs[v].to(device)
            xs[v] = xs[v].to(torch.float32)
        f_list, common_f, xrrs, xrs, hs = model(xs, model.have_view)
        h_list.append(common_f)
        ys_list.append(ys)

    local_hs = torch.cat(h_list, dim=0)
    local_ys = torch.cat(ys_list, dim=0)

    if model.init_center is None:
        kmeans = KMeans(n_clusters=class_num, init='k-means++', n_init=100)
        kmeans.fit(local_hs.detach().cpu().numpy())
        init_center = kmeans.cluster_centers_
        model.init_center = init_center
    else:
        kmeans = KMeans(n_clusters=class_num, init=model.init_center)
        kmeans.fit(local_hs.detach().cpu().numpy())
    labels = kmeans.predict(local_hs.detach().cpu().numpy())
    acc = compute_acc(labels, local_ys.detach().cpu().numpy(), class_num)
    # print('client',nu,'acc', acc)
    return acc


def compute_acc(y_pred, y_true, class_num):
    assert len(y_pred) == len(y_true)
    D = class_num
    w = np.zeros((D, D), dtype=np.int64)

    for i in range(len(y_pred)):
        w[y_pred[i], y_true[i]] += 1

    row_ind, col_ind = linear_sum_assignment(w.max() - w)

    correct_count = sum([w[i, j] for i, j in zip(row_ind, col_ind)])

    accuracy = correct_count * 1.0 / len(y_pred)

    return correct_count, accuracy


def clustering_acc(true_labels, pred_labels):
    cm = confusion_matrix(true_labels, pred_labels)

    row_ind, col_ind = linear_sum_assignment(-cm)

    mapped_pred = np.array(pred_labels)
    for true_idx, pred_idx in zip(row_ind, col_ind):
        mapped_pred[pred_labels == pred_idx] = true_idx

    correct_count = np.sum(mapped_pred == true_labels)
    acc = correct_count / len(true_labels)

    return acc, int(correct_count)


def compute_nmi(y_pred, y_true, class_num):
    assert len(y_pred) == len(y_true)
    n = len(y_true)

    confusion = np.zeros((class_num, class_num), dtype=np.int64)
    for i in range(n):
        confusion[y_true[i], y_pred[i]] += 1

    p_ij = confusion / n

    p_i = np.sum(p_ij, axis=1)  # P(y_true)
    p_j = np.sum(p_ij, axis=0)  # P(y_pred)

    mi = 0.0
    for i in range(class_num):
        for j in range(class_num):
            if p_ij[i, j] > 0:
                mi += p_ij[i, j] * np.log(p_ij[i, j] / (p_i[i] * p_j[j]))

    h_true = -np.sum(p_i * np.log(p_i + 1e-9))  # 避免log(0)
    h_pred = -np.sum(p_j * np.log(p_j + 1e-9))

    nmi = mi / np.sqrt(h_true * h_pred)
    return nmi


def compute_ari(y_pred, y_true, class_num):
    assert len(y_pred) == len(y_true)
    n = len(y_true)

    confusion = np.zeros((class_num, class_num), dtype=np.int64)
    for i in range(n):
        confusion[y_true[i], y_pred[i]] += 1

    a = np.sum(confusion, axis=1)
    b = np.sum(confusion, axis=0)

    def safe_comb(x):
        return comb(x, 2) if x >= 2 else 0.0

    sum_comb_confusion = np.sum([safe_comb(n_ij) for n_ij in confusion.flatten()])  # ΣC(n_ij,2)
    sum_comb_a = np.sum([safe_comb(a_i) for a_i in a])  # ΣC(a_i,2)
    sum_comb_b = np.sum([safe_comb(b_j) for b_j in b])  # ΣC(b_j,2)

    expected_index = sum_comb_a * sum_comb_b / comb(n, 2)

    max_index = 0.5 * (sum_comb_a + sum_comb_b)

    numerator = sum_comb_confusion - expected_index
    denominator = max_index - expected_index

    if denominator < 0:
        denominator = -denominator
    if denominator == 0:
        ari = 1.0 if numerator == 0 else 0.0
    else:
        ari = numerator / denominator

    return ari

def draw_scatter(data_tensor, labels, title):
    tsne = TSNE(n_components=2)
    data_2d = tsne.fit_transform(data_tensor.cpu().detach().numpy())
    plt.figure(figsize=(8, 6))
    plt.title(title)
    scatter = plt.scatter(data_2d[:, 0], data_2d[:, 1], c=labels, cmap='tab10', marker='x')
    plt.savefig('./images/{}.png'.format(title))
    plt.close()


def valid_by_feature(y, z, class_num):
    kmeans = KMeans(n_clusters=class_num, init='k-means++', n_init=20)
    kmeans.fit(z.detach().cpu().numpy())
    labels = kmeans.predict(z.detach().cpu().numpy())
    correct_count, accuracy = compute_acc(labels, y.detach().cpu().numpy(), class_num)
    return accuracy

def valid_by_labels(y, labels, class_num):
    correct_count, accuracy = compute_acc(labels, y.detach().cpu().numpy(), class_num)
    return correct_count, accuracy



def valid_clients(clients):
    client_0 = clients[0]
    class_num = client_0.args.class_num
    y, h = client_0.get_y_h()

    kmeans = KMeans(n_clusters=class_num, init='k-means++', n_init=100)
    kmeans.fit(h.detach().cpu().numpy())
    init_center = kmeans.cluster_centers_

    ys, labels = [], []
    for client in clients:
        y, h = client.get_y_h()
        kmeans = KMeans(n_clusters=class_num, init=init_center)
        kmeans.fit(h.detach().cpu().numpy())

        ys.extend(y.detach().cpu().numpy())
        label = kmeans.predict(h.detach().cpu().numpy())
        client.label = label
        labels.extend(label)

    correct_count, accuracy = compute_acc(labels, ys, class_num)
    nmi = compute_nmi(labels, ys, class_num)
    ari = compute_ari(labels, ys, class_num)

    return correct_count, accuracy, nmi, ari


def most_similar_per_sample(mat1: torch.Tensor,
                            mat2: torch.Tensor,
                            mat3: torch.Tensor) -> torch.Tensor:
    stacked = torch.stack([mat1, mat2, mat3], dim=1)  # (n,3,d)

    normalized = F.normalize(stacked, p=2, dim=2)     # (n,3,d)

    sim_matrix = torch.matmul(normalized, normalized.transpose(1, 2))  # (n,3,3)

    eye = torch.eye(3, device=sim_matrix.device, dtype=sim_matrix.dtype)
    sim_matrix = sim_matrix * (1 - eye)  # (n,3,3)

    sim_sums = sim_matrix.sum(dim=2)  # (n,3)

    best_idx = torch.argmax(sim_sums, dim=1)  # (n,)

    batch_indices = torch.arange(stacked.size(0), device=stacked.device)
    best = stacked[batch_indices, best_idx, :]  # (n,d)

    return best

def matrix_euclidean_distance(matrix1, matrix2):

    if matrix1.shape != matrix2.shape:
        raise ValueError("err")

    diff = matrix1 - matrix2
    distance = torch.norm(diff, p='fro')

    return distance.item()


def add_noise_to_model(model, epsilon=50, sensitivity=1.0, mechanism="gaussian"):

    noisy_state_dict = {}
    for name, param in model.state_dict().items():
        if mechanism == "gaussian":
            sigma = sensitivity / epsilon
            noise = torch.normal(mean=0, std=sigma, size=param.shape, device=param.device)
        elif mechanism == "laplace":
            scale = sensitivity / epsilon
            noise = torch.distributions.Laplace(0, scale).sample(param.shape).to(param.device)
        else:
            raise ValueError("mechanism must 'gaussian' or 'laplace'")

        noisy_state_dict[name] = param + noise

    return noisy_state_dict