import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from flcore.clients.clientdor import clientDOR
from flcore.servers.serverbase import Server
from flcore.clients.clientbase import load_item, save_item
from threading import Thread
from collections import defaultdict
from torch.utils.data import DataLoader


class FedDOR(Server):
    def __init__(self, args, times):
        super().__init__(args, times)

        # select slow clients
        self.set_slow_clients()
        self.set_clients(clientDOR)

        print(f"\nJoin ratio / total clients: {self.join_ratio} / {self.num_clients}")
        print("Finished creating server and clients.")

        # self.load_model()
        self.Budget = []
        self.num_classes = args.num_classes

        self.server_learning_rate = args.local_learning_rate
        self.batch_size = args.batch_size
        self.server_epochs = args.server_epochs
        self.margin_threthold = args.margin_threthold
        self.feature_dim = args.feature_dim
        self.server_hidden_dim = self.feature_dim
        
        if args.save_folder_name == 'temp' or 'temp' not in args.save_folder_name:
            GP = Generatable_Prototype(
                self.num_classes, 
                self.server_hidden_dim, 
                self.feature_dim, 
                self.device
            ).to(self.device)
            save_item(GP, self.role, 'GP', self.save_folder_name)
            print(GP)
        self.CEloss = nn.CrossEntropyLoss()
        self.MSEloss = nn.MSELoss()

        self.gap = torch.ones(self.num_classes, device=self.device) * 1e9
        self.min_gap = None
        self.max_gap = None


    def train(self):
        for i in range(self.global_rounds):
            s_t = time.time()
            self.selected_clients = self.select_clients()

            if i%self.eval_gap == 0:
                print(f"\n-------------Round number: {i}-------------")
                print("\nEvaluate heterogeneous models")
                self.evaluate(epoch = i)

            for client in self.selected_clients:
                client.train()

            self.receive_protos()
            self.update_GP(epoch = i)

            self.Budget.append(time.time() - s_t)
            print('-'*50, self.Budget[-1])

            if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt):
                break

        print("\nBest accuracy.")
        print(max(self.rs_test_acc))
        print(sum(self.Budget[1:])/len(self.Budget[1:]))

        self.save_results()
        

    def receive_protos(self):
        assert (len(self.selected_clients) > 0)

        self.uploaded_ids = []
        self.uploaded_protos = []
        uploaded_protos_per_client = []
        for client in self.selected_clients:
            self.uploaded_ids.append(client.id)
            protos = load_item(client.role, 'protos', client.save_folder_name)
            for k in protos.keys():
                self.uploaded_protos.append((protos[k], k))
            uploaded_protos_per_client.append(protos)

        # calculate class-wise minimum distance
        self.gap = torch.ones(self.num_classes, device=self.device) * 1e9
        avg_protos = proto_cluster(uploaded_protos_per_client)
        for k1 in avg_protos.keys():
            for k2 in avg_protos.keys():
                if k1 > k2:
                    dis = torch.norm(avg_protos[k1] - avg_protos[k2], p=2)
                    self.gap[k1] = torch.min(self.gap[k1], dis)
                    self.gap[k2] = torch.min(self.gap[k2], dis)
        self.min_gap = torch.min(self.gap)
        for i in range(len(self.gap)):
            if self.gap[i] > torch.tensor(1e8, device=self.device):
                self.gap[i] = self.min_gap
        self.max_gap = torch.max(self.gap)
        print('class-wise minimum distance', self.gap)
        print('min_gap', self.min_gap)
        print('max_gap', self.max_gap)

    def update_GP(self,epoch=None):

        GP = load_item(self.role, 'GP', self.save_folder_name)
        GP_opt = torch.optim.SGD(GP.parameters(), lr=self.server_learning_rate)
        GP.train()
        for e in range(self.server_epochs):
            proto_loader = DataLoader(self.uploaded_protos, self.batch_size, 
                                      drop_last=False, shuffle=True)
            for proto, y in proto_loader:
                y = torch.Tensor(y).type(torch.int64).to(self.device)

                proto_gen = GP(list(range(self.num_classes)))
                loss =  intra_class_loss(proto, y, proto_gen) + inter_class_loss(proto, y, proto_gen)
                
                GP_opt.zero_grad()
                loss.backward()
                GP_opt.step()
        
        print(f'Server loss: {loss.item()}')
        self.uploaded_protos = []
        save_item(GP, self.role, 'GP', self.save_folder_name)

        GP.eval()
        global_protos = defaultdict(list)
        for class_id in range(self.num_classes):
            global_protos[class_id] = GP(torch.tensor(class_id, device=self.device)).detach()
        save_item(global_protos, self.role, 'global_protos', self.save_folder_name)

def proto_cluster(protos_list):
    proto_clusters = defaultdict(list)
    for protos in protos_list:
        for k in protos.keys():
            proto_clusters[k].append(protos[k])

    for k in proto_clusters.keys():
        protos = torch.stack(proto_clusters[k])
        proto_clusters[k] = torch.mean(protos, dim=0).detach()

    return proto_clusters
            

class Generatable_Prototype(nn.Module):
    def __init__(self, num_categories, server_hid_dim, feat_dim, dev):
        super().__init__()
        self.dev = dev
        self.class_embeddings = nn.Embedding(num_embeddings=num_categories, embedding_dim=feat_dim)
        self.hidden_transform = nn.Sequential()
        self.hidden_transform.add_module('linear1', nn.Linear(feat_dim, server_hid_dim))
        self.hidden_transform.add_module('activation', nn.ReLU())
        
        self.projection = nn.Linear(server_hid_dim, feat_dim)

    def forward(self, category_id):
        cat_id_tensor = torch.as_tensor(category_id, dtype=torch.long, device=self.dev)
        
        category_emb = self.class_embeddings(cat_id_tensor)
        hidden_repr = self.hidden_transform(category_emb)
        final_output = self.projection(hidden_repr)
        
        return final_output

def intra_class_loss(representations, class_labels, prototypes):

    rep_norm = representations / torch.norm(representations, p=2, dim=1, keepdim=True)
    proto_norm = prototypes / torch.norm(prototypes, p=2, dim=1, keepdim=True)
    
    batch_size = representations.shape[0]
    class_indices = class_labels.unsqueeze(1).expand(batch_size, proto_norm.shape[1])
    matched_protos = torch.gather(proto_norm, dim=0, index=class_indices)
    
    cosine_sim = torch.diag(torch.matmul(rep_norm, matched_protos.T))
    
    per_sample_loss = torch.ones_like(cosine_sim) - cosine_sim
    average_loss = torch.mean(per_sample_loss)
    
    return average_loss

def inter_class_loss(sample_reps, class_labels, class_protos):

    normed_samples = F.normalize(sample_reps, p=2, dim=1)
    normed_protos = F.normalize(class_protos, p=2, dim=1)
    
    cosine_similarities = torch.einsum('ik,jk->ij', normed_samples, normed_protos)
    abs_similarities = torch.abs(cosine_similarities)
    
    batch_size, num_classes = abs_similarities.shape
    other_class_mask = torch.zeros(batch_size, num_classes, device=abs_similarities.device)
    other_class_mask = other_class_mask.fill_(1.0)
    other_class_mask = other_class_mask.scatter(1, class_labels.unsqueeze(1), 0.0)
    
    non_target_sims = abs_similarities[other_class_mask.bool()]
    total_inter_sim = non_target_sims.sum() * 10
    valid_count = other_class_mask.sum()
    inter_loss_val = total_inter_sim / valid_count
    
    return inter_loss_val  