from MODEL import *
from models.calibration import *
from models.fold_bn import *
from models.resnet import *
from models.spiking_layer import *


class Clients:
    def __init__(self, train_dataset, dataidx, device, id, model_save, eval_loader, malicious=False):
        self.device = device

        self.local_model = resnet20(use_bn=True, num_classes=10 if args.dataset == 'cifar10' else 100).to(device)

        self.client_id = id

        self.malicious = malicious

        self.train_loader = torch.utils.data.DataLoader(
            torch.utils.data.Subset(train_dataset, dataidx),
            batch_size=args.batch_size,
            shuffle=True
        )
        self.model_save = model_save + '-client' + str(id) + '.pth'  # store local ann

        # self.local_model.load_state_dict(torch.load(self.model_save))
        self.eval_loader = eval_loader
        self.optimizer = torch.optim.SGD(self.local_model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-5)
        self.scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=self.optimizer,
                                                              milestones=[300, 450],
                                                              gamma=0.1)
        torch.save(self.local_model.state_dict(), self.model_save)

    # 本地训练
    def local_train(self, Global_model_save, global_round, server_diff1, server_diff2):

        print('\n')
        print('Client: ', self.client_id)

        if global_round > 0:
            Glo_model = torch.load(Global_model_save)
            Local_model_dict = torch.load(self.model_save)
            transfer_snn_to_ann_model(Glo_model, Local_model_dict, self.local_model, server_diff1, server_diff2)

        Train(self.local_model, self.device, self.train_loader, self.client_id, self.model_save, self.optimizer,
              self.scheduler, self.eval_loader, global_round)

        # SNN transfer
        Local_model_dict = torch.load(self.model_save)
        temp_model = resnet20(use_bn=True, num_classes=10 if args.dataset == 'cifar10' else 100).to(self.device)
        temp_model.load_state_dict(Local_model_dict)
        search_fold_and_remove_bn(temp_model)
        temp_model = SpikeModel(model=temp_model, sim_length=args.T, specials=res_specials)
        get_maximum_activation(self.train_loader, model=temp_model, momentum=0.9, iters=5, mse=True,
                               percentile=None,
                               sim_length=args.T, channel_wise=False)
        bias_corr_model(model=temp_model, train_loader=self.train_loader, correct_mempot=False)
        temp_model.set_spike_state(use_spike=True)  # snn

        print(' lr', self.optimizer.state_dict()['param_groups'][0]['lr'])

        '''-------------------------------------------------------------------------------------------------'''
        # 传参
        diff = temp_model.state_dict()  # don't have bn

        diff1 = {}
        for n, m in self.local_model.named_modules():
            if is_bn(m):
                name = str(n) + '.weight'
                data = Local_model_dict[name].cuda()
                diff1[name] = data

                name = str(n) + '.bias'
                data = Local_model_dict[name].cuda()
                diff1[name] = data

                name = str(n) + '.running_mean'
                data = Local_model_dict[name].cuda()
                diff1[name] = data

                name = str(n) + '.running_var'
                data = Local_model_dict[name].cuda()
                diff1[name] = data

        diff2 = {}
        for n, m in self.local_model.named_modules():
            if (isinstance(m, nn.Conv2d)) or (isinstance(m, nn.Linear)):
                name = str(n) + '.weight'
                data = Local_model_dict[name].cuda()
                diff2[name] = data

        return diff, diff1, diff2

    def local_malicious(self):
        self.malicious = True
