

import torch
import os
import numpy as np
import h5py
import copy
import time
import random
import json
from utils.data_utils import read_client_data
from utils.dlg import DLG
from utils.attack_utils import attack,train_attack_model
import matplotlib.pyplot as plt
import xgboost as xgb


class Server(object):
    def __init__(self, args):
        # Set up the main attributes
        self.args = args
        self.device = args.device
        self.dataset = args.dataset
        self.num_classes = args.num_classes
        self.global_rounds = args.global_rounds
        self.local_epochs = args.local_epochs
        self.batch_size = args.batch_size
        self.learning_rate = args.local_learning_rate
        self.global_model = copy.deepcopy(args.model)
        self.num_clients = args.num_clients
        self.join_ratio = args.join_ratio
        self.random_join_ratio = args.random_join_ratio
        self.num_join_clients = int(self.num_clients * self.join_ratio)
        self.current_num_join_clients = self.num_join_clients
        self.algorithm = args.algorithm
        self.time_select = args.time_select
        self.goal = args.goal
        self.time_threthold = args.time_threthold
        self.save_folder_name = args.save_folder_name
        self.top_cnt = 100
        self.auto_break = args.auto_break

        self.clients = []
        self.selected_clients = []
        self.train_slow_clients = []
        self.send_slow_clients = []

        self.uploaded_weights = []
        self.uploaded_ids = []
        self.uploaded_models = []

        self.rs_test_acc = []
        self.rs_test_auc = []
        self.rs_train_loss = []

        self.eval_gap = args.eval_gap
        self.client_drop_rate = args.client_drop_rate
        self.train_slow_rate = args.train_slow_rate
        self.send_slow_rate = args.send_slow_rate

        self.dlg_eval = args.dlg_eval
        self.dlg_gap = args.dlg_gap
        self.batch_num_per_client = args.batch_num_per_client

        self.num_new_clients = args.num_new_clients
        self.new_clients = []
        self.eval_new_clients = False
        self.fine_tuning_epoch_new = args.fine_tuning_epoch_new
        
        self.unlearning_clients=args.unlearning_clients #此处传入的仍是id list 在set clients当中变成 clients list
        self.unlearning_ground=args.unlearning_ground
        self.post_training_ground=args.post_training_ground

        self.attack_acc=[]
        self.history_update=[[] for _ in range(self.num_clients)]
        self.old_clients=[]
        self.old_global_model=[]
        self.attacker = xgb.XGBClassifier()

    def set_clients(self, clientObj):
        for i, train_slow, send_slow in zip(range(self.num_clients), self.train_slow_clients, self.send_slow_clients):
            train_data = read_client_data(self.dataset, i, is_train=True)
            test_data = read_client_data(self.dataset, i, is_train=False)
            client = clientObj(self.args, 
                            id=i, 
                            train_samples=len(train_data), 
                            test_samples=len(test_data), 
                            train_slow=train_slow, 
                            send_slow=send_slow,
                            unlearning= (i in self.unlearning_clients) if self.unlearning_clients else False)
            self.clients.append(client)
        self.unlearning_clients=[self.clients[i] for i in self.unlearning_clients] if self.unlearning_clients else []
        

    # random select slow clients
    def select_slow_clients(self, slow_rate):
        slow_clients = [False for i in range(self.num_clients)]
        idx = [i for i in range(self.num_clients)]
        idx_ = np.random.choice(idx, int(slow_rate * self.num_clients))
        for i in idx_:
            slow_clients[i] = True

        return slow_clients

    def set_slow_clients(self):
        self.train_slow_clients = self.select_slow_clients(
            self.train_slow_rate)
        self.send_slow_clients = self.select_slow_clients(
            self.send_slow_rate)

    def select_clients(self):
        if self.random_join_ratio:
            self.current_num_join_clients = np.random.choice(range(self.num_join_clients, self.num_clients+1), 1, replace=False)[0]
        else:
            self.current_num_join_clients = len(self.clients)
        selected_clients = list(np.random.choice(self.clients, self.current_num_join_clients, replace=False))

        return selected_clients

    def send_models(self):
        assert (len(self.clients) > 0)

        for client in self.clients:
            start_time = time.time()
            
            client.set_parameters(self.global_model)
            client.send_time_cost['num_rounds'] += 1
            client.send_time_cost['total_cost'] += 2 * (time.time() - start_time)

    def receive_models(self):
        assert (len(self.selected_clients) > 0)

        active_clients = random.sample(
            self.selected_clients, int((1-self.client_drop_rate) * self.current_num_join_clients))

        self.uploaded_ids = []
        self.uploaded_weights = []
        self.uploaded_models = []
        tot_samples = 0
        for client in active_clients:
            try:
                client_time_cost = client.train_time_cost['total_cost'] / client.train_time_cost['num_rounds'] + \
                        client.send_time_cost['total_cost'] / client.send_time_cost['num_rounds']
            except ZeroDivisionError:
                client_time_cost = 0
            if client_time_cost <= self.time_threthold:
                tot_samples += client.train_samples
                self.uploaded_ids.append(client.id)
                self.uploaded_weights.append(client.train_samples)
                self.uploaded_models.append(client.model)
        for i, w in enumerate(self.uploaded_weights):
            self.uploaded_weights[i] = w / tot_samples

    def aggregate_parameters(self):
        assert (len(self.uploaded_models) > 0)

        self.global_model = copy.deepcopy(self.uploaded_models[0])
        for param in self.global_model.parameters():
            param.data.zero_()

        for w, client_model in zip(self.uploaded_weights, self.uploaded_models):
            self.add_parameters(w, client_model)

    def add_parameters(self, w, client_model):
        for server_param, client_param in zip(self.global_model.parameters(), client_model.parameters()):
            server_param.data += client_param.data.clone() * w

    def save_global_model(self):
        model_path = os.path.join("models_seed"+str(self.args.seed_num), self.dataset)
        if not os.path.exists(model_path):
            os.makedirs(model_path)
        if(self.args.learning_state!="retrain"):
            if(self.history_update[-1]):
                history_path=os.path.join(model_path,"history")
                if not os.path.exists(history_path):
                    os.makedirs(history_path)
                history_path=os.path.join(history_path,((''.join(map(str, self.args.unlearning_clients))
                                                        +"_attack_client") if self.args.attack=='True' else "_client") + ".pt")
                torch.save(self.history_update,history_path)
            
            if(self.attacker):
                attacker_path=os.path.join(model_path,("Backdoor_" if self.args.attack=='True' else "noBackdoor_") + "xgb_model.bin")
                self.attacker.save_model(attacker_path)
            

            model_path = os.path.join(model_path, ((''.join(map(str, self.args.unlearning_clients)) + 
                                      "_attack_server" ) if self.args.attack=='True' else "_server") + ".pt")
            
        else:
            model_path = os.path.join(model_path, "retrain_model_"+('_'.join(map(str, self.args.unlearning_clients)) 
                                                                   if self.args.attack == 'True' else '') + ".pt")
        torch.save(self.global_model, model_path)

    def load_model(self):
        model_path = os.path.join("models_seed"+str(self.args.seed_num), self.dataset)
        if(self.algorithm=="FedFUKD" or self.algorithm=="FedEraser"):
            history_path=os.path.join(model_path,"history")
            history_path=os.path.join(history_path,((''.join(map(str, self.args.unlearning_clients))
                                                        +"_attack_client") if self.args.attack=='True' else "_client") + ".pt")
            assert (os.path.exists(history_path))
            self.history_update=torch.load(history_path,map_location='cpu',weights_only=False,mmap=True)[:45]
            # self.history_update = [tensor.to(self.device) for tensor in self.history_update]

        attacker_path=os.path.join(model_path,("Backdoor_" if self.args.attack=='True' else "noBackdoor_") + "xgb_model.bin")
        self.attacker.load_model(attacker_path)

        model_path = os.path.join(model_path, ((''.join(map(str, self.args.unlearning_clients)) + 
                                      "_attack_server" ) if self.args.attack=='True' else "_server") + ".pt")
        print('model path',model_path)
        assert (os.path.exists(model_path))
        
        self.global_model=torch.load(model_path,map_location=self.device,weights_only=False)

        

    def model_exists(self):
        model_path = os.path.join("models_"+self.args.seed, self.dataset)
        model_path = os.path.join(model_path, (''.join(map(str, self.args.unlearning_clients)) + 
                                  "_attack_server") if self.args.attack=='True' else "_server" + ".pt")
        return os.path.exists(model_path)
        
    def save_results(self):
        algo = self.dataset + "_" + self.algorithm
        result_path = "../results/"
        if not os.path.exists(result_path):
            os.makedirs(result_path)

        if (len(self.rs_test_acc)):
            algo = algo + "_" + self.goal 
            # + "_" + str(self.times)
            file_path = result_path + "{}.h5".format(algo)
            print("File path: " + file_path)

            with h5py.File(file_path, 'w') as hf:
                hf.create_dataset('rs_test_acc', data=self.rs_test_acc)
                hf.create_dataset('rs_test_auc', data=self.rs_test_auc)
                hf.create_dataset('rs_train_loss', data=self.rs_train_loss)

    def save_item(self, item, item_name):
        if not os.path.exists(self.save_folder_name):
            os.makedirs(self.save_folder_name)
        torch.save(item, os.path.join(self.save_folder_name, "server_" + item_name + ".pt"))

    def load_item(self, item_name):
        return torch.load(os.path.join(self.save_folder_name, "server_" + item_name + ".pt"))

    def test_metrics(self):
        if self.eval_new_clients and self.num_new_clients > 0:
            self.fine_tuning_new_clients()
            return self.test_metrics_new_clients()
        
        num_samples = []
        tot_correct = []
        tot_auc = []

        att_correct=[]
        att_num_samples=[]

        # tag_correct=[]
        # tag_num_samples=[]

        for c in self.clients:
            if(c.unlearning==False):
                ct, ns, auc = c.test_metrics()
                tot_correct.append(ct*1.0)
                tot_auc.append(auc*ns)
                num_samples.append(ns)
            
        for c in self.unlearning_clients:
            ct, ns, auc = c.test_metrics()
            att_correct.append(ct*1.0)
            att_num_samples.append(ns)

        

        ids = [c.id for c in self.clients]

        return ids, num_samples, tot_correct, tot_auc,att_num_samples,att_correct

    def train_metrics(self):
        if self.eval_new_clients and self.num_new_clients > 0:
            return [0], [1], [0]
        
        num_samples = []
        forget_num_sample = []
        losses = []
        forget_acc=[]
        retain_acc=[]
        for c in self.clients:
            if(not c.unlearning):
                cl, ns, ct= c.train_metrics()
                num_samples.append(ns)
                retain_acc.append(ct*1.0)
                losses.append(cl*1.0)
        if(self.args.learning_state=='retrain'):
            self.send_models_target()
        for c in self.unlearning_clients:
            cl, ns, ct= c.train_metrics()
            forget_num_sample.append(ns)
            forget_acc.append(ct*1.0)

        ids = [c.id for c in self.clients]

        return ids, num_samples, losses, forget_num_sample, forget_acc, retain_acc

    # evaluate selected clients
    def evaluate(self, acc=None, loss=None):
        stats = self.test_metrics()
        stats_train = self.train_metrics()

        test_acc = sum(stats[2])*1.0 / sum(stats[1])
        test_auc = sum(stats[3])*1.0 / sum(stats[1])
        train_loss = sum(stats_train[2])*1.0 / sum(stats_train[1])
        accs = [a / n for a, n in zip(stats[2], stats[1])]
        aucs = [a / n for a, n in zip(stats[3], stats[1])]

        attack_acc = sum(stats[5])*1.0 / sum(stats[4]) if self.unlearning_clients else 0
        retain_acc = sum(stats_train[5])*1.0 / sum(stats_train[1])
        forget_acc = sum(stats_train[4])*1.0 / sum(stats_train[3]) if self.unlearning_clients else 0
        # target_acc = sum(stats[7])*1.0 / sum(stats[6])
        
        if acc == None:
            self.rs_test_acc.append(test_acc)
        else:
            acc.append(test_acc)
        
        if loss == None:
            self.rs_train_loss.append(train_loss)
        else:
            loss.append(train_loss)

        self.attack_acc.append(attack_acc)

        print("Averaged Train Loss: {:.4f}".format(train_loss))
        print("Averaged Test Accurancy: {:.4f}".format(test_acc))
        print("Averaged Test AUC: {:.4f}".format(test_auc))
        print("Averaged Retain Accurancy: {:.4f}".format(retain_acc))
        print("Averaged Forget Accurancy: {:.4f}".format(forget_acc))
        # self.print_(test_acc, train_acc, train_loss)
        print("Std Test Accurancy: {:.4f}".format(np.std(accs)))
        print("Std Test AUC: {:.4f}".format(np.std(aucs)))
        print("Average Attack Accurancy:{:.4f}".format(attack_acc))
        # print("Average F-Accurancy:{:.4f}".format(stats[-2]))
        # print("Average C-Accurancy:{:.4f}".format(stats[-1]))
        # print("Average Target Client Accurancy:{:.4f}".format(target_acc))
        # if(isUnlearning):
        #     if(self.attack_model==None):
        #         self.attack_model=train_attack_model()
            
        #     print("Averaged KL Divergency: {:.4f}".format(KL_Divergency))



    def print_(self, test_acc, test_auc, train_loss):
        print("Average Test Accurancy: {:.4f}".format(test_acc))
        print("Average Test AUC: {:.4f}".format(test_auc))
        print("Average Train Loss: {:.4f}".format(train_loss))

    def check_done(self, acc_lss, top_cnt=None, div_value=None):
        for acc_ls in acc_lss:
            if top_cnt != None and div_value != None:
                find_top = len(acc_ls) - torch.topk(torch.tensor(acc_ls), 1).indices[0] > top_cnt
                find_div = len(acc_ls) > 1 and np.std(acc_ls[-top_cnt:]) < div_value
                if find_top and find_div:
                    pass
                else:
                    return False
            elif top_cnt != None:
                find_top = len(acc_ls) - torch.topk(torch.tensor(acc_ls), 1).indices[0] > top_cnt
                if find_top:
                    pass
                else:
                    return False
            elif div_value != None:
                find_div = len(acc_ls) > 1 and np.std(acc_ls[-top_cnt:]) < div_value
                if find_div:
                    pass
                else:
                    return False
            else:
                raise NotImplementedError
        return True

    def call_dlg(self, R):
        # items = []
        cnt = 0
        psnr_val = 0
        for cid, client_model in zip(self.uploaded_ids, self.uploaded_models):
            client_model.eval()
            origin_grad = []
            for gp, pp in zip(self.global_model.parameters(), client_model.parameters()):
                origin_grad.append(gp.data - pp.data)

            target_inputs = []
            trainloader = self.clients[cid].load_train_data()
            with torch.no_grad():
                for i, (x, y) in enumerate(trainloader):
                    if i >= self.batch_num_per_client:
                        break

                    if type(x) == type([]):
                        x[0] = x[0].to(self.device)
                    else:
                        x = x.to(self.device)
                    y = y.to(self.device)
                    output = client_model(x)
                    target_inputs.append((x, output))

            d = DLG(client_model, origin_grad, target_inputs)
            if d is not None:
                psnr_val += d
                cnt += 1
            
            # items.append((client_model, origin_grad, target_inputs))
                
        if cnt > 0:
            print('PSNR value is {:.2f} dB'.format(psnr_val / cnt))
        else:
            print('PSNR error')

        # self.save_item(items, f'DLG_{R}')

    def set_new_clients(self, clientObj):
        for i in range(self.num_clients, self.num_clients + self.num_new_clients):
            train_data = read_client_data(self.dataset, i, is_train=True)
            test_data = read_client_data(self.dataset, i, is_train=False)
            client = clientObj(self.args, 
                            id=i, 
                            train_samples=len(train_data), 
                            test_samples=len(test_data), 
                            train_slow=False, 
                            send_slow=False)
            self.new_clients.append(client)

    # fine-tuning on new clients
    def fine_tuning_new_clients(self):
        for client in self.new_clients:
            client.set_parameters(self.global_model)
            opt = torch.optim.SGD(client.model.parameters(), lr=self.learning_rate)
            CEloss = torch.nn.CrossEntropyLoss()
            trainloader = client.load_train_data()
            client.model.train()
            for e in range(self.fine_tuning_epoch_new):
                for i, (x, y) in enumerate(trainloader):
                    if type(x) == type([]):
                        x[0] = x[0].to(client.device)
                    else:
                        x = x.to(client.device)
                    y = y.to(client.device)
                    output = client.model(x)
                    loss = CEloss(output, y)
                    opt.zero_grad()
                    loss.backward()
                    opt.step()

    # evaluating on new clients
    def test_metrics_new_clients(self):
        num_samples = []
        tot_correct = []
        tot_auc = []
        for c in self.new_clients:
            ct, ns, auc = c.test_metrics()
            tot_correct.append(ct*1.0)
            tot_auc.append(auc*ns)
            num_samples.append(ns)

        ids = [c.id for c in self.new_clients]

        return ids, num_samples, tot_correct, tot_auc
    
    def send_models_target(self):
        assert (len(self.unlearning_clients) > 0)
        # 向target client send 模型
        for client in self.unlearning_clients:
            start_time = time.time()
            
            client.set_parameters(self.global_model)
    
    def receive_models_target(self):
        assert (len(self.unlearning_clients) > 0)

        self.uploaded_ids = []
        self.uploaded_weights = []
        self.uploaded_models = []
        tot_samples = 0
        for client in self.unlearning_clients:
            try:
                client_time_cost = client.train_time_cost['total_cost'] / client.train_time_cost['num_rounds'] + \
                        client.send_time_cost['total_cost'] / client.send_time_cost['num_rounds']
            except ZeroDivisionError:
                client_time_cost = 0
            if client_time_cost <= self.time_threthold:
                tot_samples += client.train_samples
                self.uploaded_ids.append(client.id)
                self.uploaded_weights.append(client.train_samples)
                self.uploaded_models.append(client.model)
        for i, w in enumerate(self.uploaded_weights):
            self.uploaded_weights[i] = w / tot_samples
    
    def save_unlearning(self,mia_pre):
        entry = {
            "name":self.algorithm,
            "test accuracy": self.rs_test_acc,
            "attack accuracy":self.attack_acc,
            "time":self.unlearn_Budget,
            "MIA attack precision":mia_pre
        }
        result_path = "../results/json_file/"

        if not os.path.exists(result_path):
            os.makedirs(result_path)
        algo = self.dataset + "_" + self.algorithm+"_"+str(self.args.attack)

        file_path=result_path+"{}.json".format(algo)
        with open(file_path, 'w') as file:
            json.dump(entry, file, indent=2)

    def save_loss(self,global_rounds):
        save_dir = "loss_img"  
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        train_loss = self.rs_train_loss
        plt.figure(figsize=(12, 6))

        plt.plot(range(0, global_rounds + 1), train_loss, 
                label='train loss', 
                color='blue', 
                linewidth=2)

        plt.xlabel('Epoch', fontsize=12)
        plt.ylabel('Loss', fontsize=12)

        plt.title('loss curve', fontsize=14, pad=20)

        plt.grid(True, linestyle='--', alpha=0.7)

        plt.xlim(1, global_rounds)

        plt.legend()

        save_path = os.path.join(save_dir, 'loss_curve_'+str(self.args.algorithm)+'_'+str(self.args.dataset)+
                                 '_'+str(self.args.learning_state)+'_'+str(self.args.attack)+'.png')
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    def post_training(self,m):
        for i in range(self.post_training_ground+1):
            s_t = time.time()
            self.selected_clients = self.select_clients()
            self.send_models()
            self.send_models_target()

            
            print(f"\n-------------Round number: {i}-------------")
            print("\nEvaluate global model")
            self.evaluate()

            for client in self.selected_clients:
                client.train()

            self.receive_models()

            gm = torch.cat([p.view(-1) for p in self.global_model.parameters()], dim=0)  
            ga=gm-m

            self.aggregate_parameters()
            ngm = torch.cat([p.view(-1) for p in self.global_model.parameters()], dim=0)
            gt=(torch.dot((ngm-gm),ga)/(torch.norm(ga)+1e-4)**2 )*ga
            self.overwrite_grad(self.global_model.parameters,gt)

            self.Budget.append(time.time() - s_t)
            print('-'*25, 'time cost', '-'*25, self.Budget[-1])

        print("\nBest accuracy.")
        print(max(self.rs_test_acc))
        print("\nAverage time cost per round.")
        print(sum(self.Budget[1:])/len(self.Budget[1:]))

    def overwrite_grad(self,pp, newgrad):
        pointer=0
        for param in pp():
            num_params = param.numel()
            
            param_data = newgrad[pointer : pointer + num_params].view_as(param.data)
            param.data=param.data+(param_data)

            pointer += num_params