import copy
import time
from datetime import datetime

import numpy as np

from Functions import *
def parallel_sgd_measure(identity,total_data, whole_dataset, exp, cte,H, iters,sampling_f,device,criterion,num_worker,batch_size, test_loader, network):
    if identity[0] == "cifar10":
        whole_dataset = torch.utils.data.DataLoader(whole_dataset, batch_size=batch_size, shuffle=True, num_workers=num_worker)
    a = list(network.all_node[0].x.state_dict().keys())
    layers = list([a[i] for i in range(len(a)) if "conv" in a[i] or "fc.weight" in a[i]])
    loss_over_t = []
    real_time_over_t = []
    real_time = 0
    comm_over_t = []
    comm = 0
    difference_by_layer_over_t= []
    for t in range(iters):
        if t % sampling_f == 0:
            final_x = network.final(identity, total_data)
            real_time_over_t.append(real_time)
            comm_over_t.append(comm)
            loss_over_t.append(loss(identity, whole_dataset, final_x, criterion, device,test_loader))
            print(loss_over_t[-1])
            print(real_time_over_t[-1])
            print(comm_over_t[-1])
        for node in network.all_node:
            end = t - node.lag + 1
            start = max(t - node.lag, 0)
            for sgd_round in range(start, end):
                node.local_sgd(identity, total_data, sgd_round, learning_rate(sgd_round, exp, cte))
        if t % H == 0:
            temp = network.final(identity, total_data)
            if identity[0] == "cifar10":
                difference_by_layer_over_t.append(differ([node.x for node in network.all_node], temp,
                                                         [node.data_size / total_data for node in network.all_node],
                                                         layers))
                print(difference_by_layer_over_t[-1])
                for node in network.all_node:
                    node.x.load_state_dict(temp.state_dict())
            else:
                for node in network.all_node:
                    node.x = np.copy(temp)
        real_time+=1
    return [loss_over_t, real_time_over_t, comm_over_t,difference_by_layer_over_t]

def parallel_sgd_layer(identity,total_data, whole_dataset, exp, cte,H, iters,sampling_f,device,criterion,num_worker,batch_size,test_loader,alpha, network):
    if identity[0] == "cifar10":
        whole_dataset = torch.utils.data.DataLoader(whole_dataset, batch_size=batch_size, shuffle=True, num_workers=num_worker)
    all_layers = list(network.all_node[0].x.state_dict().keys())
    specific_layers = list([all_layers[i] for i in range(len(all_layers)) if "conv" in all_layers[i] or "fc.weight" in all_layers[i]])
    layer_map = []
    group = None
    for layer in all_layers:
        if layer in specific_layers:
            if group is not None:
                layer_map.append(group)
            group = [layer]
            continue
        group.append(layer)
    layer_map.append(group)
    Hs = [H * alpha for i in range(len(layer_map)-1)] + [H]
    loss_over_t = []
    real_time_over_t = []
    real_time = 0
    comm_over_t = []
    comm = 0
    difference_by_layer_over_t= []
    for t in range(iters):
        if t % sampling_f == 0:
            final_x = network.final(identity, total_data)
            real_time_over_t.append(real_time)
            comm_over_t.append(comm)
            loss_over_t.append(loss(identity, whole_dataset, final_x, criterion, device,test_loader))
            print(loss_over_t[-1])
            print(real_time_over_t[-1])
            print(comm_over_t[-1])
        for node in network.all_node:
            end = t - node.lag + 1
            start = max(t - node.lag, 0)
            for sgd_round in range(start, end):
                node.local_sgd(identity, total_data, sgd_round, learning_rate(sgd_round, exp, cte))
        to_be_updated_layers=[]
        for i in range(len(Hs)):
            if t % Hs[i] == 0:
                to_be_updated_layers += layer_map[i]
        if len(to_be_updated_layers) != 0:
            temp = network.final(identity, total_data)
            if identity[0] == "cifar10":
                # difference_by_layer_over_t.append(differ([node.x for node in network.all_node], temp,
                #                                          [node.data_size / total_data for node in network.all_node],
                #                                          layers))
                # for node in network.all_node:
                #     node.x.load_state_dict(temp.state_dict())
                new_model = temp.state_dict()
                for node in network.all_node:
                    model_dict = node.x.state_dict()
                    for layer in to_be_updated_layers:
                        model_dict[layer] = new_model[layer]
                    node.x.load_state_dict(model_dict)
            else:
                for node in network.all_node:
                    node.x = np.copy(temp)
        real_time+=1
    return [loss_over_t, real_time_over_t, comm_over_t,difference_by_layer_over_t]

def parallel_scaffold_layer(identity,total_data, whole_dataset, exp, cte,H, iters,sampling_f,device,criterion,num_worker,batch_size,test_loader,alpha, network):
    if identity[0] == "cifar10":
        whole_dataset = torch.utils.data.DataLoader(whole_dataset, batch_size=batch_size, shuffle=True, num_workers=num_worker)
    all_layers = list(network.all_node[0].x.state_dict().keys())
    specific_layers = list([all_layers[i] for i in range(len(all_layers)) if "conv" in all_layers[i] or "fc.weight" in all_layers[i]])
    layer_map = []
    server_c = copy.deepcopy(network.all_node[0].c)
    group = None
    for layer in all_layers:
        if layer in specific_layers:
            if group is not None:
                layer_map.append(group)
            group = [layer]
            continue
        group.append(layer)
    layer_map.append(group)
    Hs = [H * alpha for i in range(len(layer_map) - 1)] + [H]
    loss_over_t = []
    real_time_over_t = []
    real_time = 0
    comm_over_t = []
    comm = 0
    difference_by_layer_over_t= []
    for t in range(iters):
        print('t', t)
        if t % sampling_f == 0:
            final_x = network.final(identity, total_data)
            real_time_over_t.append(real_time)
            comm_over_t.append(comm)
            loss_over_t.append(loss(identity, whole_dataset, final_x, criterion, device, test_loader))
            print(loss_over_t[-1])
            print(real_time_over_t[-1])
            print(comm_over_t[-1])
        for node in network.all_node:
            end = t - node.lag + 1
            start = max(t - node.lag, 0)
            for sgd_round in range(start, end):
                node.local_sgd(identity, total_data, sgd_round, learning_rate(sgd_round, exp, cte))
                if identity[0] == "cifar10":
                    node.x.load_state_dict(aggrigate([node.x, node.c, server_c],
                                                     [1, -learning_rate(sgd_round, exp, cte),
                                                      learning_rate(sgd_round, exp, cte)]).state_dict())
                else:
                    node.x += learning_rate(sgd_round, exp, cte) * (- node.c + server_c)
        to_be_updated_layers = []
        for i in range(len(Hs)):
            if t % Hs[i] == 0:
                to_be_updated_layers += layer_map[i]
        if len(to_be_updated_layers) != 0:
            temp = network.final(identity, total_data)
            # temp_c = network.final_c(identity, total_data)
            if identity[0] == "cifar10":
                # difference_by_layer_over_t.append(differ([node.x for node in network.all_node], temp,
                #                                          [node.data_size / total_data for node in network.all_node],
                #                                          layers))
                # for node in network.all_node:
                #     node.x.load_state_dict(temp.state_dict())
                new_model = temp.state_dict()
                # new_server_c = temp_c.state_dict()
                server_c_dict = server_c.state_dict()
                coeff = 1 / learning_rate(sgd_round, exp, cte) / Hs[i]
                for node in network.all_node:
                    model_dict = node.x.state_dict()
                    node_c_dict = node.c.state_dict()
                    temp_node_c = aggrigate([node.c, server_c, node.x_pre, node.x], [0, 0, coeff, -coeff])
                    new_node_c = temp_node_c.state_dict()
                    for layer in to_be_updated_layers:
                        model_dict[layer] = new_model[layer]
                        # server_c_dict[layer] = new_server_c[layer]
                        node_c_dict[layer] = new_node_c[layer]
                    node.x.load_state_dict(model_dict)
                    node.x_pre = copy.copy(node.x)
                    # server_c.load_state_dict(server_c_dict)
                    node.c.load_state_dict(node_c_dict)
                temp_c = network.final_c(identity, total_data)
                new_server_c = temp_c.state_dict()
                for layer in to_be_updated_layers:
                    server_c_dict[layer] = new_server_c[layer]
                server_c.load_state_dict(server_c_dict)
            else:
                for node in network.all_node:
                    node.x = np.copy(temp)
        real_time += 1
        v
    return [loss_over_t, real_time_over_t, comm_over_t, difference_by_layer_over_t]