import copy
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from utils.data_utils import HSI_train


class Client(object):
    """
    Base class for clients in federated learning.
    """

    def __init__(self, args, id, **kwargs):
        self.model = copy.deepcopy(args.model)
        self.dataset = args.dataset
        self.test_data = args.test_data
        self.mask4d_ls = args.mask4d_ls
        self.device = args.device
        self.id = id  # integer

        self.train_samples = None
        self.batch_size = args.batch_size
        self.learning_rate = args.local_learning_rate
        self.local_steps = args.local_steps

        self.args = args

        if self.args.MP or self.args.algorithm=='MPT':
            self.backbone = copy.deepcopy(args.backbone)

        self.has_BatchNorm = False
        for layer in self.model.children():
            if isinstance(layer, nn.BatchNorm2d):
                self.has_BatchNorm = True
                break

        self.train_slow = kwargs['train_slow']
        self.send_slow = kwargs['send_slow']
        self.train_time_cost = {'num_rounds': 0, 'total_cost': 0.0}
        self.send_time_cost = {'num_rounds': 0, 'total_cost': 0.0}


        if self.args.algorithm=='MPT':
            pass
        else:
            if self.args.PTP:
                for name, param in self.model.named_parameters():
                    if 'ada' in name:
                        param.requires_grad = True
                        print('>>> [%s] requires grad =True'%name)
                    else:
                        param.requires_grad = False
                        print('>>> [%s] requires grad =False'%name)

                self.optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()),
                                                  lr=self.learning_rate,
                                                  betas=(0.9, 0.999))
            else:
                self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate, betas=(0.9, 0.999))

            self.learning_rate_scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer=self.optimizer,
                milestones=args.milestones,
                gamma=args.learning_rate_decay_gamma
            )
            self.learning_rate_decay = args.learning_rate_decay

        if self.args.algorithm == 'MCPA':
            self.loss = torch.nn.KLDivLoss(reduction=self.args.KL_reduction).cuda()
        else:
            self.loss = torch.nn.MSELoss().cuda()



    def load_train_data(self, batch_size, epoch_sum_num):

        HSI_dataset = HSI_train(args=self.args,
                                train_set=self.dataset,
                                id=self.id,
                                trn_split=self.args.trn_split,
                                epoch_sum_num=epoch_sum_num)
        self.train_samples = HSI_dataset.usr_trn_sz
        print('>>> load_train_data(): HSI_dataset.usr_trn_sz', self.train_samples)
        return DataLoader(HSI_dataset,
                           batch_size=batch_size,
                           shuffle=True,
                           num_workers=self.args.workers,
                           drop_last=True)

    def load_test_data(self, batch_size=None):

        raise NotImplementedError

    def load_pretrain(self, model_checkpoint):
        if self.args.MP or self.args.algorithm=="MPT":
            missing_key, unexpected_key = self.backbone.load_state_dict(model_checkpoint['model_weights'], strict=False)
        else:
            missing_key, unexpected_key = self.model.load_state_dict(model_checkpoint['model_weights'], strict=False)
        print('Missing keys=', missing_key) # verified
        print('Unexpected_keys=', unexpected_key) # verified


    def set_parameters(self, model):
        if self.args.algorithm == 'MPT':
            for new_param, old_param in zip(model.parameters(), self.model.parameters()):
                old_param.data = new_param.data.clone()
        else:
            if self.args.PTP or self.args.MABFT or self.args.FMABFT:
                for (_, new_param), (name, old_param) in zip(model.named_parameters(), self.model.named_parameters()):
                    if 'ada' in name:
                        old_param.data = new_param.data.clone()
            else:
                for new_param, old_param in zip(model.parameters(), self.model.parameters()):
                    old_param.data = new_param.data.clone()

    def clone_model(self, model, target):
        for param, target_param in zip(model.parameters(), target.parameters()):
            target_param.data = param.data.clone()

    def update_parameters(self, model, new_params):
        for param, new_param in zip(model.parameters(), new_params):
            param.data = new_param.data.clone()


