import logging
import copy

import torch
from torch import nn, optim

from .utils import *
from .image_tar_lite import train_target

def test_compact(model, target_test_loader, deivce):
    model.eval()
    model.to(deivce)
    correct = 0
    criterion = torch.nn.CrossEntropyLoss()
    len_target_dataset = len(target_test_loader.dataset)
    with torch.no_grad():
        for data, target in target_test_loader:
            data, target = data.to(deivce), target.to(deivce)
            s_output = model(data)
            loss = criterion(s_output, target)
            pred = torch.max(s_output, 1)[1]
            correct += torch.sum(pred == target)
    acc = 100. * correct.double() / len_target_dataset
    return acc, correct.double(), len_target_dataset


class FedSGDClientTrainer(object):
    def __init__(self, client_index, client_num, local_training_data, local_test_data, local_sample_number, device,
                 compact_model, netF, netB, netC, large_model_optimizer, args, source_data=None):
        self.client_index = client_index
        self.client_num = client_num
        self.local_training_data = local_training_data[client_index]
        self.local_test_data = local_test_data[client_index]
        self.local_sample_number = local_sample_number[client_index]
        self.source_data = source_data

        self.device = device
        self.compact_model = compact_model
        self.global_model = copy.deepcopy(compact_model)
        self.upload_model = copy.deepcopy(compact_model)
        self.dual = self.initial_dual()

        self.netF = netF
        self.netB = netB
        self.netC = netC

        self.large_model_optimizer = large_model_optimizer

        self.args = args
        self.round_idx = 0

        logging.info("client device = " + str(self.device))
        self.compact_model.to(self.device)

        self.model_params = self.master_params = self.compact_model.parameters()
        self.optimizer = torch.optim.SGD(self.model_params, lr=self.args.lr, momentum=0.9, weight_decay=self.args.wd)

        self.criterion_CE = nn.CrossEntropyLoss()
        self.criterion_KL = KL_Loss(self.args.temperature)

        self.DA_logits_dict = dict()

    def get_sample_number(self):
        return self.local_sample_number


    def get_global_param(self):
        global_param = ()
        for name, param in self.global_model.named_parameters():
            if name.split('.')[-1] == "weight":
                global_param += (param.detach().cpu().clone(),)
        return global_param

    def initial_dual(self):
        dual = ()
        for name, param in self.compact_model.named_parameters():
            if name.split('.')[-1] == "weight":
                dual += (torch.zeros_like(param).cpu(),)
        return dual

    def update_dual(self, dual, local_param, global_param):
        new_dual = ()
        for u, w, w0 in zip(dual, local_param, global_param):
            new_u = u + w - w0
            new_dual += (new_u, )
        return new_dual

    def get_local_param(self):
        local_param = ()
        for name, param in self.compact_model.named_parameters():
            if name.split('.')[-1] == "weight":
                local_param += (param.detach().cpu().clone(),)
        return local_param

    def update_upload_param(self, dual, local_param, global_param):
        # upload_model = self.client_model
        idx = 0
        for name, param in self.upload_model.named_parameters():
            if name.split('.')[-1] == "weight":
                u = dual[idx]
                w = local_param[idx]
                w0 = global_param[idx]
                # param.data = w + u - w0
                param.data = w + u
                idx += 1
        return self.upload_model

    def update_local_param(self, dual, local_param, global_param):
        # upload_model = self.client_model
        idx = 0
        for name, param in self.compact_model.named_parameters():
            if name.split('.')[-1] == "weight":
                u = dual[idx].to(self.device)
                w = local_param[idx].to(self.device)
                w0 = global_param[idx].to(self.device)
                param.data.add_(w0 - w - u)
                idx += 1
        return self.compact_model

    def cal_alpha(self, end_alpha):
        init_alpha = self.args.alpha
        n_rounds = self.args.comm_round
        exp_val = np.log(end_alpha / init_alpha) / n_rounds
        alpha = init_alpha * np.exp(exp_val * self.round_idx)

        return alpha

    def admm_loss(self, loss_kd, dual, global_param):
        idx = 0
        # alpha = self.cal_alpha(1.0)
        # print('client alpha: {:.2f}'.format(alpha), flush=True)
        alpha = self.args.alpha
        loss = (alpha) / self.client_num * loss_kd
        for name, param in self.compact_model.named_parameters():
            if name.split('.')[-1] == "weight":
                w0 = global_param[idx].to(self.device)
                u = dual[idx].to(self.device)
                # loss += (self.args.rho / 2) * (param - w0 + u).norm()**2
                loss += (self.args.rho / 2) * (param - w0 + u).norm()**2
                idx += 1
        return loss

    def get_upload_model_params(self):
        return self.upload_model.cpu().state_dict()

    def set_global_model_params(self, model_parameters):
        self.global_model.load_state_dict(model_parameters)


    ### need to revise
    def train(self):
        self.round_idx += 1

        acc, correct, test_size = test_compact(self.global_model, self.local_test_data, self.device)
        log_str = 'client {} - Acc of global model on local test data: {}'.format(self.client_index, acc)
        self.args.out_file.write(log_str + '\n')
        self.args.out_file.flush()
        print(log_str)

        # set the module in training mode
        self.compact_model.to(self.device)

        self.netF, self.netB, self.netC = train_target(self.args, self.netF, self.netB, self.netC, self.large_model_optimizer,
                                                       self.local_training_data, self.local_test_data, self.global_model)

        # train and update
        self.compact_model.train()
        epoch_loss = []
        global_param = self.get_global_param()
        local_param = self.get_local_param()

        self.dual = self.update_dual(self.dual, local_param, global_param)

        self.netF.eval()
        self.netB.eval()
        self.netC.eval()

        for epoch in range(self.args.epochs_client):
            batch_loss = []

            self.compact_model.train()
            for batch_idx, (images, labels, tar_idx) in enumerate(self.local_training_data):
                # images, labels = images.to(self.device), labels.to(self.device)
                images = images.to(self.device)
                # logging.info("shape = " + str(images.shape))
                log_probs = self.compact_model(images)

                # large_model_logits = self.large_model.predict(images)
                large_model_logits = self.netC(self.netB(self.netF(images)))

                # large_model_logits = torch.from_numpy(self.DA_logits_dict[batch_idx]).to(self.device)
                loss_kd = self.criterion_KL(log_probs, large_model_logits)

                # print(loss_kd)

                # local_param = self.get_local_param()

                # self.optimizer.zero_grad()
                # loss_kd.backward()
                # self.optimizer.step()

                loss = self.admm_loss(loss_kd, self.dual, global_param)

                log_str = 'Update Epoch: {} [{}/{} ({:.0f}%)]\t KD_Loss: {:.4f} Loss: {:.4f}'.format(
                    epoch, batch_idx * len(images), len(self.local_training_data.dataset),
                       100. * batch_idx * len(images) / len(self.local_training_data.dataset), loss_kd.item(), loss.item())
                print(log_str, flush=True)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                # self.update_local_param(dual, local_param, global_param)

                # log_str = 'client {} - Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                #     self.client_index, epoch, batch_idx * len(images), len(self.local_training_data.dataset),
                #                               100. * batch_idx / len(self.local_training_data), loss.item())
                # self.args.out_file.write(log_str + '\n')
                # self.args.out_file.flush()
                # print(log_str)
                batch_loss.append(loss_kd.item())

            epoch_loss.append(sum(batch_loss) / len(batch_loss))
            # logging.info(dual[0])

            acc, _, _ = test_compact(self.compact_model, self.local_test_data, self.device)
            log_str = 'client {} - Update Epoch: {} - Acc on test data: {}'.format(self.client_index, epoch, acc)
            # self.args.out_file.write(log_str + '\n')
            # self.args.out_file.flush()
            print(log_str)

        # acc = test_compact(self.compact_model, self.local_test_data, self.device)
        log_str = 'client {} - Acc of compact model on local test data: {}'.format(self.client_index, acc)
        self.args.out_file.write(log_str + '\n')
        self.args.out_file.flush()
        # print(log_str)

        # if self.source_data is not None:
        #     acc_src = test_compact(self.compact_model, self.source_data, self.device)
        #     log_str = 'client {} - Acc of compact model on source data: {}'.format(self.client_index, acc_src)
        #     # self.args.out_file.write(log_str + '\n')
        #     # self.args.out_file.flush()
        #     print(log_str)

        local_param = self.get_local_param()
        self.update_upload_param(self.dual, local_param, global_param)
        weights = self.get_upload_model_params()
        # if self.args.is_mobile == 1:
        #     weights = transform_tensor_to_list(weights)

        return weights, self.local_sample_number




class FedSGDClientTrainer_DA(object):
    def __init__(self, client_index, client_num, local_training_data, local_test_data, local_sample_number, device,
                 compact_model, netF, netB, netC, args):
        self.client_index = client_index
        self.client_num = client_num
        self.local_training_data = local_training_data[client_index]
        self.local_test_data = local_test_data[client_index]
        self.local_sample_number = local_sample_number[client_index]

        self.device = device
        self.compact_model = compact_model
        self.global_model = copy.deepcopy(compact_model)
        self.upload_model = copy.deepcopy(compact_model)
        self.dual = self.initial_dual()

        self.netF = netF
        self.netB = netB
        self.netC = netC

        self.args = args
        self.round_idx = 0
        self.acc_record = []

        logging.info("client device = " + str(self.device))
        self.compact_model.to(self.device)

        self.model_params = self.master_params = self.compact_model.parameters()
        self.optimizer = torch.optim.SGD(self.model_params, lr=self.args.lr, momentum=0.9, weight_decay=self.args.wd)

        self.criterion_CE = nn.CrossEntropyLoss()
        self.criterion_KL = KL_Loss(self.args.temperature)

        self.DA_logits_dict = dict()

    def get_sample_number(self):
        return self.local_sample_number


    def get_global_param(self):
        global_param = ()
        for name, param in self.global_model.named_parameters():
            if name.split('.')[-1] == "weight":
                global_param += (param.detach().cpu().clone(),)
        return global_param

    def initial_dual(self):
        dual = ()
        for name, param in self.compact_model.named_parameters():
            if name.split('.')[-1] == "weight":
                dual += (torch.zeros_like(param).cpu(),)
        return dual

    def update_dual(self, dual, local_param, global_param):
        new_dual = ()
        for u, w, w0 in zip(dual, local_param, global_param):
            new_u = u + w - w0
            new_dual += (new_u, )
        return new_dual

    def get_local_param(self):
        local_param = ()
        for name, param in self.compact_model.named_parameters():
            if name.split('.')[-1] == "weight":
                local_param += (param.detach().cpu().clone(),)
        return local_param

    def update_upload_param(self, dual, local_param, global_param):
        # upload_model = self.client_model
        idx = 0
        for name, param in self.upload_model.named_parameters():
            if name.split('.')[-1] == "weight":
                u = dual[idx]
                w = local_param[idx]
                w0 = global_param[idx]
                # param.data = w + u - w0
                param.data = w + u
                idx += 1
        return self.upload_model

    def update_local_param(self, dual, local_param, global_param):
        # upload_model = self.client_model
        idx = 0
        for name, param in self.compact_model.named_parameters():
            if name.split('.')[-1] == "weight":
                u = dual[idx].to(self.device)
                w = local_param[idx].to(self.device)
                w0 = global_param[idx].to(self.device)
                param.data.add_(w0 - w - u)
                idx += 1
        return self.compact_model

    def cal_alpha(self, end_alpha):
        init_alpha = self.args.alpha
        n_rounds = self.args.comm_round
        exp_val = np.log(end_alpha / init_alpha) / n_rounds
        alpha = init_alpha * np.exp(exp_val * self.round_idx)

        return alpha

    def admm_loss(self, loss_kd, dual, global_param):
        idx = 0
        # alpha = self.cal_alpha(1.0)
        # print('client alpha: {:.2f}'.format(alpha), flush=True)
        alpha = self.args.alpha
        loss = (alpha) / self.client_num * loss_kd
        for name, param in self.compact_model.named_parameters():
            if name.split('.')[-1] == "weight":
                w0 = global_param[idx].to(self.device)
                u = dual[idx].to(self.device)
                # loss += (self.args.rho / 2) * (param - w0 + u).norm()**2
                loss += (self.args.rho / 2) * (param - w0 + u).norm()**2
                idx += 1
        return loss

    def get_upload_model_params(self):
        return self.upload_model.cpu().state_dict()

    def set_global_model_params(self, model_parameters):
        self.global_model.load_state_dict(model_parameters)


    def evaluation(self):
        acc, correct, test_size = test_compact(self.global_model, self.local_test_data, self.device)
        log_str = 'client {} - Acc of global model on local test data: {}'.format(self.client_index, acc)
        self.args.out_file.write(log_str + '\n')
        self.args.out_file.flush()
        print(log_str, flush=True)
        # self.acc_record.append(acc.item())
        #
        # savename = self.args.client_record_name + str(self.args.exp_idx)
        #
        # np.save(savename, np.array(self.acc_record))

        return acc.item(), correct, test_size

    ### need to revise
    def train(self):
        # self.round_idx += 1

        # acc, _, _ = test_compact(self.global_model, self.local_test_data, self.device)
        # log_str = 'client {} - Acc of global model on local test data: {}'.format(self.client_index, acc)
        # self.args.out_file.write(log_str + '\n')
        # self.args.out_file.flush()
        # print(log_str)
        # self.acc_record.append(acc.item())

        # if self.round_idx == self.args.comm_round:
        #     savename = self.args.client_record_nsame+str(self.args.exp_idx)
        #     np.save(savename, np.array(self.acc_record))

        # set the module in training mode
        acc = 0
        self.compact_model.to(self.device)

        # train and update
        self.compact_model.train()
        epoch_loss = []
        global_param = self.get_global_param()
        local_param = self.get_local_param()

        self.dual = self.update_dual(self.dual, local_param, global_param)

        self.netF.eval()
        self.netB.eval()
        self.netC.eval()

        for epoch in range(self.args.epochs_client):
            batch_loss = []

            self.compact_model.train()
            for batch_idx, (images, labels, tar_idx) in enumerate(self.local_training_data):
                # images, labels = images.to(self.device), labels.to(self.device)
                images = images.to(self.device)
                # logging.info("shape = " + str(images.shape))
                log_probs = self.compact_model(images)

                # large_model_logits = self.large_model.predict(images)
                large_model_logits = self.netC(self.netB(self.netF(images)))

                # large_model_logits = torch.from_numpy(self.DA_logits_dict[batch_idx]).to(self.device)
                loss_kd = self.criterion_KL(log_probs, large_model_logits)

                # print(loss_kd)

                # local_param = self.get_local_param()

                # self.optimizer.zero_grad()
                # loss_kd.backward()
                # self.optimizer.step()

                loss = self.admm_loss(loss_kd, self.dual, global_param)

                log_str = 'Update Epoch: {} [{}/{} ({:.0f}%)]\t KD_Loss: {:.4f} Loss: {:.4f}'.format(
                    epoch, batch_idx * len(images), len(self.local_training_data.dataset),
                       100. * batch_idx * len(images) / len(self.local_training_data.dataset), loss_kd.item(), loss.item())
                print(log_str, flush=True)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                # self.update_local_param(dual, local_param, global_param)

                # log_str = 'client {} - Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                #     self.client_index, epoch, batch_idx * len(images), len(self.local_training_data.dataset),
                #                               100. * batch_idx / len(self.local_training_data), loss.item())
                # self.args.out_file.write(log_str + '\n')
                # self.args.out_file.flush()
                # print(log_str)
                batch_loss.append(loss_kd.item())

            epoch_loss.append(sum(batch_loss) / len(batch_loss))
            # logging.info(dual[0])

            acc, _, _, = test_compact(self.compact_model, self.local_test_data, self.device)
            log_str = 'client {} - Update Epoch: {} - Acc on test data: {:.4f}'.format(self.client_index, epoch, acc)
            # self.args.out_file.write(log_str + '\n')
            # self.args.out_file.flush()
            print(log_str)

        # acc = test_compact(self.compact_model, self.local_test_data, self.device)
        log_str = 'client {} - Acc of compact model on local test data: {:.4f}'.format(self.client_index, acc)
        self.args.out_file.write(log_str + '\n')
        self.args.out_file.flush()
        # print(log_str)


        local_param = self.get_local_param()
        self.update_upload_param(self.dual, local_param, global_param)
        weights = self.get_upload_model_params()
        # if self.args.is_mobile == 1:
        #     weights = transform_tensor_to_list(weights)


        return weights, self.local_sample_number

