import numpy as np
import copy
import torch
from torch.autograd import Variable
import time
from .ufedbase import UnlearnBasicClient, UnlearnBasicServer
import numpy as np
from utils import fmodule
from tqdm import tqdm
from collections import OrderedDict
import pickle
import os
class Server(UnlearnBasicServer):
    def __init__(self, option, model, clients, data_loader, device=None):
        self.initial_model = copy.deepcopy(model).to(device)
        super(Server, self).__init__(option, model, clients, data_loader, device)
        # unlearn config
        self.cali_round = 0
        self.update_his = OrderedDict()
        self.delta_t = 5
        # todo define self.file_path
        with open(os.path.join(os.path.join(os.path.dirname(self.save_folder), 'fedavg'), 'pretrained_history_federaser',
                               f"eraser_his_five_s{self.option['split_num']}_c_{self.option['class_num']}_bd_{self.bd}.pkl"), 'rb') as f:
            self.update_his = pickle.load(f)
        self.model = self.initial_model
        print('Update History Loaded Successfully!')
    def run(self):
        self.current_rounds = 0
        test_metric = self.test_on_clients(dataflag='test', model=self.model)
        self.outFunc(t_metric=test_metric)
        print(self.stage)
        assert self.stage in ['Unlearn', 'PT'], 'check your stage'
        if self.stage == 'Unlearn':
            for round in tqdm(range(1, self.u_rounds + 1), desc='Unlearning Rounds'):
                self.current_rounds = round

                if round % self.delta_t == 0:
                    print('un')
                    self.cali_round += 1
                    self.unlearn_iterate()
                    self.global_lr_scheduler(self.num_rounds)
                else:
                    continue
                test_metric = self.test_on_clients(dataflag='test', model=self.model)
                self.outFunc(test_metric)
                self.save_log(self.out_log)
        self.save_ckp()
        return


    def unlearn_iterate(self, ):
        self.selected_clients = np.delete(self.clients_id, np.where(np.isin(self.clients_id, self.unlearn_clients_id))[0]) # TODO: 可能问题在这

        reply = self.communicate(self.selected_clients)
        models, losses = reply['model'], reply['loss']
        # update_list = []
        # with torch.no_grad():
        #     for idx, m in enumerate(models):
        #         update = torch.nn.utils.parameters_to_vector(self.model.parameters()) - torch.nn.utils.parameters_to_vector(m.parameters())
        #         scale_norm = float(self.update_his[idx][self.current_rounds // self.delta_t - 1])
        #         update = update / torch.norm(update) * scale_norm
        #         print(torch.norm(update), scale_norm)
        #         tmp_model = copy.deepcopy(self.initial_model)
        #         torch.nn.utils.vector_to_parameters(update, tmp_model.parameters())
        #         update_list.append(tmp_model)
        self.model = self.aggregate(models)
        # del update_list, models
        # for idx, client in enumerate(self.online_client_list):
        #     update = module_params - m_locals[idx].span_model_params_to_vec()
        #     # scale
        #     scale_norm = float(historical_update_norm[client.id][self.current_comm_round - 1])
        #     update = update / torch.norm(update) * scale_norm
        #     update_list.append(update)
        return


    def get_param(self, m):
        # global_dict = self.model.state_dict()
        # local_dict = m.state_dict()
        # return {k: (local_dict[k] - global_dict[k]).cpu() for k in local_dict.keys()}
        return {k: v.cpu() for k, v in m.state_dict().items()}

    def load_param(self, m, param_dict):
        state_dict_on_gpu = {k: v.to(self.device) for k, v in param_dict.items()}
        m.load_state_dict(state_dict_on_gpu)
        m = m.to(self.device)
        return m

    def load_his_updates(self, u_idx):
        update_his_models = {}
        # k~cid, v~update_his
        for k, v in self.update_his.items():
            update_his_models[k] = self.load_param(copy.deepcopy(self.model), v[u_idx])
        ups = [update_his_models[cid] for cid in self.selected_clients]
        return ups

class Client(UnlearnBasicClient):
    def __init__(self, option, id, model=None):
        super(Client, self).__init__(option, id, model)


    # def load_ckp(self, ckp='bd'):
    #     assert ckp in ['bd', 'full', 'retrain']
    #     if ckp == 'bd':
    #         ckp_path = './unlearn_log/cifar10_classification/cnn/rfedavg/Seed_0_cifar10_rfedavg_lr_0.01_bd_True_full.pth'
    #     elif ckp == 'full':
    #         ckp_path = './unlearn_log/cifar10_classification/cnn/rfedavg/Seed_0_cifar10_rfedavg_lr_0.01_bd_False_full.pth'
    #     elif ckp == 'retrain':
    #         ckp_path = './unlearn_log/cifar10_classification/cnn/rfedavg/Seed_0_cifar10_rfedavg_lr_0.01_bd_False_retrain.pth'
    #     if not os.path.exists(ckp_path):
    #         raise FileNotFoundError(f"Checkpoint file '{ckp_path}' not found")
    #     return torch.load(ckp_path)

    # def run(self):
    #     self.selected_clients = self.sample()
    #     reply = self.communicate(self.selected_clients)
    #     models, losses = reply['model'], reply['loss']
    #     del_idx = np.where(self.selected_clients == 8)[0][0]
    #
    #     unlearn_model = {8: copy.deepcopy(models[del_idx])}
    #     del models[del_idx]
    #     self.model = self.aggregate(models)
    #     self.broadcast()
    #     # test_metric = self.test_on_clients()
    #     # all_acc = test_metric['accuracy']
    #     # print('before', all_acc)
    #
    #     self.clients[8].unlearn = True
    #     # model_ref = self.aggregate(models)
    #     # model_ref_vec = torch.nn.utils.parameters_to_vector(model_ref.parameters())
    #     dist_ref_random_list = []
    #     for _ in range(10):
    #         dist_ref_random_list.append(get_distance(self.model, self.initial_model).cpu().numpy())
    #     threshold = np.mean(dist_ref_random_list) / 3
    #     model_ref = copy.deepcopy(self.model)
    #
    #     for c in self.clients:
    #         if c.unlearn:
    #             c.model_ref = model_ref
    #             c.threshold = 0.05
    #             c.unlearn_model = unlearn_model[c.id]
    #
    #     # ____________unlearning_____________#
    #     # for round in tqdm(range(1, self.u_rounds + 1), desc='Unlearning Rounds'):
    #     # federated unlearn
    #     self.iterate(u=True)
    #     self.broadcast()
    #     test_metric = self.test_on_clients()
    #     all_acc = test_metric['accuracy']
    #     print('All_acc after unlearning: \n', all_acc)
    #     print('ASR before unlearning: \n', self.bd_acc)
    #     print('ASR after unlearning: \n', all_acc[8])
    #     del all_acc[8]
    #     print('\n Mean_Retain_acc after unlearning', np.mean(all_acc))
    #     # ____________retraining_____________#
    #     for round in tqdm(range(1, self.r_rounds + 1), desc='Retraining Rounds'):
    #         self.current_round = round
    #         for client in self.clients:
    #             client.current_rounds = self.current_round
    #
    #         # federated unlearn
    #         self.iterate(u=False)
    #         self.broadcast()
    #         test_metric = self.test_on_clients()
    #         all_acc = test_metric['accuracy']
    #         loss = test_metric['loss']
    #         self.outFunc(round, all_acc, loss, self.h_acc, self.h_loss)
    #         self.save_log(self.stream_log)
    #         # decay learning rate
    #         self.global_lr_scheduler(round)
    #     return
