import torch
from model.GFT_encoder import Encoder, InnerProductDecoder
from model.GFT_ft_model import TaskModel
from model.GFT_pt_model import PretrainModel
from model.GFT_vq import VectorQuantize
import torch.nn as nn
import numpy as np
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from collections import defaultdict
import copy
from utils.basic_utils import construct_graph
from sentence_transformers import SentenceTransformer
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from torch.optim import Adam, AdamW
from torch_geometric.loader import NeighborLoader
from torch_geometric.utils import mask_feature, dropout_adj


class Server:
    
    def __init__(self, args, device, kmeans_init=True):
        self.per_global_message = None
        self.args = args
        self.device = device
        self.dim = 768
        
        self.encoder = Encoder(
            input_dim=self.dim,
            hidden_dim=self.dim,
            activation=nn.ReLU,
            num_layers=2,
            backbone="mysage",
            normalize="batch",
            dropout=0.15
            )

        self.vq = VectorQuantize(
            dim=self.dim,
            codebook_size=128,
            codebook_dim=self.dim,
            heads=4,
            separate_codebook_per_head=True,
            decay=0.8,
            commitment_weight=10,
            use_cosine_sim=True,  # Cosine Codebook Works, Euclidean Codebook Collapses
            orthogonal_reg_weight=1,
            orthogonal_reg_max_codes=32,
            orthogonal_reg_active_codes_only=False,
            kmeans_init=kmeans_init,
            ema_update=False,
        )

        self.feat_recon_decoder = nn.Linear(self.dim, self.dim)
        self.topo_recon_decoder = InnerProductDecoder(hidden_dim=self.dim, output_dim=self.dim)
        self.topo_sem_recon_decoder = nn.Linear(self.dim * 2, self.dim)


        # pretrain model
        self.global_model = PretrainModel(
            encoder=self.encoder, vq=self.vq,
            feat_recon_decoder=self.feat_recon_decoder,
            topo_recon_decoder=self.topo_recon_decoder,
            topo_sem_recon_decoder=self.topo_sem_recon_decoder,
        ).to(device)
        

        self.round_id = 0 
        
        # pretrain model params
        self.pretrain_epochs = 2
        self.pretrain_batch_size = 1024
        self.feat_p=0.2
        self.edge_p=0.2
        self.topo_recon_ratio=0.1
        self.feat_lambda=100
        self.topo_lambda=0.01
        self.topo_sem_lambda=100
        self.sem_lambda=1
        self.sem_encoder_decay=0.99
        self.pretrain_lr=1e-4
        self.separate_codebook_per_head=True
        self.separate_decoder_for_each_head=True
        self.use_cosine_sim=True
        self.use_z_in_predict=True
        self.no_lin_clf=False
        self.no_proto_clf=False
        
        
    def initialization(self, domain_prototypes):
        from scipy.stats import pearsonr, spearmanr
        from scipy.spatial.distance import mahalanobis
        # from pyemd import emd
        
        from scipy.stats import wasserstein_distance

        
        
        self.domain_prototypes = domain_prototypes
        self.vq.forward(x=self.domain_prototypes.to(self.device))
        


    def execute(self, local_message_dict):
        
        
        # rho * old_glb + (1-rho) * new_avg
        
        with torch.no_grad():
            num_total_samples = len(local_message_dict.keys())
            for it, client_id in enumerate(local_message_dict.keys()):
                weight = 1 / num_total_samples
                for (local_param, global_param) in zip(local_message_dict[client_id]['weight'], self.global_model.parameters()):
                    if it == 0:
                        global_param.data.copy_(weight * local_param)
                    else:
                        global_param.data += weight * local_param
                   
    def get_global_message(self):
        global_message = list(self.global_model.parameters())
        return global_message
        
        
