import ray
import torch
import numpy as np
import networkx as nx

def matrix_normal(m):
    for i in range(len(m)):
        for j in range(len(m)):
            m[i][j] = 1 if m[i][j] != 0 else 0 
    return m

def grad_mean_std(all_gradients_list):
    mean_list = []
    std_list = []
    for gradients in all_gradients_list:
        gradients = [layer.view(-1) for layer in gradients if layer is not None]
        tmp = torch.cat(gradients)
        mean_list.append(torch.mean(tmp))
        std_list.append(torch.std(tmp))
    return np.mean(mean_list), np.mean(std_list)

def test_multi(args, devices, model_parameters, current_epoch, test_loader, writer, file, all_gradients_list):
    total_acc = 0
    total_loss = 0

    num_round = int(args.world_size / args.num_dev)

    for round in range(num_round):

        acc_loss_ids = []

        for i in range(args.num_dev):
            object_id = devices[i].test.remote(model_parameters[round * args.num_dev + i], test_loader)
            acc_loss_ids.append(object_id)
        ray.wait(acc_loss_ids, num_returns=args.num_dev)

        for object_id in acc_loss_ids:
            acc, loss = ray.get(object_id)
            total_acc += acc
            total_loss += loss
    # grad_mean, grad_std = grad_mean_std(all_gradients_list)
    avg_acc = format(total_acc / args.world_size, '.4f')
    avg_loss = format(total_loss / args.world_size, '.4f')

    # print(f'[Iter{current_epoch}] avg_acc: {avg_acc}, avg_loss: {avg_loss}, grad_mean: {grad_mean}, grad_std: {grad_std}')
    # writer.writerow([current_epoch, avg_acc, avg_loss, grad_mean, grad_std])
    print(f'[Iter{current_epoch}] | Avg Acc: {avg_acc} | Avg Loss: {avg_loss}')
    writer.writerow([current_epoch, avg_acc, avg_loss])
    file.flush()

def weight_norm(model_parameters, model):
    # weight_norm_summarize
    norm_list = []
    # 非 LR 模型下会有多个模型参数，对所有参数做平均
    for key in model.get_weights():
        if 'weight' in key:
            temp_list = []
            for i in range(len(model_parameters)):
                for j in range(len(model_parameters)):
                    if i < j:
                        # torch.sub(a,b) 张量 a,b 对应元素相减
                        # 返回每两个节点间模型参数 X范数 的绝对值 (源码为P=1)
                        temp_list.append(torch.norm(torch.abs(torch.sub(
                            model_parameters[i][key], model_parameters[j][key])), p=2))
            norm_list.append(temp_list)

    norm_arr = np.array(norm_list)
    norm_avg = norm_arr.mean()
    return norm_avg

def compute_L(train_loader_list, bsz):
    max_norm = 0
    for i in range(len(train_loader_list)):
        for batch_idx, (inputs, targets) in enumerate(train_loader_list[i]):
            max_norm = max(max_norm, torch.norm(inputs, p=2))
    return float(max_norm ** 2 / (batch_idx * bsz))

def get_laplacian(g):
    spectrum = nx.laplacian_spectrum(g)
    lap = np.array(nx.laplacian_matrix(g).todense(), dtype='float64') / spectrum[-1]
    return lap, spectrum[-1] / spectrum[1]

def generate_weight_matrix(world_size, topology_list):

    W_list = []
    lambda_min = 1.
    for topology in topology_list:
        if world_size == 5:
            t1 = t2 = np.concatenate((topology, topology), axis=1)
            topology = np.concatenate((t1, t2), axis=0)
        elif world_size > 10:
            topology = topology[:10, :10]
        G = nx.from_numpy_matrix(topology)
        if nx.is_connected(G):
            spectrum = nx.laplacian_spectrum(G)
            W, chi = get_laplacian(G)
            lambda_min = min(1 / chi, lambda_min)
            W_list.append(W)
    return W_list, lambda_min
