import copy
import os.path

from .ufedbase import UnlearnBasicClient, UnlearnBasicServer
import numpy as np
from utils import fmodule
import torch
import torch.nn as nn
import collections
import json
from tqdm import tqdm
from utils.utils_unlearn import agg_func
from utils.finch import FINCH

class Server(UnlearnBasicServer):
    def __init__(self, option, model, clients, data_loader, device=None):
        super(Server, self).__init__(option, model, clients, data_loader, device)
        self.global_protos = []
        self.local_protos = {}

    def iterate(self):
        self.selected_clients = self.sample()
        reply = self.communicate(self.selected_clients)
        # 按照self.selected_clients = self.received_clients
        models, losses, protos = reply['model'], reply['loss'], reply['proto']

        # local_protos key~client_id, value~proto  proto key~label, value~local proto
        for idx, proto in enumerate(protos):
            self.local_protos[self.selected_clients[idx]] = proto

        self.global_protos = self.proto_aggregation(self.local_protos)

        self.model = self.aggregate(models)
        del models
        return

    def pack(self, client_id, model=None):
        if model is not None:
            return {
                "model": copy.deepcopy(model),
                "current_rounds": self.current_rounds,
                "lr": self.lr,
                "momentum": self.local_momentum,
                "weight_decay": self.weight_decay,
                "stage": self.stage,
                "global_protos": self.global_protos
            }
        else:
            return {
                "model": copy.deepcopy(self.model),
                "current_rounds": self.current_rounds,
                "lr": self.lr,
                "momentum": self.local_momentum,
                "weight_decay": self.weight_decay,
                "stage": self.stage,
                "global_protos": self.global_protos
            }

    # local_protos key~client_id, value~proto  proto key~label, value~local proto
    def proto_aggregation(self, local_protos_list):
        agg_protos_label = dict()
        # 取本轮训练的·protos，合并有相同label的protos存为列表
        for idx in self.selected_clients:
            local_protos = local_protos_list[idx] # proto key~label, value~local proto
            for label in local_protos.keys():
                if label in agg_protos_label:
                    agg_protos_label[label].append(local_protos[label])
                else:
                    agg_protos_label[label] = [local_protos[label]]

        # 大list
        for [label, proto_list] in agg_protos_label.items():
            # 如果多个client对同一个label都有local proto的表示，那就聚类
            if len(proto_list) > 1:
                proto_list = [item.squeeze(0).detach().cpu().numpy().reshape(-1) for item in proto_list]
                proto_list = np.array(proto_list)

                # 聚类，看所有clients同一个label的proto能聚出几个类
                c, num_clust, req_c = FINCH(proto_list, initial_rank=None, req_clust=None, distance='cosine',
                                            ensure_early_exit=False, verbose=False)

                m, n = c.shape # 行index是proto索引，列index是proto所属的类，多列是因为做了多次聚类表征
                class_cluster_list = []

                # 把每一行的最后一列都取出来
                for index in range(m):
                    class_cluster_list.append(c[index, -1])

                class_cluster_array = np.array(class_cluster_list) # 收集每个proto被分配的类（最后一次聚类）
                uniqure_cluster = np.unique(class_cluster_array).tolist()

                agg_selected_proto = []
                # 同属一类的proto求平均
                for _, cluster_index in enumerate(uniqure_cluster):
                    selected_array = np.where(class_cluster_array == cluster_index)
                    selected_proto_list = proto_list[selected_array].reshape(-1, proto_list.shape[1])
                    proto = np.mean(selected_proto_list, axis=0, keepdims=True)
                    agg_selected_proto.append(torch.tensor(proto))
                agg_protos_label[label] = agg_selected_proto # 一个label对应多个agg selected proto因为可能是不同domain的proto
            else:
                agg_protos_label[label] = [proto_list[0].detach()] # 如果proto一致那就是一个

        return agg_protos_label


class Client(UnlearnBasicClient):
    def __init__(self, option, id, model=None):
        super(Client, self).__init__(option, id, model)
        self.global_protos = []
        self.infoNCET = option['infoNCET']

    def reply(self, server_pack=None):
        assert server_pack is not None
        self.model = self.unpack(server_pack)
        loss, local_protos = self.train()
        cpkg = self.pack(loss, local_protos)
        return cpkg

    def pack(self, loss, client_protos, model=None):
        if model is not None:
            return {
                "model": copy.deepcopy(model),
                "loss": loss,
                "proto": client_protos
            }
        else:
            return {
                "model": copy.deepcopy(self.model),
                "loss": loss,
                "proto": client_protos
            }

    def train(self, ):
        self.model.train()
        total_loss = 0.0
        optimizer = self.get_optimizer()

        if len(self.global_protos) != 0:  # 有global proto了之后
            all_global_protos_keys = np.array(list(self.global_protos.keys()))
            all_f = []
            mean_f = []
            for protos_key in all_global_protos_keys:  # 遍历已经有global表示的label
                temp_f = self.global_protos[protos_key]  # 取出所有proto
                temp_f = torch.cat(temp_f, dim=0).to(self.device)  # proto_num， proto_vector
                all_f.append(temp_f.cpu())
                mean_f.append(torch.mean(temp_f, dim=0, keepdim=True).cpu())
            all_f = [item.detach() for item in all_f]
            mean_f = [item.detach() for item in mean_f]

        for e in range(self.epochs):
            agg_protos_label = {}
            for step, (batch_x, batch_y) in enumerate(self.train_data):
                self.model.zero_grad()
                batch_x = self.data_to_device(batch_x, device=self.device)
                batch_y = self.data_to_device(batch_y, device=self.device)

                f = self.model.features_extras(batch_x)
                outputs = self.model.classifier_head(f)

                lossCE = self.criterion(outputs, batch_y)

                if len(self.global_protos) == 0:
                    loss_InfoNCE = 0 * lossCE
                else:
                    i = 0
                    loss_InfoNCE = None

                    for label in batch_y:
                        if label.item() in self.global_protos.keys():
                            f_now = f[i].unsqueeze(0) # 1, 1280
                            loss_instance = self.hierarchical_info_loss(f_now, label, all_f, mean_f, all_global_protos_keys)
                            if loss_InfoNCE is None:
                                loss_InfoNCE = loss_instance
                            else:
                                loss_InfoNCE += loss_instance
                        i += 1
                    loss_InfoNCE = loss_InfoNCE / i
                loss_InfoNCE = loss_InfoNCE # 0

                loss = lossCE + loss_InfoNCE
                loss.backward()
                optimizer.step()

                batch_mean_loss = loss.item()
                total_loss += batch_mean_loss * len(batch_y)

                if step == self.epochs - 1:
                    for i in range(len(batch_y)):
                        if batch_y[i].item() in agg_protos_label:
                            agg_protos_label[batch_y[i].item()].append(f[i, :].unsqueeze(0))
                        else:
                            agg_protos_label[batch_y[i].item()] = [f[i, :].unsqueeze(0)]

        # 平均本地protos,每个标签只有一个表示
        agg_protos = agg_func(agg_protos_label)

        del optimizer
        # self.local_protos[self.id] = agg_protos
        return total_loss / (self.datavol * self.epochs), agg_protos


    def hierarchical_info_loss(self, f_now, label, all_f, mean_f, all_global_protos_keys):
        # 问题出在all_f中可能有一个label对应多个proto的情况，导致形状不一致
        pos_idxs = np.where(all_global_protos_keys == label.item())[0][0]
        neg_idxs = np.where(all_global_protos_keys != label.item())[0]
        f_pos = all_f[pos_idxs].to(self.device) # 因为只有一个所以也无所谓吧？
        f_neg = torch.cat([all_f[ni] for ni in neg_idxs]).to(self.device)
        xi_info_loss = self.calculate_infonce(f_now, f_pos, f_neg) # 对比学习loss

        mean_f_pos = np.array(mean_f)[all_global_protos_keys == label.item()][0]
        mean_f_pos = torch.from_numpy(mean_f_pos).to(self.device)
        mean_f_pos = mean_f_pos.view(1, -1)
        # mean_f_neg = torch.cat(list(np.array(mean_f)[all_global_protos_keys != label.item()]), dim=0).to(self.device)
        # mean_f_neg = mean_f_neg.view(9, -1)

        loss_mse = nn.MSELoss()
        cu_info_loss = loss_mse(f_now, mean_f_pos) # 每个类别的global mean

        hierar_info_loss = xi_info_loss + cu_info_loss
        return hierar_info_loss

    def calculate_infonce(self, f_now, f_pos, f_neg):
        f_proto = torch.cat((f_pos, f_neg), dim=0)
        l = torch.cosine_similarity(f_now, f_proto, dim=1)
        l = l / self.infoNCET # tempreture 放大

        exp_l = torch.exp(l)
        exp_l = exp_l.view(1, -1)
        pos_mask = [1 for _ in range(f_pos.shape[0])] + [0 for _ in range(f_neg.shape[0])]
        pos_mask = torch.tensor(pos_mask, dtype=torch.float).to(self.device)
        pos_mask = pos_mask.view(1, -1)
        # pos_l = torch.einsum('nc,ck->nk', [exp_l, pos_mask])
        pos_l = exp_l * pos_mask
        sum_pos_l = pos_l.sum(1)
        sum_exp_l = exp_l.sum(1)
        infonce_loss = -torch.log(sum_pos_l / sum_exp_l)
        return infonce_loss

    def unpack(self, received_pkg):
        self.current_rounds = received_pkg['current_rounds']
        self.lr = received_pkg['lr']
        self.momentum = received_pkg['momentum']
        self.weight_decay = received_pkg['weight_decay']
        self.global_protos = received_pkg['global_protos']
        self.stage = received_pkg['stage']
        return received_pkg['model']