import copy
import logging
import random
import time
import numpy as np
import torch
from utils import transform_list_to_tensor
from client import Client


# 此文件是单机训练时才会使用
class FDIL(object):
    def __init__(self, dataset, device, args, model_trainer):
        self.device = device
        self.args = args
        [local_num_dict, train_data_local_dict, test_data_local_dict, \
        incremental_train_data, incremental_test_data, class_num] = dataset
        self.client_indexes = []
        self.client_list = []
        self.incremental_train_data = incremental_train_data  #每个client在cifar10中，有4份增量
        self.incremental_test_data = incremental_test_data
        self.train_data_local_num_dict = local_num_dict
        self.train_data_local_dict = train_data_local_dict
        self.test_data_local_dict = test_data_local_dict
        self.model_dict = dict()
        self.model_trainer = model_trainer
        self.train_acc = []
        self.test_acc = []
        self._setup_clients(local_num_dict, train_data_local_dict, test_data_local_dict, incremental_train_data,incremental_test_data,model_trainer)
        
        self.activation_dict = dict()
        self.orth_set = {}
        if self.args.model == "resnet_fot_bn":
            self.orth_layer_names = ['conv1.weight', 'layer1.0.conv1.weight', 'layer1.0.conv2.weight', 'layer1.1.conv1.weight', \
    'layer1.1.conv2.weight', 'layer2.0.conv1.weight', 'layer2.0.conv2.weight', 'layer2.0.shortcut.0.weight', \
        'layer2.1.conv1.weight', 'layer2.1.conv2.weight', 'layer3.0.conv1.weight', 'layer3.0.conv2.weight', \
            'layer3.0.shortcut.0.weight', 'layer3.1.conv1.weight', 'layer3.1.conv2.weight', 'layer4.0.conv1.weight', \
                'layer4.0.conv2.weight', 'layer4.0.shortcut.0.weight', 'layer4.1.conv1.weight', 'layer4.1.conv2.weight']
        elif self.args.model == "alexnet_fot":
            self.orth_layer_names = ['conv1.weight', 'conv2.weight', 'conv3.weight', 'fc1.weight', 'fc2.weight']

        for name in self.orth_layer_names:
            self.orth_set[name] = None
        self.epsilon = args.epsilon

    def _setup_clients(self, train_data_local_num_dict, train_data_local_dict, test_data_local_dict,incremental_train_data,incremental_test_data,model_trainer):
        print("############setup_clients (START)#############")
        for client_idx in range(self.args.client_num_in_total):
            # Test with all samples
            c = Client(client_idx, train_data_local_dict[client_idx], test_data_local_dict, train_data_local_num_dict[client_idx], 
                                        incremental_train_data[client_idx], incremental_test_data[client_idx], self.args, self.device, model_trainer)
            self.client_list.append(c)
        print("############setup_clients (END)#############")

    def train(self,savepath):

        start_time = time.time()
        task_used_B=[]
        for round_idx in range(self.args.comm_round):
            
            node1_time = time.time()
            print("到达round: "+str(round_idx)+" 耗时:"+str(node1_time - start_time)+"秒")

            w_global = self.model_trainer.get_model_params()

            print("################Communication round : {}".format(round_idx))
            w_locals = []

            # 计算当前任务中的当前轮数  
            current_round_in_task = round_idx % self.args.incremental_round 
            ta_id = round_idx // self.args.incremental_round 

            # incremental learning (update)
            if round_idx % self.args.incremental_round == 0 and round_idx != 0:
                print("Start updating each client dataset by incremental learning! " )
                for client in self.client_list:
                    client.update_incremental(w_global,ta_id)

                print("Finishing updating! " )

            #选择客户端
            self._client_sampling(round_idx, self.args.client_num_in_total,
                                                   self.args.client_num_per_round,self.client_list)
                                                   
            print("client_indexes = " + str(self.client_indexes))
            
            client_runtime_info={}
            print("-------model actually train------")
            # choose client 训练
            if current_round_in_task == 0:#新任务
                task_used_B= [0] * self.args.client_num_in_total

            for idx in self.client_indexes:#w每次用全局模型聚合后的还是聚合前每个客户端自己的
                client_idx = idx
                for i in self.client_list:
                    if i.client_idx == client_idx:
                        client = i

                weight,new_B = client.train(copy.deepcopy(w_global),ta_id,task_used_B[client_idx])
                task_used_B[client_idx]=new_B
                activations = None 
                if (round_idx+1) % self.args.incremental_round == 0:
                    activations = client.get_act(self.orth_set)
                w_locals.append((client.get_sample_number(), copy.deepcopy(weight)))
                client_runtime_info[client_idx] = activations
            self.activation_dict.update(client_runtime_info)

            w_global = self._aggregate(w_locals)
            if round_idx % 1 == 0:
                self._local_test_on_all_clients(round_idx)
                print(task_used_B)

            # if round_idx % 1 == 0:
            #     f = open(str(self.args.dataset)+"_re/tu_dan0418_copy3_"+str(self.args.dataset)+"_a="+str(self.args.alpha)+"_model_"+str(self.args.model)+"_lr="+str(self.args.lr)+"_size="+str(self.args.memory_size+1)+".txt",'w')
            #     for i in range(len(self.train_acc)):
            #         f.write("train acc:"+str(self.train_acc[i])+" "+"test acc:"+str(self.test_acc[i])+'\n')
            #     f.close()
            #     print( str(round_idx)+": train acc:"+str(self.train_acc[i])+" "+"test acc:"+str(self.test_acc[i])+'\n')


            if (self.args.incremental_round - 10) <= current_round_in_task < self.args.incremental_round:
                pa = self.model_trainer.get_model_params()
                torch.save(pa,savepath+'/model_%d_%d.pth'%(ta_id,current_round_in_task))
        end_time = time.time()
        print("end耗时:"+str(end_time - start_time)+"秒")
  
    def _client_sampling(self, round_idx, client_num_in_total, client_num_per_round,client_list):
        if client_num_in_total == client_num_per_round:
            self.client_indexes = [client_index for client_index in range(client_num_in_total)]
        else:
            num_clients = min(client_num_per_round, client_num_in_total)
            np.random.seed(round_idx)  
            self.client_indexes = np.random.choice(range(client_num_in_total), num_clients, replace=False)

#ADD    ExpandOrthogonalSet
    def expand_orth_set(self):
        print("Expand orth set")
        ratios = {}
        num_samples = {}
        activations = {}
        act_list = list(self.activation_dict.values())
        keys = act_list[0].keys()
        for k in keys:
            for i in range(0, len(act_list)):
                local_act = act_list[i]
                if i == 0:
                    activations[k] = local_act[k][0]
                    ratios[k] = [local_act[k][1]]
                    num_samples[k] = [local_act[k][2]]
                else:
                    activations[k] += local_act[k][0]
                    ratios[k].append(local_act[k][1])
                    num_samples[k].append(local_act[k][2])

        for key in activations.keys():
            weights = np.array(num_samples[key]) / np.sum(num_samples[key])
            weighted_avg = np.sum(weights * np.array(ratios[key]))
            org_eps = self.epsilon
            new_eps = (weighted_avg - (1 - org_eps)) / weighted_avg
            #find svds of remaining
            U, S, V = torch.svd(activations[key])
            #find how many singular vectors will be used
            total = torch.norm(activations[key])**2 
            for i in range(len(S)):
                hand = torch.norm(S[0:i+1])**2
                if hand / total > new_eps:
                    break

            if self.orth_set[key] == None:
                self.orth_set[key] = U[:,0:i+1]
            else:
                self.orth_set[key] = torch.cat((self.orth_set[key], U[:,0:i+1]),dim=1)
            
            self.orth_set[key], _ = torch.qr(self.orth_set[key])
        self.epsilon += self.args.eps_inc
#

    # Server端的聚合
    def _aggregate(self, w_locals):
        training_num = 0
        for idx in range(len(w_locals)):
            (sample_num, averaged_params) = w_locals[idx]
            training_num += sample_num
        # 更新初步集成模型
        (sample_num, averaged_params) = w_locals[0]
        for k in averaged_params.keys():
            for i in range(0, len(w_locals)):
                local_sample_number, local_model_params = w_locals[i]
                w = local_sample_number / training_num
                if i == 0:
                    averaged_params[k] = local_model_params[k] * w
                else:
                    averaged_params[k] += local_model_params[k] * w
#ADD  FedProject
        global_params = self.model_trainer.get_model_params()#聚合前的旧全局W
        _,global_gradients = w_locals[0]

        for k in global_params.keys():
            global_gradients[k] = global_params[k] - averaged_params[k]#全局更新量 δglobal

        #Apply projected gradient descent
        for key in self.orth_layer_names:
            if self.orth_set[key] == None: continue
            if "conv" in key or "shortcut" in key:
                grad = global_gradients[key]
                projected = self.orth_set[key] @ self.orth_set[key].T @ grad.view(grad.size(0), -1).T
                global_gradients[key] = grad - projected.T.view(grad.size())
            else:
                grad = global_gradients[key]
                projected = self.orth_set[key] @ self.orth_set[key].T @ grad.T  #−Oℓ1Oℓ1Tδℓglobal
                global_gradients[key] = grad - projected.T #(6):δℓ∗global → δℓglobal −Oℓ1Oℓ1Tδℓglobal

        for k in global_params.keys():
            averaged_params[k] = global_params[k] - global_gradients[k]#(7)W ←W − µδ∗global

        # update the global model which is cached at the server side
        self.model_trainer.set_model_params(averaged_params)
        
        ##GPSE-server
        if list(self.activation_dict.values())[0] is not None:
            print("Expanding orth set")
            self.expand_orth_set()#GGPSE:ExpandOrthogonalSet
            for key in self.orth_layer_names:
                if self.orth_set[key] == None: continue
                # print(self.orth_set[key].shape)
                shape1, shape2 = self.orth_set[key].shape
                # print({f"Space/{key}": shape2/shape1})#, "task": round_idx
        self.activation_dict = dict() 
        return averaged_params

    def _local_test_on_all_clients(self, round_idx):

        print("################local_test_on_all_clients : {}".format(round_idx))

        train_metrics = {
            'num_samples': [],
            'num_correct': [],
            'losses': []
        }

        test_metrics = {
            'num_samples': [],
            'num_correct': [],
            'losses': []
        }

        # client = self.client_list[0]

        for client in self.client_list:
            """
            Note: for datasets like "fed_CIFAR100" and "fed_shakespheare",
            the training client number is larger than the testing client number
            """

            train_local_metrics = client.local_test(False)
            train_metrics['num_samples'].append(copy.deepcopy(train_local_metrics['test_total']))
            train_metrics['num_correct'].append(copy.deepcopy(train_local_metrics['test_correct']))
            train_metrics['losses'].append(copy.deepcopy(train_local_metrics['test_loss']))

            # test data
            test_local_metrics = client.local_test(True)
            test_metrics['num_samples'].append(copy.deepcopy(test_local_metrics['test_total']))
            test_metrics['num_correct'].append(copy.deepcopy(test_local_metrics['test_correct']))
            test_metrics['losses'].append(copy.deepcopy(test_local_metrics['test_loss']))


        # test on training dataset
        train_acc = sum(train_metrics['num_correct']) / sum(train_metrics['num_samples'])
        train_loss = sum(train_metrics['losses']) / sum(train_metrics['num_samples'])

        test_acc = sum(test_metrics['num_correct']) / sum(test_metrics['num_samples'])
        test_loss = sum(test_metrics['losses']) / sum(test_metrics['num_samples'])

        stats = {'training_acc': train_acc, 'training_loss': train_loss}
        # wandb.log({"Train/Acc": train_acc, "round": round_idx})
        # wandb.log({"Train/Loss": train_loss, "round": round_idx})
        print(stats)

        stats = {'test_acc': test_acc, 'test_loss': test_loss}
        # wandb.log({"Test/Acc": test_acc, "round": round_idx})
        # wandb.log({"Test/Loss": test_loss, "round": round_idx})
        print(stats)
        self.train_acc.append(train_acc)
        self.test_acc.append(test_acc)


    def test(self,num,savepath):

        model_pa = torch.load(savepath)
        self.model_trainer.set_model_params(model_pa)
        ### 选择参与训练的客户端（self.args.client_num_per_round个），索引保存在self.client_indexes
        # \(client_idx, local_training_data, local_test_data, local_sample_number, incremental_train_data,incremental_test_data,args, device, model_trainer):
        # client = Client(0, 'none','none',0,'none','none', self.args,self.device,self.model_trainer)
        # client= Client(0, train_data_local_dict[0], test_data_local_dict, train_data_local_num_dict[0], 
        #                                 incremental_train_data[0], incremental_test_data[0], self.args, self.device, model_trainer)
        client = self.client_list[0]
        test_metrics = {
            'num_samples': [],
            'num_correct': [],
            'losses': []
        }
        result_acc=[]
        client.test_update_incremental(num)
        test_local_metrics = client.local_test(True)
        test_metrics['num_samples'].append(copy.deepcopy(test_local_metrics['test_total']))
        test_metrics['num_correct'].append(copy.deepcopy(test_local_metrics['test_correct']))
        test_acc = test_local_metrics['test_correct'] /test_local_metrics['test_total']
        # test_acc_total = sum(test_metrics['num_correct']) / sum(test_metrics['num_samples'])
        stats = {'test_acc': test_acc}
        print(stats)
        result_acc.append(test_acc)
        # for i in range(num):
        #     client.test_update_incremental(i)
        #     test_local_metrics = client.local_test(True)
        #     test_metrics['num_samples'].append(copy.deepcopy(test_local_metrics['test_total']))
        #     test_metrics['num_correct'].append(copy.deepcopy(test_local_metrics['test_correct']))
        #     test_acc = test_local_metrics['test_correct'] /test_local_metrics['test_total']
        #     test_acc_total = sum(test_metrics['num_correct']) / sum(test_metrics['num_samples'])
        #     stats = {'test_acc': test_acc, 'test_acc_total': test_acc_total}
        #     print(stats)
        #     result_acc.append(test_acc)
        print('-----------------------------------------')
        return result_acc

