import logging
import copy
import os.path as osp
import os
import shutil
import torch
from torch import nn, optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from .utils import *
# added
import sys, os
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "./../")))
import model.network as network
def test_compact(model, target_test_loader, deivce):
    model.eval()
    model.to(deivce)
    correct = 0
    criterion = torch.nn.CrossEntropyLoss()
    len_target_dataset = len(target_test_loader.dataset)
    all_feat, all_target = None, None
    with torch.no_grad():
        for data, target in target_test_loader:
            data, target = data.to(deivce), target.to(deivce)
            feature = model[:-1](data)
            s_output = model(data)
            if all_feat is None:
                all_feat = feature.float()
                all_target = target.float()
            else:
                all_feat = torch.cat((all_feat, feature.float()), 0)
                all_target = torch.cat((all_target, target.float()), 0)
            loss = criterion(s_output, target)
            pred = torch.max(s_output, 1)[1]
            correct += torch.sum(pred == target)
    acc = 100. * correct.double() / len_target_dataset
    return acc, correct.double(), len_target_dataset

class FedXDDServerTrainer(object):
    def __init__(self, client_num, source_data, device, server_model, args):
        self.client_num = client_num
        self.source_data = source_data["source_tr"]
        self.validation_data = source_data["source_te"]
        self.device = device
        self.args = args
        self.alpha = args.alpha

        self.round_idx = 0

        self.model_dict = dict()
        self.sample_num_dict = dict()

        self.global_model = server_model
        self.global_model.to(self.device)

        self.auxiliary_model = copy.deepcopy(server_model)
        # self.auxiliary_model.to(self.device)

        self.upload_model = copy.deepcopy(server_model)

        self.model_params = self.master_params = network.get_parameters(args, self.global_model)
        self.optimizer = torch.optim.SGD(self.model_params, lr=self.args.lr, momentum=0.9,
                                         weight_decay=self.args.wd)
        # self.optimizer = torch.optim.SGD(self.model_params, lr=self.args.lr, momentum=0.9)
        # self.optimizer = network.gen_optim_whole(args, self.global_model)

        self.dual = self.initial_dual()
        self.iter_source_training = iter(self.source_data)


        self.criterion_CE = nn.CrossEntropyLoss()
        self.criterion_KL = KL_Loss(self.args.temperature)
        self.best_acc = 0.0

        self.acc_record = []


        self.model_dict = dict()
        self.sample_num_dict = dict()
        self.train_acc_dict = dict()
        self.train_loss_dict = dict()
        self.correct_num_dict = dict()
        self.test_acc_avg = 0.0
        self.test_loss_avg = 0.0

        self.flag_client_model_uploaded_dict = dict()
        for idx in range(self.client_num):
            self.flag_client_model_uploaded_dict[idx] = False


    def initial_dual(self):
        dual = ()
        for name, param in self.auxiliary_model.named_parameters():
            if name.split('.')[-1] == "weight":
                dual += (torch.zeros_like(param).cpu(),)
        return dual

    def update_dual(self, dual, local_param, global_param):
        new_dual = ()
        for u, w, w0 in zip(dual, local_param, global_param):
            new_u = u + w - w0
            new_dual += (new_u, )
        return new_dual

    def update_upload_param(self, dual, local_param):
        # upload_model = self.client_model
        idx = 0
        for name, param in self.upload_model.named_parameters():
            if name.split('.')[-1] == "weight":
                u = dual[idx]
                w = local_param[idx]
                param.data = w + u
                idx += 1
        return self.upload_model

    def get_upload_model_params(self):
        return self.upload_model.cpu().state_dict()

    def get_global_param(self):
        global_param = ()
        for name, param in self.global_model.named_parameters():
            if name.split('.')[-1] == "weight":
                global_param += (param.detach().cpu().clone(),)
        return global_param

    def get_auxiliary_param(self):
        local_param = ()
        for name, param in self.auxiliary_model.named_parameters():
            if name.split('.')[-1] == "weight":
                local_param += (param.detach().cpu().clone(),)
        return local_param

    def get_global_model_params(self):
        return self.global_model.cpu().state_dict()

    def set_global_model_params(self, model_parameters):
        self.global_model.load_state_dict(model_parameters)

    def add_local_trained_result(self, index, model_params, sample_num):
        logging.info("add_model. index = %d" % index)
        self.model_dict[index] = model_params
        self.sample_num_dict[index] = sample_num
        self.flag_client_model_uploaded_dict[index] = True

    def add_local_acc_result(self, index, correct_num, test_sample_num):
        logging.info("add_model. index = %d" % index)
        self.sample_num_dict[index] = test_sample_num
        self.correct_num_dict[index] = correct_num
        self.flag_client_model_uploaded_dict[index] = True

    def check_whether_all_receive(self):
        for idx in range(self.client_num):
            if not self.flag_client_model_uploaded_dict[idx]:
                return False
        for idx in range(self.client_num):
            self.flag_client_model_uploaded_dict[idx] = False
        return True

    # def aggregate_and_train_ADMM(self):
    #     model_list = []
    #     training_num = 0
    #
    #     self.auxiliary_model.to(self.device)
    #
    #     acc = test_compact(self.auxiliary_model, self.source_data, self.device)
    #
    #     log_str = 'Acc of auxiliary model on source data: {}'.format(acc)
    #     print(log_str)
    #
    #     self.auxiliary_model.train()
    #
    #     # accum_grad = dict()
    #     batch_num = 0
    #
    #     global_param = self.get_global_param()
    #     auxiliary_param = self.get_auxiliary_param()
    #
    #     self.dual = self.update_dual(self.dual, auxiliary_param, global_param)
    #
    #     # for i in range(self.args.epochs_client):
    #     for i in range(2):
    #         for batch_idx, (images, labels) in enumerate(self.source_data):
    #             images, labels = images.to(self.device), labels.to(self.device)
    #             batch_num = batch_num + 1
    #             # logging.info("shape = " + str(images.shape))
    #             log_probs = self.auxiliary_model(images)
    #
    #             loss_ce = (self.criterion_CE(log_probs, labels)).to(self.device)
    #
    #             loss = self.admm_loss(loss_ce, self.dual, global_param)
    #
    #             # log_str = 'Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
    #             #     i, batch_idx * len(images), len(self.source_data.dataset),
    #             #        100. * batch_idx / len(self.source_data.dataset), loss_ce.item())
    #             # print(log_str, flush=True)
    #
    #             self.optimizer.zero_grad()
    #             loss.backward()
    #             self.optimizer.step()
    #
    #             # print('batch # {}, reg loss {}'.format(batch_idx, loss_ce.item()), flush=True)
    #     auxiliary_param = self.get_auxiliary_param()
    #     self.update_upload_param(self.dual, auxiliary_param)
    #
    #     weights = self.get_upload_model_params()
    #
    #     for idx in range(self.client_num):
    #         # if self.args.is_mobile == 1:
    #         #     self.model_dict[idx] = transform_list_to_tensor(self.model_dict[idx])
    #         model_list.append((self.sample_num_dict[idx], self.model_dict[idx]))
    #         training_num += self.sample_num_dict[idx]
    #
    #     logging.info("len of self.model_dict[idx] = " + str(len(self.model_dict)))
    #     (num0, averaged_params) = model_list[0]
    #     # logging.info(averaged_params)
    #     for k in averaged_params.keys():
    #         for i in range(0, len(model_list)):
    #             local_sample_number, local_model_params = model_list[i]
    #             # w = local_sample_number / training_num
    #             w = 1 / (self.client_num + 1)
    #             if i == 0:
    #                 averaged_params[k] = local_model_params[k] * w
    #             else:
    #                 averaged_params[k] += local_model_params[k] * w
    #
    #         # logging.info(averaged_params[k])
    #         # logging.info(accum_grad[k]/ self.client_num)
    #         averaged_params[k] = averaged_params[k] + weights[k] / (self.client_num + 1)
    #
    #         # logging.info(averaged_params[k].device)
    #         # logging.info(accum_grad[k].device)
    #         # averaged_params[k] = averaged_params[k]
    #
    #     self.set_global_model_params(averaged_params)
    #     # self.model_global.to(self.device)
    #     #
    #     # for name, param in self.model_global.named_parameters():
    #     #     u = averaged_params[name].to(self.device)
    #     #     param.data.add_(u)
    #
    #     model_parameters = self.get_global_model_params()
    #
    #     return model_parameters

    def consensus_loss(self, loss_ce, local_model_params_list):
       # self.cal_alpha(1.0)
       loss = (1.0-self.alpha)*loss_ce
       # loss = self.alpha * loss_ce
       for name, param in self.global_model.named_parameters():
           for i in range(self.client_num):
               local_wieght = local_model_params_list[i][name].to(self.device)
               loss += (self.args.rho / 2) * (local_wieght - param).norm()**2

       return loss

    def cal_alpha(self, end_alpha):
        init_alpha = self.args.alpha
        n_rounds = self.args.comm_round
        exp_val = np.log(end_alpha / init_alpha) / n_rounds
        self.alpha = init_alpha * np.exp(exp_val * self.round_idx)

    def aggregate_and_train(self):
        local_model_params_list = []
        training_num = 0

        self.round_idx += 1
        self.global_model.to(self.device)

        self.global_model.train()
        # accum_grad = dict()
        batch_num = 0

        for idx in range(self.client_num):
            local_model_params_list.append(self.model_dict[idx])
            training_num += self.sample_num_dict[idx]

        # self.cal_alpha(1.0)
        # print('server alpha: {:.2f}'.format(self.alpha), flush=True)


        # if self.alpha == 1.0:
        #     averaged_params = local_model_params_list[0]
        #     for k in averaged_params.keys():
        #         for i in range(0, self.client_num):
        #             w = 1 / self.client_num
        #             if i == 0:
        #                 averaged_params[k] = local_model_params_list[i][k] * w
        #             else:
        #                 averaged_params[k] += local_model_params_list[i][k] * w
        #
        #     self.set_global_model_params(averaged_params)
        # else:
        #     for batch_idx, (images, labels) in enumerate(self.source_data):
        #         images, labels = images.to(self.device), labels.to(self.device)
        #         batch_num = batch_num + 1
        #         # logging.info("shape = " + str(images.shape))
        #         log_probs = self.global_model(images)
        #
        #         loss_ce = (self.criterion_CE(log_probs, labels)).to(self.device)
        #
        #         loss = self.consensus_loss(loss_ce, local_model_params_list)
        #
        #         # log_str = 'Update Epoch: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        #         #     batch_idx * len(images), len(self.source_data.dataset),
        #         #        100. * batch_idx / len(self.source_data.dataset), loss.item())
        #         # print(log_str, flush=True)
        #
        #         self.optimizer.zero_grad()
        #         loss.backward()
        #         self.optimizer.step()
        #
        #         # print('batch # {}, reg loss {}'.format(batch_idx, loss_ce.item()), flush=True)

        for batch_idx, (images, labels) in enumerate(self.source_data):
            images, labels = images.to(self.device), labels.to(self.device)
            batch_num = batch_num + 1
            # logging.info("shape = " + str(images.shape))
            log_probs = self.global_model(images)

            loss_ce = (self.criterion_CE(log_probs, labels)).to(self.device)

            loss = self.consensus_loss(loss_ce, local_model_params_list)

            # log_str = 'Update Epoch: [{}/{} ({:.0f}%)]\tLoss: {:.4f}'.format(
            #     batch_idx * len(images), len(self.source_data.dataset),
            #        100. * batch_idx / len(self.source_data.dataset), loss.item())
            # print(log_str, flush=True)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

        # acc_src, _, _ = test_compact(self.global_model, self.source_data, self.device)
        # log_str = 'Acc of global model on source training data: {}'.format(acc_src.item())
        # print(log_str, flush=True)

        acc, correct, validation_size = test_compact(self.global_model, self.validation_data, self.device)
        log_str = 'Acc of global model on source validation data: {}'.format(acc.item())
        self.args.out_file.write(log_str + '\n')
        self.args.out_file.flush()
        print(log_str, flush=True)

        self.acc_record.append(acc.item())

        if self.round_idx == self.args.comm_round:
            savename = self.args.server_record_name
            if not osp.exists(savename+'.npy'):
                np.save(savename, np.array(self.acc_record).reshape(1, -1))
            else:
                pre_acc_record = np.load(savename+'.npy', allow_pickle=True)
                pre_acc_record = np.concatenate((pre_acc_record, np.array(self.acc_record).reshape(1, -1)), axis=0)
                np.save(savename, pre_acc_record)

        model_parameters = self.get_global_model_params()

        return model_parameters, correct, validation_size

    def cal_acc_total(self, correct, validation_size):
        total_correct = correct
        total_data_num = validation_size
        for i in range(0, self.client_num):
            total_correct += self.correct_num_dict[i]
            total_data_num += self.sample_num_dict[i]

        acc_total = 100 * total_correct / total_data_num

        savename = self.args.mixed_record_name
        if not osp.exists(savename+'.npy'):
            np.save(savename, np.array(acc_total.item()))
        else:
            mix_acc_record = np.load(savename+'.npy', allow_pickle=True)
            mix_acc_record = np.append(mix_acc_record, acc_total.item())
            np.save(savename, mix_acc_record)

        log_str = 'Acc of global model on mixed data: {}'.format(acc_total)
        self.args.out_file.write(log_str + '\n')
        self.args.out_file.flush()
        print(log_str+'\n', flush=True)

        self.args.out_file.close()

        return

