import torch
from model.GFT_encoder import Encoder, InnerProductDecoder
from model.GFT_ft_model import TaskModel
from model.GFT_pt_model import PretrainModel
from model.GFT_vq import VectorQuantize
import torch.nn as nn
import numpy as np
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from collections import defaultdict
import copy
from utils.basic_utils import construct_graph
from sentence_transformers import SentenceTransformer
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from torch.optim import Adam, AdamW
from scipy.stats import wasserstein_distance  # 计算一维 EMD
from torch_geometric.loader import NeighborLoader
from torch_geometric.utils import mask_feature, dropout_adj


class Server:
    
    def __init__(self, args, device, kmeans_init=True):
        self.per_global_message = None
        self.args = args
        self.device = device
        self.dim = 768
        
        self.encoder = Encoder(
            input_dim=self.dim,
            hidden_dim=self.dim,
            activation=nn.ReLU,
            num_layers=2,
            backbone="mysage",
            normalize="batch",
            dropout=0.15
            )

        self.vq = VectorQuantize(
            dim=self.dim,
            codebook_size=128,
            codebook_dim=self.dim,
            heads=4,
            separate_codebook_per_head=True,
            decay=0.8,
            commitment_weight=10,
            use_cosine_sim=True,  # Cosine Codebook Works, Euclidean Codebook Collapses
            orthogonal_reg_weight=1,
            orthogonal_reg_max_codes=32,
            orthogonal_reg_active_codes_only=False,
            kmeans_init=kmeans_init,
            ema_update=False,
        )

        self.feat_recon_decoder = nn.Linear(self.dim, self.dim)
        self.topo_recon_decoder = InnerProductDecoder(hidden_dim=self.dim, output_dim=self.dim)
        self.topo_sem_recon_decoder = nn.Linear(self.dim * 2, self.dim)


        # pretrain model
        self.global_model = PretrainModel(
            encoder=self.encoder, vq=self.vq,
            feat_recon_decoder=self.feat_recon_decoder,
            topo_recon_decoder=self.topo_recon_decoder,
            topo_sem_recon_decoder=self.topo_sem_recon_decoder,
        ).to(device)
        

        self.round_id = 0 
        
        # pretrain model params
        self.init_epochs = 10
        self.pretrain_epochs = 2
        self.pretrain_batch_size = 1024
        self.feat_p=0.2
        self.edge_p=0.2
        self.topo_recon_ratio=0.1
        self.feat_lambda=100
        self.topo_lambda=0.01
        self.topo_sem_lambda=100
        self.sem_lambda=1
        self.sem_encoder_decay=0.99
        self.pretrain_lr=1e-4
        self.separate_codebook_per_head=True
        self.separate_decoder_for_each_head=True
        self.use_cosine_sim=True
        self.use_z_in_predict=True
        self.no_lin_clf=False
        self.no_proto_clf=False
        
        self.rho = 0.9

        
      
    def initialization(self, domain_prototypes):
        from scipy.stats import pearsonr, spearmanr
        from scipy.spatial.distance import mahalanobis
        # from pyemd import emd
        
        from scipy.stats import wasserstein_distance

        
        
        self.domain_prototypes = domain_prototypes
        self.vq.forward(x=self.domain_prototypes.to(self.device))
        
        dist_matrix = []
        num_clients = self.domain_prototypes.shape[0]

        # cov_matrix_inv = np.linalg.inv(np.cov(self.domain_prototypes.cpu().numpy().T))
        
        
        for client_i in range(num_clients):
            row = []
            for client_j in range(num_clients):
                proto_i = self.domain_prototypes[client_i, :].cpu()
                proto_j = self.domain_prototypes[client_j, :].cpu()

                d = 1-torch.cosine_similarity(proto_i, proto_j, dim=0)
                # d = torch.dot(proto_i, proto_j)
                # d = mahalanobis(proto_i, proto_j, cov_matrix_inv)
                # d = wasserstein_distance(proto_i, proto_j)
                # d = 1 - pearsonr(proto_i, proto_j)[0]
                
                # d = 1 - spearmanr(proto_i, proto_j)[0]
                # d = torch.norm(proto_i-proto_j)

                row.append(d)
            dist_matrix.append(row)

        dist_matrix = np.array(dist_matrix)

        plt.figure(figsize=(10, 8))
        sns.heatmap(
            dist_matrix,
            annot=True,
            fmt=".2f",
            cmap="YlOrRd",
            xticklabels=[f"Client {i}" for i in range(num_clients)],
            yticklabels=[f"Client {i}" for i in range(num_clients)]
        )
        plt.title("Client Dist Heatmap")
        plt.xlabel("Client ID")
        plt.ylabel("Client ID")
        plt.tight_layout()
        plt.savefig(f"./domain_cos_dis.png")

        
        


    def execute(self, local_message_dict):
        
        
        # rho * old_glb + (1-rho) * new_avg
        
        with torch.no_grad():
            num_total_samples = len(local_message_dict.keys())
            for it, client_id in enumerate(local_message_dict.keys()):
                weight = 1 / num_total_samples
                for (local_param, global_param) in zip(local_message_dict[client_id]['weight'], self.global_model.parameters()):
                    if it == 0:
                        # global_param.data.copy_(self.rho * global_param.data)
                        global_param.data.copy_(weight * local_param)
                    # else:
                        global_param.data += weight * local_param

                    # global_param.data += (1-self.rho) * weight * local_param
                   
    def get_global_message(self):
        global_message = list(self.global_model.parameters())
        return global_message
        
        
