import logging
import random
import math
import copy
import torch
import numpy as np
import torch.nn.functional as F
import time
from utils import transform_list_to_tensor

class Client:

    def __init__(self, client_idx, local_training_data, local_test_data, local_sample_number, incremental_train_data,incremental_test_data,args, device,
                 model_trainer):
        self.client_idx = client_idx
        self.local_training_data = local_training_data
        self.local_test_data = local_test_data
        self.init_test_data = local_test_data
        self.local_sample_number = local_sample_number
        self.incremental_train_data = incremental_train_data
        self.incremental_test_data = incremental_test_data
        print("{}: self.local_sample_number = {}".format(self.client_idx, self.local_sample_number))
        self.args = args
        self.device = device
        self.model_trainer = model_trainer
        self.incremental_id = 0
        self.personal_model = self.model_trainer.get_model_params()      

        self.init_local_training_data = local_training_data
        self.init_incremental_train_data = incremental_train_data
        self.init_incremental_test_data = incremental_test_data
        self.get_lable_data()

    def get_lable_data(self):
        data,unlable_num = self.get_data_base_ratio(self.init_local_training_data)
        self.local_training_data = data
        self.local_sample_number = self.local_sample_number - unlable_num

        self.incremental_train_data = []
        for i in range(len(self.init_incremental_train_data)):
            data,_ = self.get_data_base_ratio(self.init_incremental_train_data[i])
            self.incremental_train_data.append(data)

    def get_data_base_ratio(self,init_data):
        R=self.args.lable_ratio
        num= self.args.batch_size*(len(init_data)-1) + len(init_data[-1][0])
        unlable_num = (1-R)*num
        unlable_num = math.floor(unlable_num)
        np.random.seed(123)  
        un_id =  np.random.choice( num, unlable_num,False)
        # self.local_sample_number = self.local_sample_number - unlable_num

        new_train = []
        # new_train.append([])
        current_batch_id = -1
        current_batch_data_num = self.args.batch_size +5 #超出阈值，使得在初始时调用else分支

        for i in range(len(init_data)):#batch  然后data[0][0]是x，data[0][1]是y
            for j in range(len(init_data[i][1])):#batch中数据个数
                data_id = i*self.args.batch_size+j
                if data_id not in un_id:
                    temp_x=init_data[i][0][j]
                    temp_y=init_data[i][1][j]
                    temp_x= temp_x.unsqueeze(0)
                    temp_y= temp_y.unsqueeze(0)
                    if current_batch_data_num < self.args.batch_size:    
                        new_train[current_batch_id][0]=torch.cat((new_train[current_batch_id][0],temp_x),dim=0)
                        new_train[current_batch_id][1]=torch.cat((new_train[current_batch_id][1],temp_y),dim=0)
                        current_batch_data_num+=1
                    else:
                        new_train.append([])
                        current_batch_data_num = 1
                        current_batch_id += 1
                        new_train[current_batch_id]=[temp_x,temp_y]

        return new_train,unlable_num

    def get_sample_number(self):
        return self.local_sample_number
        
    def update_incremental(self, w_global,ta_id):

        if self.incremental_id < len(self.incremental_train_data):

            num_il = self.args.batch_size*(len(self.incremental_train_data[self.incremental_id])-1) + len(self.incremental_train_data[self.incremental_id][-1][0])
            #最后一个批次可能不满，num_il代表第incremental_id份增连数据的样本量
            print("Client: " + str(self.client_idx) + " will increase " + str(num_il) + " data samples")

            delete_num = max(self.local_sample_number - self.args.memory_buffer, 0)
            #如果当前本地样本总数超过设定的内存大小（self.args.memory_size），则需要删除一些样本以保持在内存限制内。delete_num计算了需要删除的样本数。

            self.model_trainer.set_model_params(w_global)

            if delete_num > 0 : 
                print("-----------------------------------------delete reply" )
                # 设置随机种子  
                np.random.seed(123)  
                delete_indexes = np.random.choice( self.local_sample_number, delete_num,False)

                self.local_sample_number = self.args.memory_buffer + num_il
                # #旧样本M个，+，新样本
                # self.local_sample_number = self.args.memory_buffer 
                # #新旧一共M

                new_train = []
                new_train += self.incremental_train_data[self.incremental_id]

                current_batch_id = len(new_train)-1
                for i in  range(len(new_train)):
                    new_train[i] = list(new_train[i])
                current_batch_data_num = len(new_train[current_batch_id][1])
                # total = 0
                for i in range(len(self.local_training_data)):#batch  然后data[0][0]是x，data[0][1]是y
                    for j in range(len(self.local_training_data[i][1])):#batch中数据个数
                        data_id = i*self.args.batch_size+j
                        if data_id not in delete_indexes:
                            # total+=1
                            temp_x=self.local_training_data[i][0][j]
                            temp_y=self.local_training_data[i][1][j]
                            temp_x= temp_x.unsqueeze(0)
                            temp_y= temp_y.unsqueeze(0)
                            if current_batch_data_num < self.args.batch_size:    
                                new_train[current_batch_id][0]=torch.cat((new_train[current_batch_id][0],temp_x),dim=0)
                                new_train[current_batch_id][1]=torch.cat((new_train[current_batch_id][1],temp_y),dim=0)
                                current_batch_data_num+=1
                            else:
                                new_train.append([])
                                current_batch_data_num = 1
                                current_batch_id += 1
                                new_train[current_batch_id]=[temp_x,temp_y]

                self.local_training_data = new_train
                self.local_test_data =  self.local_test_data + self.incremental_test_data[self.incremental_id]
                # self.local_test(True)

            else:
                self.local_sample_number = num_il + self.local_sample_number
                
                new_train = []
                new_train += self.incremental_train_data[self.incremental_id][:-1]#去除最后一个batch
                last_new_batch =self.incremental_train_data[self.incremental_id][-1]
                new_train += self.local_training_data[:-1]#去除最后一个batch
                last_old_batch =self.local_training_data[-1]
                
                new_train.append([])
                new_train[len(new_train)-1]=(last_new_batch[0],last_new_batch[1])

                current_batch_id = len(new_train)-1
                for i in  range(len(new_train)):
                    new_train[i] = list(new_train[i])

                current_batch_data_num = len(new_train[current_batch_id][1])
                for j in range(len(last_old_batch[1])):#batch中数据个数
                    data_id = i*self.args.batch_size+j

                    temp_x=last_old_batch[0][j]
                    temp_y=last_old_batch[1][j]
                    temp_x= temp_x.unsqueeze(0)
                    temp_y= temp_y.unsqueeze(0)
                    if current_batch_data_num < self.args.batch_size:    
                        new_train[current_batch_id][0]=torch.cat((new_train[current_batch_id][0],temp_x),dim=0)
                        new_train[current_batch_id][1]=torch.cat((new_train[current_batch_id][1],temp_y),dim=0)
                        current_batch_data_num+=1
                    else:
                        new_train.append([])
                        current_batch_data_num = 1
                        current_batch_id += 1
                        new_train[current_batch_id]=[temp_x,temp_y]
                self.local_training_data = new_train

                self.local_test_data =  self.local_test_data + self.incremental_test_data[self.incremental_id]
                # self.local_test(True)
                #test data 没有合并batch
            print("Client: " + str(self.client_idx)+"  " +str(len(self.incremental_train_data[self.incremental_id]))+"   "+str(len(self.incremental_train_data[self.incremental_id][-1][0])))

            self.incremental_id += 1

        else:
            print("Client: " + str(self.client_idx) + " has no more incremental dataset")
    
    def train(self, w_global,ta_id,used_B):
        self.model_trainer.id = self.client_idx
        self.model_trainer.set_model_params(w_global)

        new_B = used_B
        num = (self.args.B / self.args.batch_size)* self.args.epochs * self.args.incremental_round
        num = math.floor(num)  
        if used_B < num:
            new_B=self.model_trainer.train(self.local_training_data, self.device, self.args,ta_id,used_B)
        else:
            print('-----Client '+str(self.client_idx)+' no train')
        weights = self.model_trainer.get_model_params()

        return weights,new_B


    def local_test(self, b_use_test_dataset):
        if b_use_test_dataset:
            test_data = self.local_test_data
        else:
            test_data = self.local_training_data
        metrics = self.model_trainer.test(test_data, self.device, self.args)
        return metrics

    def get_act(self,orth_set):
        activations = self.model_trainer.get_activations(self.local_training_data, self.device, orth_set,self.args.dataset)
        return activations
        
    def test_update_incremental(self,num):
        self.local_test_data = self.init_test_data
        for i in range(num):
            self.local_test_data =  self.local_test_data+self.incremental_test_data[i]

