# -*- coding: utf-8 -*-
import ray
import torch
from torch import nn
from utils.Optimizer_ADOM import ADOM

@ray.remote(num_cpus=0.2)
class Device(object):
    def __init__(self, device_index, args, model, train_loader, data_size, L, lambda__min):
        self.args = args
        self.data_size = data_size
        self.device_index = device_index
        self.device = args.device
        self.model = model.to(self.device)

        tmp_para = model.get_param()
        z_g_k = []
        z_f_k = []
        m_k = []
        delta_k = []
        k_ = self.args.K
        for para in tmp_para:
            z_g_k.append(torch.zeros_like(para).float)
            z_f_k.append(para)
            m_k.append(para)
            delta_k.append(torch.zeros_like(para).float)

        self.optimizer = ADOM(self.model.parameters(), args, L, lambda__min, z_g_k,
                              z_f_k, m_k, delta_k, k_)

        self.train_loader = train_loader
        self.data_iteration = iter(self.train_loader)  # iter为迭代器，用next执行下一迭代
        self.criterion = nn.CrossEntropyLoss()

    # 去中心化训练
    def decentralized_train(self, now_device_index, epoch, current_weights, weight_matrix):
        # 设置当前设备模型
        self.model.set_weights(current_weights)
        # 设置训练模式
        self.model.train()
        try:
            inputs, targets = next(self.data_iteration)
        except StopIteration:
            self.data_iteration = iter(self.train_loader)
            inputs, targets = next(self.data_iteration)

        # 初始化梯度为 0
        self.optimizer.zero_grad()
        outputs = self.model(inputs)
        # 返回loss
        loss = self.criterion(outputs, targets)
        # 反向传播迭代
        loss.backward()
        self.optimizer.step(torch.tensor(weight_matrix.copy(), dtype=torch.float32))
        # 返回预测标签结果
        _, predicted = outputs.max(1)

        # 返回模型参数和梯度
        return self.model.get_weights(), self.model.get_gradients(), loss

    # 将各个节点的模型权重设置为聚合后的权重
    def decentralized_parallel_set_weights(self, weights):
        self.model.set_weights(weights)

    def test(self, current_weights, test_loader):
        # 设置当前设备模型
        self.model.set_weights(current_weights)
        self.model.eval()
        test_loss = 0
        correct = 0
        total = 0
        criterion = nn.CrossEntropyLoss()

        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(test_loader):
                outputs = self.model(inputs)

                loss = criterion(outputs, targets)
                test_loss += loss
                _, predicted = outputs.max(1)
                inner_total = targets.size(0)
                inner_correct = predicted.eq(targets).sum().item()
                total += inner_total
                correct += inner_correct

        test_acc = format(correct / total * 100, '.4f')
        test_loss = format(test_loss / batch_idx + 1, '.4f')

        return float(test_acc), float(test_loss)

def train_adom(args, devices, current_epoch, model_parameters, weight_matrix):
    # 保存权重
    all_weights_list = []
    all_gradients_list = []
    num_round = int(args.world_size / args.num_dev)

    # 遍历边缘设备
    for round in range(num_round):
        # 各个节点本地迭代返回的模型参数的id
        device_local_update_weights_id = []

        # 遍历边缘设备
        for i in range(len(devices)):
            real_idx = round * args.num_dev + i
            weight_id = devices[i].decentralized_train.remote(real_idx, current_epoch, model_parameters[real_idx],
                                                              weight_matrix)
            # 所有节点模型参数保存至device_local_update_weights_id
            device_local_update_weights_id.append(weight_id)
        # ray.wait(object_ids,num_returns)阻塞 ray 运行的 object，控制在 object_ids 任务中，最多只可返回 num_returns 个任务
        ray.wait(device_local_update_weights_id, num_returns=len(device_local_update_weights_id))

        for object_id in device_local_update_weights_id:
            weights, gradients, loss = ray.get(object_id)
            all_weights_list.append(weights)
            all_gradients_list.append(gradients)

    return all_weights_list, all_gradients_list

def agg_adom(args, model, all_weights_list):
    agg_weights_list = []

    for i in range(args.world_size):
        # W_i
        curr_weight = all_weights_list[i].copy()
        for key in model.get_weights():
            curr_weight[key] = torch.mul(curr_weight[key], 1 / args.world_size)
        for j in range(args.world_size):
            if i == j:
                continue
            else:
                # W_j
                tmp_weight = all_weights_list[j].copy()
                for key in model.get_weights():
                    curr_weight[key] = torch.add(curr_weight[key],
                                                 torch.mul(tmp_weight[key],
                                                           1 / args.world_size))

        agg_weights_list.append(curr_weight)

    return agg_weights_list
