import logging
import copy
from ofa.imagenet_classification import networks
import torch
from torch import nn, optim
from .utils import *
from .image_tar_lite import train_target
# from tensorflow_privacy.privacy.analysis.compute_noise_from_budget_lib import compute_noise
from opacus.privacy_engine import get_noise_multiplier
from opacus.privacy_analysis import compute_rdp, get_privacy_spent

# added
import sys, os
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "./../")))
import model.network as network


DEFAULT_ALPHAS = [1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64))

def test_compact(model, target_test_loader, deivce):
    model.eval()
    model.to(deivce)
    correct = 0
    criterion = torch.nn.CrossEntropyLoss()
    len_target_dataset = len(target_test_loader.dataset)
    all_feat, all_target = None, None
    with torch.no_grad():
        for data, target in target_test_loader:
            data, target = data.to(deivce), target.to(deivce)
            feature = model[:-1](data)
            s_output = model(data)
            if all_feat is None:
                all_feat = feature.float()
                all_target = target.float()
            else:
                all_feat = torch.cat((all_feat, feature.float()), 0)
                all_target = torch.cat((all_target, target.float()), 0)
            loss = criterion(s_output, target)
            pred = torch.max(s_output, 1)[1]
            correct += torch.sum(pred == target)
    acc = 100. * correct.double() / len_target_dataset
    return acc, correct.double(), len_target_dataset

def train_acc(model, target_train_loader, deivce):
    model.eval()
    model.to(deivce)
    correct = 0
    criterion = torch.nn.CrossEntropyLoss()
    len_target_dataset = len(target_train_loader.dataset)
    with torch.no_grad():
        for data, target, _ in target_train_loader:
            data, target = data.to(deivce), target.to(deivce)
            s_output = model(data)
            loss = criterion(s_output, target)
            pred = torch.max(s_output, 1)[1]
            correct += torch.sum(pred == target)
    acc = 100. * correct.double() / len_target_dataset
    return acc


class FedXDDClientTrainer(object):
    def __init__(self, client_index, client_num, local_training_data, local_test_data, local_sample_number, device,
                 compact_model, netF, netB, netC, large_model_optimizer, args, test_data, secure_agg=False,
                 seed_list=None, sign_list=None):
        self.client_index = client_index
        self.client_num = client_num
        self.local_training_data = local_training_data
        self.local_test_data = local_test_data
        self.local_sample_number = local_sample_number
        self.test_data = test_data

        self.device = device
        self.compact_model = compact_model
        self.global_model = copy.deepcopy(compact_model)
        self.upload_model = copy.deepcopy(compact_model)
        self.dual = self.initial_dual()

        self.netF = netF
        self.netB = netB
        self.netC = netC

        self.large_model_optimizer = large_model_optimizer

        self.args = args
        self.round_idx = 0

        logging.info("client device = " + str(self.device))
        self.compact_model.to(self.device)
        self.model_params = self.master_params = network.get_parameters(args, self.compact_model)
        # self.model_params = self.master_params = self.compact_model.parameters()
        self.optimizer = torch.optim.SGD(self.model_params, lr=self.args.lr, momentum=0.9, weight_decay=self.args.wd)

        self.criterion_CE = nn.CrossEntropyLoss()
        self.criterion_KL = KL_Loss(self.args.temperature)

        self.DA_logits_dict = dict()

        self.secure_agg = secure_agg
        self.seed_list = seed_list
        self.sign_list = sign_list


    def get_sample_number(self):
        return self.local_sample_number


    def get_global_param(self):
        global_param = ()
        for name, param in self.global_model.named_parameters():
            if name.split('.')[-1] == "weight":
                global_param += (param.detach().cpu().clone(),)
        return global_param

    def initial_dual(self):
        dual = ()
        for name, param in self.compact_model.named_parameters():
            if name.split('.')[-1] == "weight":
                dual += (torch.zeros_like(param).cpu(),)
        return dual

    def update_dual(self, dual, local_param, global_param):
        new_dual = ()
        for u, w, w0 in zip(dual, local_param, global_param):
            new_u = u + w - w0
            new_dual += (new_u,)
        return new_dual

    def get_local_param(self):
        local_param = ()
        for name, param in self.compact_model.named_parameters():
            if name.split('.')[-1] == "weight":
                local_param += (param.detach().cpu().clone(),)
        return local_param

    def update_upload_param(self, dual, local_param, global_param):
        # upload_model = self.client_model
        idx = 0
        for name, param in self.upload_model.named_parameters():
            if name.split('.')[-1] == "weight":
                u = dual[idx]
                w = local_param[idx]
                # w0 = global_param[idx]
                # param.data = w + u - w0
                param.data = w + u
                idx += 1
        if self.secure_agg:
            for seed, sign in zip(self.seed_list, self.sign_list):
                idx = 0
                RNG = torch.Generator()
                RNG.manual_seed(int(seed))
                for name, param in self.upload_model.named_parameters():
                    if name.split('.')[-1] == "weight":
                        noise = generate_noise(param.shape, 2.0, RNG)
                        param.data = param.data + sign*noise
                        idx += 1

        return self.upload_model

    def update_local_param(self, dual, local_param, global_param):
        # upload_model = self.client_model
        idx = 0
        for name, param in self.compact_model.named_parameters():
            if name.split('.')[-1] == "weight":
                u = dual[idx].to(self.device)
                w = local_param[idx].to(self.device)
                w0 = global_param[idx].to(self.device)
                param.data.add_(w0 - w - u)
                idx += 1
        return self.compact_model

    # def cal_alpha(self, end_alpha):
    #     init_alpha = self.args.alpha
    #     n_rounds = self.args.comm_round
    #     exp_val = np.log(end_alpha / init_alpha) / n_rounds
    #     alpha = init_alpha * np.exp(exp_val * self.round_idx)
    #
    #     return alpha

    def admm_loss(self, loss_kd, dual, global_param):
        idx = 0
        alpha = self.args.alpha
        loss = (alpha) / self.client_num * loss_kd
        # loss = 1.0 / self.client_num * loss_kd
        # loss = loss_kd / self.client_num
        for name, param in self.compact_model.named_parameters():
            if name.split('.')[-1] == "weight":
                w0 = global_param[idx].to(self.device)
                u = dual[idx].to(self.device)

                loss += (self.args.rho / 2) * (param - w0 + u).norm() ** 2
                idx += 1
        return loss

    def get_upload_model_params(self):
        return self.upload_model.cpu().state_dict()

    def set_global_model_params(self, model_parameters):
        self.global_model.load_state_dict(model_parameters)


    def evaluation(self):
        acc_tr, _, _ = test_compact(self.global_model, self.local_test_data, self.device)
        log_str = 'client {} - Acc of global model on local test data: {}'.format(self.client_index, acc_tr)
        self.args.out_file.write(log_str + '\n')
        self.args.out_file.flush()
        print(log_str, flush=True)

        acc, correct, test_size = test_compact(self.global_model, self.test_data, self.device)
        log_str = 'client {} - Acc of global model on global test data: {}'.format(self.client_index, acc)
        self.args.out_file.write(log_str + '\n')
        self.args.out_file.flush()
        print(log_str, flush=True)

        # self.acc_record.append(acc.item())
        #
        # savename = self.args.client_record_name + str(self.args.exp_idx)
        #
        # np.save(savename, np.array(self.acc_record))

        return acc.item(), correct, test_size


    ### need to revise
    def train(self):
        self.round_idx += 1

        # set the module in training mode
        self.compact_model.to(self.device)
        self.netF, self.netB, self.netC = train_target(self.args, self.netF, self.netB, self.netC, self.large_model_optimizer,
                                                       self.local_training_data, self.local_test_data, self.global_model)

        # train and update
        epoch_loss = []
        global_param = self.get_global_param()
        local_param = self.get_local_param()

        self.dual = self.update_dual(self.dual, local_param, global_param)

        self.netF.eval()
        self.netB.eval()
        self.netC.eval()

        for epoch in range(self.args.epochs_client):
            # batch_loss = []

            self.compact_model.train()
            for batch_idx, (images, labels, tar_idx) in enumerate(self.local_training_data):
                # images, labels = images.to(self.device), labels.to(self.device)
                images = images.to(self.device)
                # logging.info("shape = " + str(images.shape))
                log_probs = self.compact_model(images)

                # large_model_logits = self.large_model.predict(images)
                large_model_logits = self.netC(self.netB(self.netF(images)))

                # large_model_logits = torch.from_numpy(self.DA_logits_dict[batch_idx]).to(self.device)
                loss_kd = self.criterion_KL(log_probs, large_model_logits)

                # print(loss_kd)

                # local_param = self.get_local_param()

                # self.optimizer.zero_grad()
                # loss_kd.backward()
                # self.optimizer.step()

                loss = self.admm_loss(loss_kd, self.dual, global_param)

                # log_str = 'Update Epoch: {} [{}/{} ({:.0f}%)]\t KD_Loss: {:.4f} Loss: {:.4f}'.format(
                #     epoch, batch_idx * len(images), len(self.local_training_data.dataset),
                #        100. * batch_idx * len(images) / len(self.local_training_data.dataset), loss_kd.item(), loss.item())
                # print(log_str, flush=True)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                # self.update_local_param(dual, local_param, global_param)

                # log_str = 'client {} - Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                #     self.client_index, epoch, batch_idx * len(images), len(self.local_training_data.dataset),
                #                               100. * batch_idx / len(self.local_training_data), loss.item())
                # self.args.out_file.write(log_str + '\n')
                # self.args.out_file.flush()
                # print(log_str)
            #     batch_loss.append(loss_kd.item())
            #
            # epoch_loss.append(sum(batch_loss) / len(batch_loss))
            # logging.info(dual[0])

            # acc, _, _ = test_compact(self.compact_model, self.local_test_data, self.device)
            # log_str = 'client {} - Update Epoch: {} - Acc on test data: {}'.format(self.client_index, epoch, acc)
            # self.args.out_file.write(log_str + '\n')
            # self.args.out_file.flush()
            # print(log_str)

        # acc = test_compact(self.compact_model, self.local_test_data, self.device)
        # log_str = 'client {} - Acc of compact model on local test data: {}'.format(self.client_index, acc)
        # self.args.out_file.write(log_str + '\n')
        # self.args.out_file.flush()

        local_param = self.get_local_param()
        self.update_upload_param(self.dual, local_param, global_param)
        weights = self.get_upload_model_params()

        return weights, self.local_sample_number

class FedXDDClientTrainer_DA(object):
    def __init__(self, client_index, client_num, local_training_data, local_test_data, local_sample_number, device,
                 compact_model, netF, netB, netC, args):
        self.client_index = client_index
        self.client_num = client_num
        self.local_training_data = local_training_data
        self.local_test_data = local_test_data
        self.local_sample_number = local_sample_number

        self.device = device
        self.compact_model = compact_model
        self.global_model = copy.deepcopy(compact_model)
        self.upload_model = copy.deepcopy(compact_model)
        self.dual = self.initial_dual()

        self.netF = netF
        self.netB = netB
        self.netC = netC

        self.args = args
        self.round_idx = 0
        self.acc_record = []

        logging.info("client device = " + str(self.device))
        self.compact_model.to(self.device)
        self.model_params = self.master_params = network.get_parameters(args, self.compact_model)
        self.optimizer = torch.optim.SGD(self.model_params, lr=self.args.lr, momentum=0.9, weight_decay=self.args.wd)
        # self.optimizer = torch.optim.SGD(self.model_params, lr=self.args.lr, momentum=0.9)

        # self.optimizer = network.gen_optim_whole(args, self.compact_model)
        self.criterion_CE = nn.CrossEntropyLoss()
        self.criterion_KL = KL_Loss(self.args.temperature)

        self.DA_logits_dict = dict()


    def get_sample_number(self):
        return self.local_sample_number


    def get_global_param(self):
        global_param = ()
        if hasattr(self.global_model, 'modelname') and self.global_model.modelname[0:4] == 'SHOT':
            for name, param in self.global_model[:2].named_parameters():
                if name.split('.')[-1] == "weight":
                    global_param += (param.detach().cpu().clone(),)
        else:
            for name, param in self.global_model.named_parameters():
                if name.split('.')[-1] == "weight":
                    global_param += (param.detach().cpu().clone(),)
        return global_param

    def initial_dual(self):
        dual = ()
        if hasattr(self.compact_model, 'modelname') and self.compact_model.modelname[0:4] == 'SHOT':
            for name, param in self.compact_model[:2].named_parameters():
                if name.split('.')[-1] == "weight":
                    dual += (torch.zeros_like(param).cpu(),)
        else:
            for name, param in self.compact_model.named_parameters():
                if name.split('.')[-1] == "weight":
                    dual += (torch.zeros_like(param).cpu(),)
        return dual

    def update_dual(self, dual, local_param, global_param):
        new_dual = ()
        for u, w, w0 in zip(dual, local_param, global_param):
            new_u = u + w - w0
            new_dual += (new_u, )
        return new_dual

    def get_local_param(self):
        local_param = ()
        if hasattr(self.compact_model, 'modelname') and self.compact_model.modelname[0:4] == 'SHOT':
            for name, param in self.compact_model[:2].named_parameters():
                if name.split('.')[-1] == "weight":
                    local_param += (param.detach().cpu().clone(),)
        else:
            for name, param in self.compact_model.named_parameters():
                if name.split('.')[-1] == "weight":
                    local_param += (param.detach().cpu().clone(),)
        return local_param

    def update_upload_param(self, dual, local_param, global_param):
        # upload_model = self.client_model
        idx = 0
        if hasattr(self.upload_model, 'modelname') and self.upload_model.modelname[0:4] == 'SHOT':
            for name, param in self.upload_model[:2].named_parameters():
                if name.split('.')[-1] == "weight":
                    u = dual[idx]
                    w = local_param[idx]
                    w0 = global_param[idx]
                    # param.data = w + u - w0
                    param.data = w + u
                    idx += 1
        else:
            for name, param in self.upload_model.named_parameters():
                if name.split('.')[-1] == "weight":
                    u = dual[idx]
                    w = local_param[idx]
                    w0 = global_param[idx]
                    # param.data = w + u - w0
                    param.data = w + u
                    idx += 1
        return self.upload_model

    def update_local_param(self, dual, local_param, global_param):
        # upload_model = self.client_model
        idx = 0
        if hasattr(self.compact_model, 'modelname') and self.compact_model.modelname[0:4] == 'SHOT':
            for name, param in self.compact_model[:2].named_parameters():
                if name.split('.')[-1] == "weight":
                    u = dual[idx].to(self.device)
                    w = local_param[idx].to(self.device)
                    w0 = global_param[idx].to(self.device)
                    param.data.add_(w0 - w - u)
                    idx += 1
        else:
            for name, param in self.compact_model.named_parameters():
                if name.split('.')[-1] == "weight":
                    u = dual[idx].to(self.device)
                    w = local_param[idx].to(self.device)
                    w0 = global_param[idx].to(self.device)
                    param.data.add_(w0 - w - u)
                    idx += 1
        return self.compact_model

    def cal_alpha(self, end_alpha):
        init_alpha = self.args.alpha
        n_rounds = self.args.comm_round
        exp_val = np.log(end_alpha / init_alpha) / n_rounds
        alpha = init_alpha * np.exp(exp_val * self.round_idx)

        return alpha

    def admm_loss(self, loss_kd, dual, global_param):
        idx = 0
        # alpha = self.cal_alpha(1.0)
        # print('client alpha: {:.2f}'.format(alpha), flush=True)
        alpha = self.args.alpha
        loss = (alpha) / self.client_num * loss_kd
        # loss = 1.0 / self.client_num * loss_kd
        # loss = loss_kd / self.client_num
        if hasattr(self.compact_model, 'modelname') and self.compact_model.modelname[0:4] == 'SHOT':
            for name, param in self.compact_model[:2].named_parameters():
                if name.split('.')[-1] == "weight":
                    w0 = global_param[idx].to(self.device)
                    u = dual[idx].to(self.device)
                    # loss += (self.args.rho / 2) * (param - w0 + u).norm()**2
                    loss += (self.args.rho / 2) * (param - w0 + u).norm()**2
                    idx += 1
        else:
            for name, param in self.compact_model.named_parameters():
                if name.split('.')[-1] == "weight":
                    w0 = global_param[idx].to(self.device)
                    u = dual[idx].to(self.device)
                    # loss += (self.args.rho / 2) * (param - w0 + u).norm()**2
                    loss += (self.args.rho / 2) * (param - w0 + u).norm()**2
                    idx += 1
        return loss

    def get_upload_model_params(self):
        return self.upload_model.cpu().state_dict()

    def set_global_model_params(self, model_parameters):
        self.global_model.load_state_dict(model_parameters)


    def evaluation(self):
        # acc_tr = train_acc(self.global_model, self.local_training_data, self.device)
        # log_str = 'client {} - Acc of global model on local training data: {}'.format(self.client_index, acc_tr)
        # self.args.out_file.write(log_str + '\n')
        # self.args.out_file.flush()
        # print(log_str, flush=True)

        acc, correct, test_size = test_compact(self.global_model, self.local_test_data, self.device)
        log_str = 'client {} - Acc of global model on local test data: {}'.format(self.client_index, acc)
        self.args.out_file.write(log_str + '\n')
        self.args.out_file.flush()
        print(log_str, flush=True)
        # self.acc_record.append(acc.item())
        #
        # savename = self.args.client_record_name + str(self.args.exp_idx)
        #
        # np.save(savename, np.array(self.acc_record))

        return acc.item(), correct, test_size

    def train(self):
        self.round_idx += 1

        # acc, _, _ = test_compact(self.global_model, self.local_test_data, self.device)
        # log_str = 'client {} - Acc of global model on local test data: {}'.format(self.client_index, acc)
        # self.args.out_file.write(log_str + '\n')
        # self.args.out_file.flush()
        # print(log_str)
        # self.acc_record.append(acc.item())

        # if self.round_idx == self.args.comm_round:
        #     savename = self.args.client_record_nsame+str(self.args.exp_idx)
        #     np.save(savename, np.array(self.acc_record))

        # set the module in training mode
        acc = 0
        self.compact_model.to(self.device)

        # train and update
        self.compact_model.train()
        epoch_loss = []
        global_param = self.get_global_param()
        local_param = self.get_local_param()

        self.dual = self.update_dual(self.dual, local_param, global_param)

        self.netF.eval()
        self.netB.eval()
        self.netC.eval()

        for epoch in range(self.args.epochs_client):
            batch_loss = []
            self.compact_model.train() # debug for SHOT-like model
            if hasattr(self.compact_model, 'modelname') and self.compact_model.modelname[0:4] == 'SHOT':
                self.compact_model[2].eval()
            for batch_idx, (images, labels, tar_idx) in enumerate(self.local_training_data):
                # images, labels = images.to(self.device), labels.to(self.device)
                images = images.to(self.device)
                # logging.info("shape = " + str(images.shape))
                log_probs = self.compact_model(images)

                # large_model_logits = self.large_model.predict(images)
                large_model_logits = self.netC(self.netB(self.netF(images)))

                # large_model_logits = torch.from_numpy(self.DA_logits_dict[batch_idx]).to(self.device)
                loss_kd = self.criterion_KL(log_probs, large_model_logits)

                # print(loss_kd)

                # local_param = self.get_local_param()

                # self.optimizer.zero_grad()
                # loss_kd.backward()
                # self.optimizer.step()

                loss = self.admm_loss(loss_kd, self.dual, global_param)

                # log_str = 'Update Epoch: {} [{}/{} ({:.0f}%)]\t KD_Loss: {:.4f} Loss: {:.4f}'.format(
                #     epoch, batch_idx * len(images), len(self.local_training_data.dataset),
                #        100. * batch_idx * len(images) / len(self.local_training_data.dataset), loss_kd.item(), loss.item())
                # print(log_str, flush=True)

                self.optimizer.zero_grad()
                loss.backward()

                self.optimizer.step()

                # self.update_local_param(dual, local_param, global_param)

                # log_str = 'client {} - Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                #     self.client_index, epoch, batch_idx * len(images), len(self.local_training_data.dataset),
                #                               100. * batch_idx / len(self.local_training_data), loss.item())
                # self.args.out_file.write(log_str + '\n')
                # self.args.out_file.flush()
                # print(log_str)
                batch_loss.append(loss_kd.item())

            epoch_loss.append(sum(batch_loss) / len(batch_loss))
            # logging.info(dual[0])

            # acc, _, _, = test_compact(self.compact_model, self.local_test_data, self.device)
            # log_str = 'client {} - Update Epoch: {} - Acc on test data: {:.4f}'.format(self.client_index, epoch, acc)
            # self.args.out_file.write(log_str + '\n')
            # self.args.out_file.flush()
            # print(log_str)

        # acc, _, _, = test_compact(self.compact_model, self.local_test_data, self.device)
        # log_str = 'client {} - Acc of compact model on local test data: {:.4f}'.format(self.client_index, acc)
        # self.args.out_file.write(log_str + '\n')
        # self.args.out_file.flush()
        # print(log_str)


        local_param = self.get_local_param()
        self.update_upload_param(self.dual, local_param, global_param)
        weights = self.get_upload_model_params()
        # if self.args.is_mobile == 1:
        #     weights = transform_tensor_to_list(weights)


        return weights, self.local_sample_number

