from sparsity_op import Topk_layer
from wireless_module import passive_wireless_model
from scheduling_policy import resource_allocation
import torch
import numpy as np
import torch.nn as nn
import architectures_torch as architectures
from utils import setup_logger, accuracy, AverageMeter, average_weights, hetero_average_weights_subnet,CrossEntropyLossCal,CrossEntropyLossCal_3sets,CrossEntropyLoss_BDKS_high,CrossEntropyLoss_BDKS_median,CrossEntropyLosssmalltobig
import logging
import torch.nn.functional as F
import os
import copy
from datasets_torch import get_dataloader
import random
from slimable_op import SlimmableConv2d
from model_optimizer_scheduler_torch import get_model_optimizer_scheduler


# torch.multiprocessing.set_sharing_strategy('file_system')
class Copymodelclass():
    def __init__(self):
        self.cloud = None
        self.classifier = None
        self.local = None
        self.local_b = None
        self.cloud_b = None

class MIA:
    def __init__(self, args):
        # read basic settings from config
        (self.save_dir, self.random_seed, self.load,self.save,
        self.arch, self.n_epochs, self.lr, self.local_lr,
        self.dataset, self.batch_size,
        self.cutting_layer, self.num_agent,
        self.max_channel,
        self.tau,
        self.no_subnetwork,self.sparsity,
        self.heteroSFL,self.no_BDKS,self.no_W2N,self.theta,self.alpha) = (args.filename,args.random_seed,args.load,args.save,
                            args.arch, args.num_epochs, args.learning_rate, args.local_lr,
                            args.dataset, args.batch_size, 
                            args.cutlayer, args.num_agent,
                            args.max_channel, 
                            args.tau,
                            args.no_subnetwork,args.sparsity,
                            args.heteroSFL,args.no_BDKS,args.no_W2N,args.theta,args.alpha)
                            
        self.fast_channel = self.max_channel
        if args.indicator=='dropdata':
            self.dropdata = True
        else:
            self.dropdata = False

        ## set seeds
        torch.manual_seed(self.random_seed)
        np.random.seed(self.random_seed)
        random.seed(self.random_seed)

        self.best_acc = 0
        ## basic training setting
        if self.local_lr == -1: # if local_lr is not set
            self.local_lr = self.lr

        ## setup save folder
        self.save_dir = "./saves/"+str(self.save_dir) + "/"
        if not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)

        # setup logger
        model_log_file = self.save_dir + '/MIA.log'
        self.logger = setup_logger('{}_logger'.format(str(self.save_dir)), model_log_file, level=logging.DEBUG)


        # SFL setting
        self.batch_size = int(round(self.batch_size/self.num_agent))
        # setup dataset
        (self.client_dataloader, 
        self.pub_dataloader, 
        self.num_class, 
        self.num_batches,
        self.client_class) = get_dataloader(self.dataset,self.logger,batch_size=self.batch_size,
                                           num_agent=self.num_agent,data_dist=args.datadist)
                    
        # get models and optimizer
        self.model = None
        (self.model,
        self.local_optimizer_list, self.local_b_optimizer_list, 
        self.cloud_b_optimizer_list, self.optimizer,
        self.train_scheduler,self.train_local_scheduler_list, 
        self.train_local_b_scheduler_list,self.train_cloud_b_scheduler_list,
        self.orig_channel) = get_model_optimizer_scheduler(self.arch,self.cutting_layer,self.max_channel,
                                self.num_agent,self.num_class,
                                self.lr,self.local_lr,
                                self.logger,self.n_epochs,no_subnetwork=self.no_subnetwork)

        if self.load:
            self.load_model(self.save_dir)

        self.comm_env = passive_wireless_model(args.comm_env,self.num_agent,args.channeldist,args.max_channel,self.random_seed)
        sample_data = torch.ones((1,3,32,32)).cuda()
        with torch.no_grad():
            base_size = self.model.local_list[0](sample_data).size()

            self.orig_channel = base_size[1]
            base_size = base_size[-1]


        self.allocation = resource_allocation(self.num_agent,self.max_channel,self.batch_size,base_size,self.orig_channel,indicator=args.indicator)
        self.loss_track = [AverageMeter() for _ in range(self.num_agent)]



    ################### Train from resume         ################### Train from resume         ################### Train from resume 
    def load_model(self,save_dir):
        # save server-side model
        self.model.cloud.load_state_dict(torch.load(self.save_dir+'/model_cloud.tar'))
        self.model.classifier.load_state_dict(torch.load(self.save_dir+'/model_clasifier.tar'))
        # save client-side model
        for i in range(len(self.model.local_list)):
            self.model.local_list[i].load_state_dict(torch.load(self.save_dir+'/model_local_list'+str(0)+'.tar'))
            self.model.local_b_list[i].load_state_dict(torch.load(self.save_dir+'/model_local_b_list'+str(0)+'.tar'))
        self.model.cloud_b_list[-1].load_state_dict(torch.load(self.save_dir+'/model_cloud_b_list'+str(-1)+'.tar'))
        
    


    ################### Train from scratch        ################### Train from scratch        ################### Train from scratch

    def optimizer_step(self,client_list):
        self.optimizer.step()
        if self.dropdata:
            for i in range(len(self.local_optimizer_list)):
                if i in client_list:
                    self.local_optimizer_list[i].step()
            for i in range(len(self.local_b_optimizer_list)):
                if i in client_list:
                    self.local_b_optimizer_list[i].step()
            for i in range(len(self.cloud_b_optimizer_list)):
                self.cloud_b_optimizer_list[i].step()
        else:
            for i in range(len(self.local_optimizer_list)):
                self.local_optimizer_list[i].step()
            for i in range(len(self.local_b_optimizer_list)):
                self.local_b_optimizer_list[i].step()
            for i in range(len(self.cloud_b_optimizer_list)):
                self.cloud_b_optimizer_list[i].step()

    def optimizer_zero_grad(self):
        self.optimizer.zero_grad()
        for i in range(len(self.local_optimizer_list)):
            self.local_optimizer_list[i].zero_grad()
        for i in range(len(self.local_b_optimizer_list)):
            self.local_b_optimizer_list[i].zero_grad()
        for i in range(len(self.cloud_b_optimizer_list)):
            self.cloud_b_optimizer_list[i].zero_grad()

    def scheduler_step(self, epoch = 0):
        self.train_scheduler.step(epoch)
        for i in range(len(self.train_local_scheduler_list)):
            self.train_local_scheduler_list[i].step(epoch)
        for client_id in range(len(self.train_local_b_scheduler_list)):
            self.train_local_b_scheduler_list[client_id].step(epoch)
        for server_bn_id in range(len(self.train_cloud_b_scheduler_list)):
            self.train_cloud_b_scheduler_list[server_bn_id].step(epoch)

    def v1_client_forward(self, x_private, client_id=0):
        
        self.model.local_list[client_id].train()
        self.model.local_b_list[client_id].train()

        my_channel = self.channel_allocation[client_id]
        z_private = self.model.local_list[client_id](x_private)
        if self.sparsity:
            #top-k
            z_private_client = self.model.sparsity_layer[client_id](z_private)
        elif self.dropdata:
            z_private_client = z_private
        else:
            self.model.local_b_list[client_id].apply(lambda m: setattr(m, 'real_out_channels', my_channel))
            z_private_client = self.model.local_b_list[client_id](z_private)
            
        if z_private_client.size(1)!=my_channel and not(self.sparsity):
            raise ValueError('The output of encoder doee not match the channel.')
        return z_private_client

    def v1_cloud_b_forward(self, z_private_client, client_id=0):
        if self.sparsity or self.dropdata:
            # no BL layer, doing nothing
            z_private = []
            channel_list = []
            channels = set([self.fast_channel])
            for channel in channels:
                for _ in range(z_private_client.size(0)):
                    channel_list.append(channel)
                z_private.append(z_private_client)
        
        elif self.no_subnetwork:
            for i in range(len(self.model.cloud_b_list)):
                self.model.cloud_b_list[i].train()
            my_channel = self.channel_allocation[client_id]
            z_private = []
            channel_list = []
            if z_private_client.size(1)!=my_channel:
                raise ValueError('The input of decoder doee not match.')
            channels = set([my_channel])
            for channel in channels:
                for _ in range(z_private_client.size(0)):
                    channel_list.append(channel)
                if channel == self.max_channel:
                    self.model.cloud_b_list[-1].apply(lambda m: setattr(m, 'real_in_channels', channel))
                    z_private.append(self.model.cloud_b_list[-1](z_private_client[:,0:channel,:,:]))
                else:
                    self.model.cloud_b_list[-2].apply(lambda m: setattr(m, 'real_in_channels', channel))
                    z_private.append(self.model.cloud_b_list[-2](z_private_client[:,0:channel,:,:]))

        elif self.heteroSFL and self.no_BDKS:
            self.model.cloud_b_list[-1].train()
            my_channel = self.channel_allocation[client_id]
            z_private = []
            channel_list = []
            if z_private_client.size(1)!=my_channel:
                raise ValueError('The input of decoder doee not match.')
            channels = set([my_channel])
            for channel in channels:
                for _ in range(z_private_client.size(0)):
                    channel_list.append(channel)
                self.model.cloud_b_list[-1].apply(lambda m: setattr(m, 'real_in_channels', channel))
                z_private.append(self.model.cloud_b_list[-1](z_private_client[:,0:channel,:,:]))
                
        elif self.heteroSFL and not(self.no_BDKS):
            self.model.cloud_b_list[-1].train()
            my_channel = self.channel_allocation[client_id]
            z_private = []
            channel_list = []
            if z_private_client.size(1)!=my_channel:
                raise ValueError('The input of decoder doee not match.')
            '''Get subnetwork BL's activations'''
            channels = self.unique_fading_channels[self.unique_fading_channels<=my_channel]
            for channel in channels:
                for _ in range(z_private_client.size(0)):
                    channel_list.append(channel)
                self.model.cloud_b_list[-1].apply(lambda m: setattr(m, 'real_in_channels', channel))
                z_private.append(self.model.cloud_b_list[-1](z_private_client[:,0:channel,:,:]))
        return z_private,channel_list

    def v1_server_forward(self,x_private, label_private,client_cat,channel_cat,sample_index_cat):
        self.model.cloud.train()
        self.model.classifier.train()
        total_loss = []
        total_losses_list = [0 for _ in range(self.num_agent)]

        '''Get the logits of different BLs'''
        channel_iter = torch.sort(torch.unique(channel_cat),descending=False).values

        channel_wise_soft_label = {channel.item():{} for channel in channel_iter}
        channel_wise_output = {channel.item():{} for channel in channel_iter}
        channel_wise_label = {channel.item():{} for channel in channel_iter}
        channel_wise_soft_label_logit = {channel.item():{} for channel in channel_iter}
        for channel in channel_iter:
            output = self.model.cloud(x_private[channel_cat==channel,:,:,:])
            output = self.classifier_forward(output)
            label_private_channel_specific = label_private[channel_cat==channel]
            client_cat_channel_specific = client_cat[(channel_cat==channel).cpu()]
            channel = channel.item()
            for client in np.unique(client_cat_channel_specific):
                channel_wise_output[channel][client] = output[client_cat_channel_specific==client,:]
                channel_wise_soft_label[channel][client] = F.softmax(output[client_cat_channel_specific==client,:]).detach().clone()
                channel_wise_label[channel][client] = label_private_channel_specific[client_cat_channel_specific==client]
                channel_wise_soft_label_logit[channel][client] = output[client_cat_channel_specific==client,:].detach().clone()

        for channel in channel_iter:
            client_cat_channel_specific = client_cat[(channel_cat==channel).cpu()]
            channel = channel.item()
            for client in np.unique(client_cat_channel_specific):
                if self.fading_channels[client]==self.fast_channel:
                    high_end_client = True
                else:
                    high_end_client = False

                if self.heteroSFL and not(self.no_BDKS):
                    
                    if self.no_W2N:
                        if high_end_client and channel == self.slow_channel and client in channel_wise_soft_label[self.fast_channel].keys():
                            # print(channel,client,'CrossEntropyLossCal')
                            criterion = CrossEntropyLossCal(self.slow_channel,self.fast_channel)
                            f_loss = criterion(channel_wise_output[channel][client], channel_wise_label[channel][client],self.client_class[client],self.logit_1,self.logit_16,channel,high_end_client,tau=self.tau)
                        elif high_end_client and channel == self.fast_channel:
                            # print(channel,client,'CrossEntropyLoss_BDKS_high')
                            criterion = CrossEntropyLoss_BDKS_high(reduction='none')
                            f_loss = criterion(channel_wise_output[channel][client], channel_wise_label[channel][client],channel_wise_soft_label[self.slow_channel][client],self.client_class[client],self.label_size_accumulated[1],np.max(self.label_size_accumulated),self.logit_16,tau=self.tau,threshold=self.theta,weight_constant=self.alpha)
                        else:
                            # print(channel,client,'CrossEntropyLossCal')
                            criterion = CrossEntropyLossCal(self.slow_channel,self.fast_channel)
                            f_loss = criterion(channel_wise_output[channel][client], channel_wise_label[channel][client],self.client_class[client],self.logit_1,self.logit_16,channel,high_end_client,tau=self.tau)
                    else:
                        '''Key Implementation of HeteroSFL START'''
                        if high_end_client and channel == self.slow_channel and client in channel_wise_soft_label[self.fast_channel].keys():
                            # print(channel,client,'smalltobig')
                            criterion = CrossEntropyLosssmalltobig(reduction='none')
                            f_loss = criterion(channel_wise_output[channel][client], channel_wise_label[channel][client],channel_wise_soft_label[self.fast_channel][client],self.logit_1)
                        elif high_end_client and channel == self.fast_channel:
                            # print(channel,client,'CrossEntropyLoss_BDKS_high')
                            criterion = CrossEntropyLoss_BDKS_high(reduction='none')
                            f_loss = criterion(channel_wise_output[channel][client], channel_wise_label[channel][client],channel_wise_soft_label[self.slow_channel][client],self.client_class[client],self.label_size_accumulated[1],np.max(self.label_size_accumulated),self.logit_16,tau=self.tau,threshold=self.theta,weight_constant=self.alpha)
                        else:
                            # print(channel,client,'CrossEntropyLossCal')
                            criterion = CrossEntropyLossCal(self.slow_channel,self.fast_channel)
                            f_loss = criterion(channel_wise_output[channel][client], channel_wise_label[channel][client],self.client_class[client],self.logit_1,self.logit_16,channel,high_end_client,tau=self.tau)
                        '''Key Implementation of HeteroSFL END'''
                elif self.sparsity or self.dropdata or (self.heteroSFL and self.no_BDKS):
                    # logit calibration
                    criterion = CrossEntropyLossCal(self.slow_channel,self.fast_channel)
                    f_loss = criterion(channel_wise_output[channel][client], channel_wise_label[channel][client],self.client_class[client],self.logit_1,self.logit_16,channel,high_end_client,tau=self.tau)
                else:
                    # baseline SFL or self.no_subnetwork
                    criterion = torch.nn.CrossEntropyLoss(reduction='none')
                    f_loss = criterion(channel_wise_output[channel][client], channel_wise_label[channel][client])

                total_loss.append(f_loss*1)
                total_losses = f_loss.detach().cpu().numpy()
                total_losses_list[client]+=(np.mean(total_losses))

                del f_loss,total_losses
        total_loss = torch.cat(total_loss, dim = 0).cuda()
        # if total_loss.size(0)!=sample_index_cat.size(0):
        #     raise ValueError('total loss size does not match the sample index!')
        total_loss = torch.mean(total_loss)
        total_loss.backward()

        return total_losses_list
        


    def classifier_forward(self,output):
        if self.arch == "resnet18" or self.arch == "resnet34":
            output = F.avg_pool2d(output, 4)
            output = output.view(output.size(0), -1)
            output = self.model.classifier(output)
        elif self.arch == "resnet20" or self.arch == "resnet32" or self.arch == "resnet110":
            output = F.avg_pool2d(output, 8)
            output = output.view(output.size(0), -1)
            output = self.model.classifier(output)
        elif self.arch == 'densenet121':
            output = F.relu(output, inplace=True)
            output = F.adaptive_avg_pool2d(output, (1, 1))
            output = torch.flatten(output, 1)
            output = self.model.classifier(output)
        elif self.arch == 'mobilenetv2':
            output = self.model.classifier(output)
            output = output.view(output.size()[0], -1)
        else:
            output = output.view(output.size(0), -1)
            output = self.model.classifier(output)
        return output


    def calibrate_client_server(self, x_private, given_channel=1):
        # self.logger.debug('cal {}'.format(given_channel))
        cloud_bn_id = given_channel
        self.copy_test_model.cloud.train()
        self.copy_test_model.local.eval()
        self.copy_test_model.classifier.eval()
        self.copy_test_model.local_b.eval()
        self.copy_test_model.cloud_b.eval()


        self.copy_test_model.local_b.apply(lambda m: setattr(m, 'real_out_channels', given_channel))
        self.copy_test_model.cloud_b.apply(lambda m: setattr(m, 'real_in_channels', given_channel))


        with torch.no_grad():
            x_private = x_private.cuda()
            z_private_client = self.copy_test_model.local(x_private)
            if self.sparsity or self.dropdata:
                z_private = z_private_client
            else:
                z_private_client = self.copy_test_model.local_b(z_private_client)
                if z_private_client.size(1)!=given_channel:
                    raise ValueError('The outpt of encoder dimension does not match the given channel size.')
                z_private = self.copy_test_model.cloud_b(z_private_client)
            output = self.copy_test_model.cloud(z_private)
        return 0

    def set_batchnorm(self,block):
        if "__len__" in dir(block):
            for layer in block:
                self.set_batchnorm(layer)
        else:
            if self.arch=='resnet20':
                for layer in block.children():
                    if isinstance(layer, nn.BatchNorm2d):
                        layer.reset_running_stats()
            else:
                if isinstance(block, nn.BatchNorm2d):
                    block.reset_running_stats()

    def check_batchnorm(self,block):
        if "__len__" in dir(block):
            for layer in block:
                self.check_batchnorm(layer)
        else:
            if self.arch=='resnet20':
                for layer in block.children():
                    if isinstance(layer, nn.BatchNorm2d):
                        self.logger.debug('{} - layer.running_mean - {}'.format(layer,layer.running_mean[0]))
                        self.logger.debug('{} - layer.running_var - {}'.format(layer,layer.running_var[0]))
            else:
                if isinstance(block, nn.BatchNorm2d):
                    self.logger.debug('{} - layer.running_mean - {}'.format(block,block.running_mean[0]))
                    self.logger.debug('{} - layer.running_var - {}'.format(block,block.running_var[0]))


    def calibrate_validate_target(self,epoch,given_channel,cal_size=1):
        # clear the running mean and var in cloud batchnorm
        self.set_batchnorm(self.copy_test_model.cloud)
        # calibrate the batch norm using all batches
        self.SFL(epoch,mode='cal',given_channel=given_channel,cal_size=cal_size)
        self.validate_target(client_id=0,given_channel=given_channel, prefix='After cal '+str(given_channel)+' with '+str(cal_size) +' clients')


    def validate_target(self, client_id=0, given_channel=1, prefix='Before cal'):
        """
        Run evaluation
        """
        # batch_time = AverageMeter()
        losses = AverageMeter()
        top1 = AverageMeter()
        top1_classes = [AverageMeter() for _ in range(self.num_class)]
        val_loader = self.pub_dataloader

        cloud_bn_id = self.channel_allocation[client_id]

        # switch to evaluate mode
        self.copy_test_model.local.eval()
        self.copy_test_model.cloud.eval()
        self.copy_test_model.classifier.eval()
        self.copy_test_model.local_b.eval()
        self.copy_test_model.cloud_b.eval()
        self.copy_test_model.local_b.apply(lambda m: setattr(m, 'real_out_channels', given_channel))
        self.copy_test_model.cloud_b.apply(lambda m: setattr(m, 'real_in_channels', given_channel))


        criterion = nn.CrossEntropyLoss()

        for i, (input, target) in enumerate(val_loader):
            input = input.cuda()
            target = target.cuda()
            # compute output
            with torch.no_grad():
                output = self.copy_test_model.local(input)
                if self.sparsity or self.dropdata:
                    pass
                else:
                    output = self.copy_test_model.local_b(output)
                    if output.size(1)!=given_channel:
                        raise ValueError('The outpt of encoder dimension does not match the given channel size.')
            
                    # self.logger.debug('size output{}'.format(output.size()))
                    output = self.copy_test_model.cloud_b(output)
                output = self.copy_test_model.cloud(output)

                if self.arch == "resnet18" or self.arch == "resnet34":
                    output = F.avg_pool2d(output, 4)
                    output = output.view(output.size(0), -1)
                    output = self.copy_test_model.classifier(output)
                elif self.arch == "resnet20" or self.arch == "resnet32" or self.arch == "resnet110":
                    output = F.avg_pool2d(output, 8)
                    output = output.view(output.size(0), -1)
                    output = self.copy_test_model.classifier(output)
                elif self.arch == 'densenet121':
                    output = F.relu(output, inplace=True)
                    output = F.adaptive_avg_pool2d(output, (1, 1))
                    output = torch.flatten(output, 1)
                    output = self.copy_test_model.classifier(output)
                elif self.arch == 'mobilenetv2':
                    output = self.copy_test_model.classifier(output)
                    output = output.view(output.size()[0], -1)
                else:
                    output = output.view(output.size(0), -1)
                    output = self.copy_test_model.classifier(output)
                loss = criterion(output, target)

            output = output.float()
            loss = loss.float()

            # measure accuracy and record loss
            prec1 = accuracy(output.data, target)[0]
            losses.update(loss.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))
            # accuracy of classes
            for label in range(self.num_class):
                prec1 = accuracy(output.data[target==label], target[target==label])[0]
                top1_classes[label].update(prec1.item(), (output.data[target==label]).size(0))


        self.logger.debug(prefix +' Test:\t'
                          'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                          'Prec@1 {top1.avg:.3f}'.format(
            client_id, loss=losses,
            top1=top1))
        for label in range(self.num_class):
            self.logger.debug(prefix +' Test (class-{}):\t'
                          'Prec@1 {top1.avg:.3f}'.format(
            label,top1=top1_classes[label]))
        
        if (self.best_acc<top1.avg) and (self.save):
            self.best_acc = top1.avg
            self.save_model()

    def save_model(self):
        # save server-side model
        torch.save(self.model.cloud.state_dict(), self.save_dir+'/model_cloud.tar')
        torch.save(self.model.classifier.state_dict(), self.save_dir+'/model_clasifier.tar')
        # save client-side model
        # client_list_check = [self.channelratio_group.index(channel_ratio) for channel_ratio in set(self.channelratio_group)]
        # for client_id in client_list_check:
        torch.save(self.model.local_list[0].state_dict(), self.save_dir+'/model_local_list'+str(0)+'.tar')
        torch.save(self.model.local_b_list[0].state_dict(), self.save_dir+'/model_local_b_list'+str(0)+'.tar')
        torch.save(self.model.cloud_b_list[-1].state_dict(), self.save_dir+'/model_cloud_b_list'+str('-1')+'.tar')

        #     # save bottleneck client-side
        #     torch.save(self.model.local_b_list[client_id].state_dict(), self.save_dir+'/model_local_b_list'+str(client_id)+'.tar')
        # # save bottleneck server-side
        #     torch.save(self.model.cloud_b_list[-1].state_dict(), self.save_dir+'/model_cloud_b_list'+str('-1')+'.tar')

    def sync_client(self):
        # sync local models
        global_weights = average_weights(self.track_channel_sync,self.model.local_list)
        for i in range(self.num_agent):
            self.model.local_list[i].load_state_dict(global_weights)

        # sync local bottleneck layer
        if self.no_subnetwork:
            high_end_clients_index = np.where(self.fading_channels==self.max_channel)[0]
            low_end_clients_index = np.where(self.fading_channels!=self.max_channel)[0]
            high_end_local_b_list = [self.model.local_b_list[i] for i in high_end_clients_index]
            low_end_local_b_list = [self.model.local_b_list[i] for i in low_end_clients_index]
            # aggregating within high-end /low-end group
            track_channel_sync_highend = [self.track_channel_sync[i] for i in high_end_clients_index]
            track_channel_sync_lowend = [self.track_channel_sync[i] for i in low_end_clients_index]
            global_weights_highend = average_weights(track_channel_sync_highend,high_end_local_b_list)
            global_weights_lowend = average_weights(track_channel_sync_lowend,low_end_local_b_list)

            for client_id in range(self.num_agent):
                if client_id in high_end_clients_index:
                    w_avg = copy.deepcopy(global_weights_highend)
                else:
                    w_avg = copy.deepcopy(global_weights_lowend)

                self.model.local_b_list[client_id].load_state_dict(w_avg)
                del w_avg

        elif self.sparsity or self.dropdata:
            pass

        else:
            global_client_b_weights = hetero_average_weights_subnet(self.track_channel_sync,self.model.local_b_list)
            for client_id in range(self.num_agent):
                w_avg = copy.deepcopy(global_client_b_weights)
                # for key in self.model.local_b_list[client_id].state_dict().keys():
                #     print(self.model.local_b_list[client_id].state_dict()[key][:,0,1,1])
                #     break
                self.model.local_b_list[client_id].load_state_dict(w_avg)

                del w_avg
        # for key in self.model.local_b_list[client_id].state_dict().keys():
        #     print(self.model.local_b_list[client_id].state_dict()[key][:,0,1,1])
        #     break

    def get_client_list_iter(self,batch_id):
        # generate the client iteration based on their sample size, evenly distributed.
        client_list = []
        if self.dropdata:
            client_list = []
            if batch_id==0:
                self.acc_batch = [len(self.client_dataloader.client_dataloader[i])*self.samplereduceratio[i]/self.num_batches for i in range(self.num_agent)]
        
            else:
                step_batch = [len(self.client_dataloader.client_dataloader[i])*self.samplereduceratio[i]/self.num_batches for i in range(self.num_agent)]
                self.acc_batch = [self.acc_batch[i]+step_batch[i] for i in range(self.num_agent)]   

            while (len(client_list)<self.num_agent) and (True in (np.array(self.acc_batch)>=1)):
                client_id = np.argsort(self.acc_batch)[-1::-1][0]
                client_list.append(client_id)
                self.acc_batch[client_id] -= 1
            return sorted(client_list)

        else:
            if batch_id==0:
                self.acc_batch = [max(1,len(self.client_dataloader.client_dataloader[i])/self.num_batches) for i in range(self.num_agent)]   
            else:
                step_batch = [len(self.client_dataloader.client_dataloader[i])/self.num_batches for i in range(self.num_agent)]
                self.acc_batch = [self.acc_batch[i]+step_batch[i] for i in range(self.num_agent)]   

            while (len(client_list)<self.num_agent):
                client_id = np.argsort(self.acc_batch)[-1::-1][0]
                client_list.append(client_id)
                self.acc_batch[client_id] -= 1
            return sorted(client_list)



    def SFL(self,epoch,mode='train',given_channel=1,cal_size=1):
        if mode == "train":
            self.model.cloud.apply(lambda m: setattr(m, 'track_running_stats', False))
            self.model.classifier.apply(lambda m: setattr(m, 'track_running_stats', False))

            self.train_loss = -1
            self.avg_train_loss = [AverageMeter() for _ in range(self.num_agent)]
            for batch_id in range(self.num_batches):
                self.batch_id = batch_id
                client_list_iter = self.get_client_list_iter(batch_id)
                channel_cat = []
                label_cat = []
                ia_cat = []
                client_cat = []
                sample_index_cat = [-1]
                if len(client_list_iter) != 0:
                    '''Aggregate activations from clients'''
                    for client_id in client_list_iter:
                        images, labels = self.client_dataloader.load_data(client_id,self.adjusted_batch_size[client_id])
                        if (images.size(0)!=self.adjusted_batch_size[client_id]) or (labels.size(0)!=self.adjusted_batch_size[client_id]):
                            raise ValueError('The data size is different from the adjusted batch_size.')
                        self.track_channel_sync[client_id][self.channel_allocation[client_id]] += images.size(0)
                        ia = self.v1_client_forward(images,client_id)
                        ia_list,channel_list = self.v1_cloud_b_forward(ia,client_id)
                        channel_cat += channel_list
                        basic = sample_index_cat[-1]+1
                        for _ in set(channel_list):
                            sample_index_cat += list(range(basic,basic+int(self.adjusted_batch_size[client_id])))

                        for ia in ia_list:
                            ia_cat.append(ia)
                            label_cat.append(labels)
                            client_cat.append(client_id*torch.ones_like(labels))
                        del ia_list
                    
                    sample_index_cat.pop(0)
                    sample_index_cat = torch.tensor(sample_index_cat).cuda()
                    channel_cat = torch.tensor(channel_cat).cuda()
                    ia_cat = torch.cat(ia_cat, dim = 0).cuda()
                    label_cat = torch.cat(label_cat, dim = 0).cuda()
                    client_cat = torch.cat(client_cat, dim = 0).cpu()


                    if (not (self.heteroSFL and not(self.no_BDKS))):
                        if (ia_cat.size(0)!=np.sum(self.adjusted_batch_size[client_list_iter]) or label_cat.size(0)!=np.sum(self.adjusted_batch_size[client_list_iter])):
                            self.logger.error(np.sum(self.adjusted_batch_size[[client_list_iter]]),ia_cat.size(0),label_cat.size(0))
                            raise ValueError("Intermidiate activation batch size does not match the ajusted batch size!")
                    else:
                        total_ia_size = 0
                        for channel_item in self.unique_fading_channels:
                            total_ia_size += np.sum(np.array(self.adjusted_batch_size[client_list_iter]*(self.channel_allocation[client_list_iter]>=channel_item)))

                        if (ia_cat.size(0)!=total_ia_size or label_cat.size(0)!=total_ia_size):
                            self.logger.error(total_ia_size,ia_cat.size(0),label_cat.size(0))
                            raise ValueError("Intermidiate activation batch size does not match the ajusted batch size!")

                    if (sample_index_cat.size(0)!=ia_cat.size(0)) or (ia_cat.size(0)!=label_cat.size(0)) or (ia_cat.size(0)!=client_cat.size(0)):
                        self.logger.error(sample_index_cat.size(0),ia_cat.size(0),label_cat.size(0),client_cat.size(0))
                        raise ValueError("Intermidiate activation batch size does not match the sample index length!")


                    '''Feed activations after BL decoder to server'''
                    self.train_loss = self.v1_server_forward(ia_cat, label_cat,client_cat,channel_cat,sample_index_cat)
                    for client in range(self.num_agent):
                        self.avg_train_loss[client].update(float(self.train_loss[client]),1)
                
                self.optimizer_step(client_list_iter)
                self.optimizer_zero_grad()

            for client_id in range(self.num_agent): 
                self.logger.debug("log--[{}/{}][client-{}] batch size: {} train loss: {:1.4f} (avg: {:1.4f} )".format(
                    epoch, self.n_epochs, client_id, int(np.sum(self.adjusted_batch_size)), self.train_loss[client_id],self.avg_train_loss[client_id].avg))

            self.sync_client()

        elif mode == "cal":
            self.copy_test_model.cloud.apply(lambda m: setattr(m, 'track_running_stats', True))
            self.copy_test_model.classifier.apply(lambda m: setattr(m, 'track_running_stats', True))
            for batch_id in range(self.num_batches):
                client_list_iter = self.get_client_list_iter(batch_id)
                sample_cat = []
                if len(client_list_iter) != 0:
                    for client_id in client_list_iter:
                        images, _ = self.client_dataloader.load_data(client_id,self.batch_size)
                        sample_cat.append(images)
                    sample_cat = torch.cat(sample_cat, dim = 0).cuda()                
                    self.calibrate_client_server(sample_cat, given_channel=given_channel)
        
    def compute_label_size(self,client_class,fading_channels):
        # always with 1 channel
        unique_fading_channels = np.unique(np.append(fading_channels,self.slow_channel))
        # unique_fading_channels = np.unique(fading_channels)
        label_size = np.array([[0.0  for _ in range(self.num_class)] for _ in unique_fading_channels])
        label_size_accumulated = np.array([[0.0  for _ in range(self.num_class)] for _ in unique_fading_channels])
        for client_id in range(len(fading_channels)):
            
            label_size[(unique_fading_channels==fading_channels[client_id]),:] += client_class[client_id]
            label_size_accumulated[(unique_fading_channels<=fading_channels[client_id]),:] += client_class[client_id]
        return label_size,label_size_accumulated,unique_fading_channels


    def sparsity_level_compute(self):
        self.sparsity_level = (2-self.fading_channels/self.orig_channel)/2
        self.sparsity_level[self.sparsity_level<=0.5] = 0

    def __call__(self, log_frequency=500, verbose=False, progress_bar=True):

        #Main Training
        for epoch in range(1, self.n_epochs+1):
            self.logit_1 = {label:[] for label in range(self.num_class)}
            self.logit_16 = {label:[] for label in range(self.num_class)}


            self.fading_channels,self.slow_channel = self.comm_env()


            self.label_size,self.label_size_accumulated,self.unique_fading_channels = self.compute_label_size(self.client_class,self.fading_channels)
            self.adjusted_batch_size, self.channel_allocation, self.time_per_round, self.samplereduceratio = self.allocation.allocation(self.fading_channels,predefined=[])


            if self.dropdata:
                self.track_channel_sync = [{(channel):0 for channel in range(self.orig_channel+1)} for _ in range(self.num_agent)]
            else:
                self.track_channel_sync = [{(channel):0 for channel in range(self.max_channel+1)} for _ in range(self.num_agent)]
            if self.sparsity:
                # compute necessary sparsity level
                self.sparsity_level_compute()
                self.model.sparsity_layer = [Topk_layer(sparity_level) for sparity_level, client in zip(self.sparsity_level,range(self.num_agent))]

            self.logger.info('Epoch {}:'.format(epoch))
            self.logger.debug('self.adjusted_batch_size {}'.format(self.adjusted_batch_size))
            self.logger.debug('self.channel_allocation {}'.format(self.channel_allocation))

            self.scheduler_step(epoch)
            '''Train'''        
            if not(self.load):
                self.SFL(epoch)
                
                if self.fast_channel in self.unique_fading_channels:
                    self.logit_1 = {label:[np.mean(self.logit_1[label]),np.std(self.logit_1[label]),len(self.logit_1[label])] for label in range(self.num_class)}
                    self.logit_16 = {label:[np.mean(self.logit_16[label]),np.std(self.logit_16[label]),len(self.logit_16[label])] for label in range(self.num_class)}

                    for key in self.logit_1.keys():
                        self.logger.debug('self.logit_1-{} -mean {:.3f} -std {:.3f} -len {}'.format(key,self.logit_1[key][0],self.logit_1[key][1],self.logit_1[key][2]))
                    for key in self.logit_16.keys():
                        self.logger.debug('self.logit_16-{} -mean{:.3f}-std{:.3f} -len {}'.format(key,self.logit_16[key][0],self.logit_16[key][1],self.logit_16[key][2]))

            # if (epoch == 199) and self.save:
            #     self.save_model()

            '''Test, including the batch norm calibration'''
            if epoch<10 or epoch >100:
                self.copy_test_model = Copymodelclass()
                self.copy_test_model.cloud = copy.deepcopy(self.model.cloud)
                self.copy_test_model.classifier = copy.deepcopy(self.model.classifier)
                self.copy_test_model.local = copy.deepcopy(self.model.local_list[0])
                if self.no_subnetwork:
                    self.copy_test_model.local_b = copy.deepcopy(self.model.local_b_list[1])
                    self.copy_test_model.cloud_b = copy.deepcopy(self.model.cloud_b_list[1])
                else:
                    self.copy_test_model.local_b = copy.deepcopy(self.model.local_b_list[0])
                    self.copy_test_model.cloud_b = copy.deepcopy(self.model.cloud_b_list[0])

                for channel_size in [self.slow_channel,self.fast_channel]:
                # for channel_size in np.unique([min(self.unique_fading_channels),max(self.unique_fading_channels)]):
                    self.calibrate_validate_target(epoch,channel_size,cal_size=self.num_agent)
                del self.copy_test_model
        return 0