import torch
from model.GFT_encoder import Encoder, InnerProductDecoder
from model.GFT_ft_model import TaskModel
from model.GFT_pt_model import PretrainModel
from model.GFT_vq import VectorQuantize
import torch.nn as nn
import numpy as np
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from collections import defaultdict
import copy
from utils.basic_utils import construct_graph
from sentence_transformers import SentenceTransformer
import torch.nn.functional as F


class Server:
    
    def __init__(self, args, device, kmeans_init=False):
        self.per_global_message = None
        self.args = args
        self.device = device
        
        self.dim = 768
        
        self.encoder = Encoder(
            input_dim=self.dim,
            hidden_dim=self.dim,
            activation=nn.ReLU,
            num_layers=2,
            backbone="sage",
            normalize="batch",
            dropout=0.15
            )

        self.vq = VectorQuantize(
            dim=self.dim,
            codebook_size=128,
            codebook_dim=self.dim,
            heads=4,
            separate_codebook_per_head=True,
            decay=0.8,
            commitment_weight=10,
            use_cosine_sim=True,  # Cosine Codebook Works, Euclidean Codebook Collapses
            orthogonal_reg_weight=1,
            orthogonal_reg_max_codes=32,
            orthogonal_reg_active_codes_only=False,
            kmeans_init=kmeans_init,
            ema_update=False,
        )
        # self.encoder.load_state_dict(torch.load('/home/ai/xkli/GFT-main/ckpts/pretrain_model/codebook_size_128_layer_2_pretrain_on_all_seed_42/encoder_25.pt'))
        # self.vq.load_state_dict(torch.load('/home/ai/xkli/GFT-main/ckpts/pretrain_model/codebook_size_128_layer_2_pretrain_on_all_seed_42/vq_25.pt'))
        self.feat_recon_decoder = nn.Linear(self.dim, self.dim)
        self.topo_recon_decoder = InnerProductDecoder(hidden_dim=self.dim, output_dim=self.dim)
        self.topo_sem_recon_decoder = nn.Linear(self.dim * 2, self.dim)


        # pretrain model
        self.global_model = PretrainModel(
            encoder=self.encoder, vq=self.vq,
            feat_recon_decoder=self.feat_recon_decoder,
            topo_recon_decoder=self.topo_recon_decoder,
            topo_sem_recon_decoder=self.topo_sem_recon_decoder,
        ).to(device)


    
  
    
    
    def intra_cluster_personalization_naive(self, intra_cluster_message):
        # 对每一个 client 的 local model 做初始化, 创建 student 和 teacher
        for client_id in intra_cluster_message.keys():
            local_teacher = copy.deepcopy(self.global_model)    
            for (local_param, local_init_param) in zip(local_teacher.parameters(), intra_cluster_message[client_id]["weight"]):
                local_param.data.copy_(local_init_param)
            local_teacher.eval()
            local_student = copy.deepcopy(local_teacher)
            local_student.eval()
            
            
            intra_cluster_message[client_id]["teacher"] = local_teacher
            intra_cluster_message[client_id]["student"] = local_student
        
        
        for client_id in intra_cluster_message.keys():
            # if type(intra_cluster_message[client_id]["knowledge"]) is dict:
            #     list_prompt = [i for i in intra_cluster_message[client_id]["knowledge"].keys()]
            #     list_knowledge = [i for i in intra_cluster_message[client_id]["knowledge"].values()]
            #     tensor_knowledge = torch.vstack(list_knowledge)
            #     z_q, indices, _, _= intra_cluster_message[client_id]["teacher"].vq(tensor_knowledge) # codebook params.
            #     x_hat = intra_cluster_message[client_id]["teacher"].feat_recon_decoder(z_q) # feat_decoder params.
            #     graph = construct_graph(x=x_hat, lm=self.ST, prompt=list_prompt)
            #     z = intra_cluster_message[client_id]["teacher"].encoder( # encoder params
            #                                                 graph[0],
            #                                                 graph[1],
            #                                                 graph[2])
            #     intra_cluster_message[client_id]["z"] = z.detach()
            #     intra_cluster_message[client_id]['domain_prototype'] = torch.mean(tensor_knowledge, dim=0)
                
            # else: # list
            tensor_knowledge = torch.vstack(intra_cluster_message[client_id]["knowledge"]).detach()
            z_q, indices, _, _= intra_cluster_message[client_id]["teacher"].vq(tensor_knowledge) # codebook params.
            x_hat = intra_cluster_message[client_id]["teacher"].feat_recon_decoder(z_q) # feat_decoder params.
            graph = construct_graph(x=x_hat)
            z = intra_cluster_message[client_id]["teacher"].encoder( # encoder params.
                                                        graph[0],
                                                        graph[1],
                                                        graph[2])
            intra_cluster_message[client_id]["z"] = z.detach()
            intra_cluster_message[client_id]['domain_prototype'] = torch.mean(tensor_knowledge, dim=0)
                
        
            
        # kd
        for client_i in intra_cluster_message.keys(): # keys: 'weight', 'knowledge', 'edge_attr', 'teacher', 'student', 'proto_graph'
            intra_cluster_message[client_i]["student"].train()
            optimizer = torch.optim.Adam(intra_cluster_message[client_i]["student"].parameters(), lr=0.001)
            
            

            kd_weight = torch.tensor([torch.cosine_similarity(intra_cluster_message[client_i]["domain_prototype"],
                                                        intra_cluster_message[client_j]["domain_prototype"], 
                                                        dim=0) for client_j in intra_cluster_message.keys()]).to(self.device)
            kd_weight = F.softmax(kd_weight)
            
            for epoch in range(3):
                optimizer.zero_grad()
                kd_loss = 0
                for it, client_j in enumerate(intra_cluster_message.keys()):
                    # if type(intra_cluster_message[client_j]["knowledge"]) is dict:
                    #     list_knowledge = [i for i in intra_cluster_message[client_j]["knowledge"].values()]
                    #     tensor_knowledge = torch.vstack(list_knowledge).detach()
                    # else:
                    tensor_knowledge = torch.vstack(intra_cluster_message[client_j]["knowledge"]).detach()
                
                    student_z_q, _, _, _= intra_cluster_message[client_i]["student"].vq(tensor_knowledge) # codebook params.
                    student_x_hat = intra_cluster_message[client_i]["student"].feat_recon_decoder(student_z_q) # feat_decoder params.
                    student_constructed_graph = construct_graph(x=student_x_hat)
                    student_z = intra_cluster_message[client_i]["student"].encoder( # encoder params
                        student_constructed_graph[0],
                        student_constructed_graph[1],
                        student_constructed_graph[2])

                    kd_loss += kd_weight[it] * torch.mean(torch.mean(torch.abs(
                        student_z - intra_cluster_message[client_j]["z"]), dim=1))         
                    
                kd_loss.backward()
                optimizer.step()                    

                    

        # personalized global model
        per_global_model_dict = {}
        
        for client_i in intra_cluster_message.keys():
            per_global_model_dict[client_i] = list(intra_cluster_message[client_i]["student"].parameters())
            
        self.per_global_message = per_global_model_dict   
        
        
        
    
    
    def intra_cluster_personalization_domain_level(self, intra_cluster_message, compute_glb=False):
        
        for client_id in intra_cluster_message.keys():
            local_teacher = copy.deepcopy(self.global_model)    
            for (local_param, local_init_param) in zip(local_teacher.parameters(), intra_cluster_message[client_id]["weight"]):
                local_param.data.copy_(local_init_param)
            local_teacher.eval()
            local_student = copy.deepcopy(local_teacher)
            local_student.eval()
            
            
            intra_cluster_message[client_id]["teacher"] = local_teacher

        for client_id in intra_cluster_message.keys():
            # if type(intra_cluster_message[client_id]["knowledge"]) is dict:
            #     list_prompt = [i for i in intra_cluster_message[client_id]["knowledge"].keys()]
            #     list_knowledge = [i for i in intra_cluster_message[client_id]["knowledge"].values()]
            #     tensor_knowledge = torch.vstack(list_knowledge)
            #     z_q, indices, _, _= intra_cluster_message[client_id]["teacher"].vq(tensor_knowledge) # codebook params.
            #     x_hat = intra_cluster_message[client_id]["teacher"].feat_recon_decoder(z_q) # feat_decoder params.
            #     graph = construct_graph(x=x_hat, lm=self.ST, prompt=list_prompt)
            #     z = intra_cluster_message[client_id]["teacher"].encoder( # encoder params
            #                                                 graph[0],
            #                                                 graph[1],
            #                                                 graph[2])
            #     intra_cluster_message[client_id]["z"] = z.detach()
            #     intra_cluster_message[client_id]['domain_prototype'] = torch.mean(tensor_knowledge, dim=0)
                
            # else: # list
            tensor_knowledge = intra_cluster_message[client_id]["knowledge"].detach()
            z_q, indices, _, _= intra_cluster_message[client_id]["teacher"].vq(tensor_knowledge) # codebook params.
            x_hat = intra_cluster_message[client_id]["teacher"].feat_recon_decoder(z_q) # feat_decoder params.
            graph = construct_graph(x=x_hat)
            z = intra_cluster_message[client_id]["teacher"].encoder( # encoder params.
                                                        graph[0],
                                                        graph[1],
                                                        graph[2])
            intra_cluster_message[client_id]["z"] = z.detach()
            intra_cluster_message[client_id]['domain_prototype'] = torch.mean(tensor_knowledge, dim=0)
            
        
            
        # kd
        for client_i in intra_cluster_message.keys(): # keys: 'weight', 'knowledge', 'edge_attr', 'teacher', 'student', 'proto_graph'
            intra_cluster_message[client_i]["student"].train()
            optimizer = torch.optim.Adam(intra_cluster_message[client_i]["student"].parameters(), lr=0.001)
            
            

            kd_weight = torch.tensor([torch.cosine_similarity(intra_cluster_message[client_i]["domain_prototype"],
                                                        intra_cluster_message[client_j]["domain_prototype"], 
                                                        dim=0) for client_j in intra_cluster_message.keys()]).to(self.device)
            kd_weight = F.softmax(kd_weight)
            
            for epoch in range(3):
                optimizer.zero_grad()
                kd_loss = 0
                for it, client_j in enumerate(intra_cluster_message.keys()):
                    # if type(intra_cluster_message[client_j]["knowledge"]) is dict:
                    #     list_knowledge = [i for i in intra_cluster_message[client_j]["knowledge"].values()]
                    #     tensor_knowledge = torch.vstack(list_knowledge).detach()
                    # else:
                    tensor_knowledge = intra_cluster_message[client_j]["knowledge"].detach()
                
                    student_z_q, _, _, _= intra_cluster_message[client_i]["student"].vq(tensor_knowledge) # codebook params.
                    student_x_hat = intra_cluster_message[client_i]["student"].feat_recon_decoder(student_z_q) # feat_decoder params.
                    student_constructed_graph = construct_graph(x=student_x_hat)
                    student_z = intra_cluster_message[client_i]["student"].encoder( # encoder params
                        student_constructed_graph[0],
                        student_constructed_graph[1],
                        student_constructed_graph[2])

                    kd_loss += kd_weight[it] * torch.mean(torch.mean(torch.abs(
                        student_z - intra_cluster_message[client_j]["z"]), dim=1))         
                    
                kd_loss.backward()
                optimizer.step()                    

                    

        # personalized global model
        per_global_model_dict = {}
        
        for client_i in intra_cluster_message.keys():
            per_global_model_dict[client_i] = list(intra_cluster_message[client_i]["student"].parameters())
            
        self.per_global_message = per_global_model_dict   
        
        
        
        
        
        # obtain general global model    
        if compute_glb:       
            with torch.no_grad():
                num_total_samples = len(intra_cluster_message.keys())
                for it, client_id in enumerate(intra_cluster_message.keys()):
                    weight = 1 / num_total_samples
                    for (local_param, global_param) in zip(intra_cluster_message[client_id]["weight"], self.global_model.parameters()):
                        if it == 0:
                            global_param.data.copy_(weight * local_param)
                        else:
                            global_param.data += weight * local_param
                            
            # gen global model kd
            
            self.global_model.train()
            optimizer = torch.optim.Adam(self.global_model.parameters(), lr=0.001)
            
            
            for epoch in range(3):
                optimizer.zero_grad()
                kd_loss = 0
                for client_i in intra_cluster_message.keys(): # keys: 'weight', 'knowledge', 'edge_attr', 'teacher', 'student', 'proto_graph'
                    # if type(intra_cluster_message[client_i]["knowledge"]) is dict:
                    #     list_knowledge = [i for i in intra_cluster_message[client_i]["knowledge"].values()]
                    #     tensor_knowledge = torch.vstack(list_knowledge).detach()
                    # else:
                    tensor_knowledge = intra_cluster_message[client_i]["knowledge"].detach()
                    
                
                    student_z_q, _, _, _= self.global_model.vq(tensor_knowledge) # codebook params.
                    student_x_hat = self.global_model.feat_recon_decoder(student_z_q) # feat_decoder params.
                    student_constructed_graph = construct_graph(x=student_x_hat)
                    student_z = self.global_model.encoder( # encoder params
                        student_constructed_graph[0],
                        student_constructed_graph[1],
                        student_constructed_graph[2])

                    kd_loss += torch.mean(torch.mean(torch.abs(
                        student_z - intra_cluster_message[client_i]["z"]), dim=1))         
                        
                kd_loss.backward()
                optimizer.step()                                      
                        
        

    
    def intra_cluster_personalization_prototype_level(self, intra_cluster_message, compute_glb=False):
        
        for client_id in intra_cluster_message.keys():
            local_teacher = copy.deepcopy(self.global_model)    
            for (local_param, local_init_param) in zip(local_teacher.parameters(), intra_cluster_message[client_id]["weight"]):
                local_param.data.copy_(local_init_param)
            local_teacher.eval()
            local_student = copy.deepcopy(local_teacher)
            local_student.eval()
            
            
            intra_cluster_message[client_id]["teacher"] = local_teacher
            intra_cluster_message[client_id]["student"] = local_student

        for client_id in intra_cluster_message.keys():
            # if type(intra_cluster_message[client_id]["knowledge"]) is dict:
            #     list_prompt = [i for i in intra_cluster_message[client_id]["knowledge"].keys()]
            #     list_knowledge = [i for i in intra_cluster_message[client_id]["knowledge"].values()]
            #     tensor_knowledge = torch.vstack(list_knowledge)
            #     z_q, indices, _, _= intra_cluster_message[client_id]["teacher"].vq(tensor_knowledge) # codebook params.
            #     x_hat = intra_cluster_message[client_id]["teacher"].feat_recon_decoder(z_q) # feat_decoder params.
            #     graph = construct_graph(x=x_hat, lm=self.ST, prompt=list_prompt)
            #     z = intra_cluster_message[client_id]["teacher"].encoder( # encoder params
            #                                                 graph[0],
            #                                                 graph[1],
            #                                                 graph[2])
            #     intra_cluster_message[client_id]["z"] = z.detach()
            #     intra_cluster_message[client_id]['domain_prototype'] = torch.mean(tensor_knowledge, dim=0)
                
            # else: # list
            tensor_knowledge = torch.vstack(intra_cluster_message[client_id]["knowledge"]).detach()
            z_q, indices, _, _= intra_cluster_message[client_id]["teacher"].vq(tensor_knowledge) # codebook params.
            x_hat = intra_cluster_message[client_id]["teacher"].feat_recon_decoder(z_q) # feat_decoder params.
            graph = construct_graph(x=x_hat)
            z = intra_cluster_message[client_id]["teacher"].encoder( # encoder params.
                                                        graph[0],
                                                        graph[1],
                                                        graph[2])
            intra_cluster_message[client_id]["z"] = z.detach()
            intra_cluster_message[client_id]['domain_prototype'] = torch.mean(tensor_knowledge, dim=0)
            
    
            
        # kd
        for client_i in intra_cluster_message.keys(): # keys: 'weight', 'knowledge', 'edge_attr', 'teacher', 'student', 'proto_graph'
            print(f"[server] training client: {client_i} personalized glb model.")
            intra_cluster_message[client_i]["student"].train()
            optimizer = torch.optim.Adam(intra_cluster_message[client_i]["student"].parameters(), lr=0.001)
            
            kd_weight = []
            
            # if type(intra_cluster_message[client_i]["knowledge"]) is dict:
            #     list_knowledge = [i for i in intra_cluster_message[client_i]["knowledge"].values()]
            #     central_tensor_knowledge = torch.vstack(list_knowledge).detach()
            # else:
            central_tensor_knowledge = torch.vstack(intra_cluster_message[client_i]["knowledge"]).detach()
            
            
            for client_j in intra_cluster_message.keys():
                # if type(intra_cluster_message[client_j]["knowledge"]) is dict:
                #     list_knowledge = [i for i in intra_cluster_message[client_j]["knowledge"].values()]
                #     tensor_knowledge = torch.vstack(list_knowledge).detach()
                # else:
                tensor_knowledge = torch.vstack(intra_cluster_message[client_j]["knowledge"]).detach()
                
                for prototype_j_id in range(tensor_knowledge.shape[0]):
                    max_sim = -999
                    prototype_j = tensor_knowledge[prototype_j_id]
                    for prototype_i_id in range(central_tensor_knowledge.shape[0]):
                        prototype_i = central_tensor_knowledge[prototype_i_id]
                        sim = torch.cosine_similarity(prototype_i, prototype_j, dim=0)
                        max_sim = max(max_sim, sim)        
                        
                    assert max_sim != -999
                    kd_weight.append(max_sim)
                    
            

            kd_weight = F.softmax(torch.hstack(kd_weight))
            
            for epoch in range(3):
                ptr = 0
                optimizer.zero_grad()
                kd_loss = 0
                for it, client_j in enumerate(intra_cluster_message.keys()):
                    tensor_knowledge = torch.vstack(intra_cluster_message[client_j]["knowledge"]).detach()
                
                    student_z_q, _, _, _= intra_cluster_message[client_i]["student"].vq(tensor_knowledge) # codebook params.
                    student_x_hat = intra_cluster_message[client_i]["student"].feat_recon_decoder(student_z_q) # feat_decoder params.
                    student_constructed_graph = construct_graph(x=student_x_hat)
                    student_z = intra_cluster_message[client_i]["student"].encoder( # encoder params
                        student_constructed_graph[0],
                        student_constructed_graph[1],
                        student_constructed_graph[2])

                    kd_loss +=  torch.mean( torch.mm( kd_weight[ptr:ptr+student_z.shape[0]].view(1,-1), torch.abs(
                        student_z - intra_cluster_message[client_j]["z"])), dim=1)
                    ptr += student_z.shape[0]
                    
                kd_loss.backward()
                optimizer.step()                    

                    

        # personalized global model
        per_global_model_dict = {}
        
        for client_i in intra_cluster_message.keys():
            per_global_model_dict[client_i] = list(intra_cluster_message[client_i]["student"].parameters())
            
        self.per_global_message = per_global_model_dict   
        
        
        
        
        
        # obtain general global model     
        if compute_glb:   
            print("[server] computing global model.")
            with torch.no_grad():
                num_total_samples = len(intra_cluster_message.keys())
                for it, client_id in enumerate(intra_cluster_message.keys()):
                    weight = 1 / num_total_samples
                    for (local_param, global_param) in zip(intra_cluster_message[client_id]["weight"], self.global_model.parameters()):
                        if it == 0:
                            global_param.data.copy_(weight * local_param)
                        else:
                            global_param.data += weight * local_param
                            
            # gen global model kd
            
            self.global_model.train()
            optimizer = torch.optim.Adam(self.global_model.parameters(), lr=0.001)
            
            
            for epoch in range(3):
                optimizer.zero_grad()
                kd_loss = 0
                for client_i in intra_cluster_message.keys(): # keys: 'weight', 'knowledge', 'edge_attr', 'teacher', 'student', 'proto_graph'
                    # if type(intra_cluster_message[client_i]["knowledge"]) is dict:
                    #     list_knowledge = [i for i in intra_cluster_message[client_i]["knowledge"].values()]
                    #     tensor_knowledge = torch.vstack(list_knowledge).detach()
                    # else:
                    tensor_knowledge = torch.vstack(intra_cluster_message[client_i]["knowledge"]).detach()
                    
                
                    student_z_q, _, _, _= self.global_model.vq(tensor_knowledge) # codebook params.
                    student_x_hat = self.global_model.feat_recon_decoder(student_z_q) # feat_decoder params.
                    student_constructed_graph = construct_graph(x=student_x_hat)
                    student_z = self.global_model.encoder( # encoder params
                        student_constructed_graph[0],
                        student_constructed_graph[1],
                        student_constructed_graph[2])

                    kd_loss += torch.mean(torch.mean(torch.abs(
                        student_z - intra_cluster_message[client_i]["z"]), dim=1))         
                        
                kd_loss.backward()
                optimizer.step()         


    def execute(self, local_message_dict, client_knowledge, compute_glb=False):
        # client_cluster, n_clusters, _ = self.clustering(client_knowledge)
        cluster_client = {}
        # for client_id, cluster_id in enumerate(client_cluster):
        #     if cluster_id not in cluster_client.keys():
        #         cluster_client[cluster_id] = [client_id]
        #     else:
        #         cluster_client[cluster_id].append(client_id)
        
        n_clusters = 1
        cluster_client[0] = list(range(len(local_message_dict)))
        for cluster_id in range(n_clusters):
            intra_cluster_message = {}
            for client_id in cluster_client[cluster_id]:
                intra_cluster_message[client_id] = {
                    "weight": local_message_dict[f"client_{client_id}"]["weight"],
                    "knowledge": client_knowledge[f"client_{client_id}"],
                }
            
            self.intra_cluster_personalization_domain_level(intra_cluster_message, compute_glb=compute_glb)
            # self.intra_cluster_personalization_prototype_level(intra_cluster_message, compute_glb=compute_glb)
            
            
        
                 
    def get_global_message(self):
        if self.per_global_message is None:
            global_message = list(self.global_model.parameters())
            return global_message
        else:
            return self.per_global_message
        
        
