# -*- coding: utf-8 -*-
import ray
import copy
import torch
from torch import optim, nn

from model.regularization import Regularization

@ray.remote(num_cpus=1)
class Device(object):
    def __init__(self, device_index, args, model, train_loader, data_size):
        self.args = args
        self.data_size = data_size
        self.device_index = device_index
        self.device = args.device
        self.model = copy.deepcopy(model).to(self.device)

        self.optimizer = optim.SGD(self.model.parameters(), lr=self.args.lr)

        self.train_loader = train_loader
        self.data_iteration = iter(self.train_loader)  # iter为迭代器，用next执行下一迭代
        self.criterion = nn.CrossEntropyLoss()

    # 去中心化训练
    def decentralized_train(self, current_weights):

        # 设置当前设备模型
        self.model.set_weights(current_weights)
        # 设置训练模式
        self.model.train()
        try:
            inputs, targets = next(self.data_iteration)
        except StopIteration:
            self.data_iteration = iter(self.train_loader)
            inputs, targets = next(self.data_iteration)

        # 初始化梯度为 0
        self.optimizer.zero_grad()
        outputs = self.model(inputs)
        # 返回loss
        loss = self.criterion(outputs, targets)

        if self.args.model == 'Lasso':
            # L1 Regularization
            loss += Regularization(self.model, 0.001, p=1)(self.model)
        elif self.args.model == 'RR':
            # L2 Regularization
            loss += Regularization(self.model, 0.001, p=0)(self.model)
        # 反向传播迭代
        loss.backward()
        self.optimizer.step()
        # 返回预测标签结果
        _, predicted = outputs.max(1)

        # 返回模型参数和梯度
        return self.model.get_weights(), self.model.get_gradients(), loss

    # 将各个节点的模型权重设置为聚合后的权重
    def decentralized_parallel_set_weights(self, weights):
        self.model.set_weights(weights)

    # 进行模型训练
    def decentralized_parallel_gradients_step(self, gradients):


        optimizer = optim.SGD(self.model.parameters(), lr=self.args.lr)

        optimizer.zero_grad()
        self.model.set_gradients(gradients)
        optimizer.step()
        return self.model.get_weights()

    def test(self, current_weights, test_loader):
        # 设置当前设备模型
        self.model.set_weights(current_weights)
        self.model.eval()
        test_loss = 0
        correct = 0
        total = 0
        criterion = nn.CrossEntropyLoss()

        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(test_loader):
                outputs = self.model(inputs)

                loss = criterion(outputs, targets)
                test_loss += loss
                _, predicted = outputs.max(1)
                inner_total = targets.size(0)
                inner_correct = predicted.eq(targets).sum().item()
                total += inner_total
                correct += inner_correct

        test_acc = format(correct / total * 100, '.4f')
        test_loss = format(test_loss / batch_idx + 1, '.4f')

        return float(test_acc), float(test_loss)

def train_dsgd(args, devices, current_epoch, model_parameters):

    # 保存权重
    all_weights_list = []
    all_gradients_list = []
    num_round = int(args.world_size / args.num_dev)
    # 遍历边缘设备
    for round in range(num_round):
        # 各个节点本地迭代返回的模型参数的id
        device_local_update_weights_id = []
        # 遍历边缘设备
        for i in range(len(devices)):
            real_idx = round * args.num_dev + i
            weight_id = devices[i].decentralized_train.remote(model_parameters[real_idx])
            # 所有节点模型参数保存至device_local_update_weights_id
            device_local_update_weights_id.append(weight_id)
        # ray.wait(object_ids,num_returns)阻塞 ray 运行的 object，控制在 object_ids 任务中，最多只可返回 num_returns 个任务
        ray.wait(device_local_update_weights_id, num_returns=len(device_local_update_weights_id))


        for object_id in device_local_update_weights_id:
            weights, gradients, loss = ray.get(object_id)
            all_weights_list.append(weights)
            all_gradients_list.append(gradients)

    return all_weights_list, all_gradients_list

def agg_dsgd(args, model, all_weights_list, weight_matrix):

    agg_weights_list = []

    # 对每个节点遍历
    for index_i in range(args.world_size):
        count = 0  # 统计当前节点相连的节点的个数
        # 当前节点的模型参数
        tmp_weight = all_weights_list[index_i].copy()


        for key in model.get_weights():
            tmp_weight[key] = torch.mul(tmp_weight[key].clone(), weight_matrix[index_i][index_i])

        # 模型聚合
        for index_j in range(args.world_size):
            if weight_matrix[index_i][index_j] != 0:
                count += 1
            # 取另一个 temp weight，a 保存除当前节点外的其他节点模型，用于聚合
            a = all_weights_list[index_j].copy()
            # 跳过 weight_matrix 的对角线元素
            if index_i == index_j:
                continue
            else:
                for key in model.get_weights():
                    a[key] = torch.mul(a[key], weight_matrix[index_i][index_j])
                    tmp_weight[key] = torch.add(tmp_weight[key].clone(), a[key])

        agg_weights_list.append(tmp_weight)

    return agg_weights_list

def grad_step_dsgd(args, devices, agg_weights_list, all_gradients_list, model_parameters):
    num_round = int(args.world_size / args.num_dev)
    # 执行 list 清空
    model_parameters.clear()
    # 遍历边缘设备
    for round in range(num_round):
        all_step_weights_list = []
        # 各节点设置为聚合后的模型参数
        for index in range(args.num_dev):
            real_idx = round * args.num_dev + index
            devices[index].decentralized_parallel_set_weights.remote(agg_weights_list[real_idx])

            scale_factor = 0.1

            weight_id = devices[index].decentralized_parallel_gradients_step.remote(all_gradients_list[real_idx],
                                                                                    scale_factor)
            # 保存最后一次训练任务的 id
            all_step_weights_list.append(weight_id)

        ray.wait(all_step_weights_list, num_returns=len(all_step_weights_list))

        for object_id in all_step_weights_list:
            final_weights = ray.get(object_id)
            model_parameters.append(final_weights)

    return model_parameters