# -*- coding: utf-8 -*-
import ray
import copy
import torch
import numpy as np
from torch import optim, nn

from model.regularization import Regularization

@ray.remote(num_cpus=1)
class Device(object):
    def __init__(self, device_index, args, model, train_loader, data_size):
        self.args = args
        self.data_size = data_size
        self.device_index = device_index
        self.device = args.device
        self.model = copy.deepcopy(model).to(self.device)

        self.optimizer = optim.SGD(self.model.parameters(), lr=self.args.lr,
                                   momentum=0.9, weight_decay=5e-4)

        self.train_loader = train_loader
        self.data_iteration = iter(self.train_loader)  # iter为迭代器，用next执行下一迭代
        self.criterion = nn.CrossEntropyLoss()

    # 去中心化训练
    def decentralized_train(self, now_device_index, epoch, current_weights):

        # 设置当前设备模型
        self.model.set_weights(current_weights)
        # 设置训练模式
        self.model.train()
        try:
            inputs, targets = next(self.data_iteration)
        except StopIteration:
            self.data_iteration = iter(self.train_loader)
            inputs, targets = next(self.data_iteration)

        # 初始化梯度为 0
        self.optimizer.zero_grad()
        outputs = self.model(inputs)
        # 返回loss
        loss = self.criterion(outputs, targets)

        if self.args.model == 'Lasso':
            # L1 Regularization
            loss += Regularization(self.model, 0.001, p=1)(self.model)
        elif self.args.model == 'RR':
            # L2 Regularization
            loss += Regularization(self.model, 0.001, p=0)(self.model)
        # 反向传播迭代
        loss.backward()
        self.optimizer.step()
        # 返回预测标签结果
        _, predicted = outputs.max(1)

        # 返回模型参数和梯度
        return self.model.get_weights(), self.model.get_gradients(), loss

    # 将各个节点的模型权重设置为聚合后的权重
    def decentralized_parallel_set_weights(self, weights):
        self.model.set_weights(weights)

    def test(self, current_weights, test_loader):
        # 设置当前设备模型
        self.model.set_weights(current_weights)
        self.model.eval()
        test_loss = 0
        correct = 0
        total = 0
        criterion = nn.CrossEntropyLoss()

        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(test_loader):
                outputs = self.model(inputs)

                loss = criterion(outputs, targets)
                test_loss += loss
                _, predicted = outputs.max(1)
                inner_total = targets.size(0)
                inner_correct = predicted.eq(targets).sum().item()
                total += inner_total
                correct += inner_correct

        test_acc = format(correct / total * 100, '.4f')
        test_loss = format(test_loss / batch_idx + 1, '.4f')

        return float(test_acc), float(test_loss)

def train_swarm(args, devices, current_epoch, model_parameters):
    # 保存权重
    all_weights_list = []
    all_gradients_list = []
    num_round = int(args.world_size / args.num_dev)
    # 遍历边缘设备
    for round in range(num_round):
        # 各个节点本地迭代返回的模型参数的id
        device_local_update_weights_id = []

        # 遍历边缘设备
        for i in range(len(devices)):
            weight_id = devices[i].decentralized_train.remote(i, current_epoch, model_parameters[i])
            # 所有节点模型参数保存至device_local_update_weights_id
            device_local_update_weights_id.append(weight_id)
        ray.wait(device_local_update_weights_id, num_returns=len(device_local_update_weights_id))

        for object_id in device_local_update_weights_id:
            weights, gradients, loss = ray.get(object_id)
            all_weights_list.append(weights)
            all_gradients_list.append(gradients)

    return all_weights_list, all_gradients_list

def agg_swarm(args, model, all_weights_list, weight_matrix):

    agg_weights_list = []
    choices_dict = {}

    for rank in range(args.world_size):
        pos_list = []
        for rank_j in range(args.world_size):
            if rank != rank_j and weight_matrix[rank][rank_j] > 0:
                pos_list.append(rank_j)
        if len(pos_list) > 0:
            choices_dict[rank] = int(np.random.choice(pos_list, 1))
        else:
            choices_dict[rank] = None

    # 对每个节点遍历
    for index_i in range(args.world_size):
        # 当前节点的模型参数
        tmp_weight = all_weights_list[index_i].copy()

        if choices_dict[index_i] is not None:
            model_a = all_weights_list[index_i].copy()
            # 随机选择与a相连的节点索引
            rand_pos = choices_dict[index_i]

            # b为与a相连的随机节点
            model_b = all_weights_list[rand_pos].copy()

            for key in model.get_weights():
                tmp_weight[key] = torch.div(torch.add(model_a[key], model_b[key]), 2)

            all_weights_list[rand_pos] = copy.deepcopy(tmp_weight)

        agg_weights_list.append(tmp_weight)
    return agg_weights_list