import copy

from Functions import *


class Network:
    def __init__(self, model, all_node, edge, action, pre_action):
        self.all_node = all_node
        self.edge = edge
        self.action = action
        self.pre_action = pre_action
        self.visited = []
        self.x_agg = copy.deepcopy(model)
        self.q = np.ones(len(self.all_node))
        self.p = np.ones(len(self.all_node)) / len(self.all_node)
        self.m = 10

    def async_gossip(self,identity, t, H):
        c = 0
        for node in range(len(self.all_node)):
            node_obj = self.all_node[node]
            if len(node_obj.async_agg) == 0 and node_obj.async_H == 0:
                neighb = np.random.choice(self.edge[node])
                neighb_obj = self.all_node[neighb]
                if len(neighb_obj.async_agg) == 0:
                    c += 2
                    node_obj.async_H = H
                    if identity[0] == "cifar10":
                        normalizer = node_obj.data_size + neighb_obj.data_size
                        future_x = aggrigate([node_obj.x, neighb_obj.x],[node_obj.data_size/normalizer,neighb_obj.data_size/normalizer])
                    else:
                        normalizer = node_obj.data_size + neighb_obj.data_size
                        future_x = (node_obj.x * node_obj.data_size + neighb_obj.x * neighb_obj.data_size) / normalizer
                        # future_x = (node_obj.x + neighb_obj.x) / 2
                    node_obj.async_agg.append(
                        [t - (-neighb_obj.observe_delay(neighb_obj.neighbor.index(node)) // 1), future_x,
                         copy.deepcopy(node_obj.x)])
                    neighb_obj.async_agg.append(
                        [t - (-node_obj.observe_delay(node_obj.neighbor.index(neighb)) // 1), future_x,
                         copy.deepcopy(neighb_obj.x)])
        for node in range(len(self.all_node)):
            node_obj = self.all_node[node]
            node_obj.async_agg.sort(key=lambda tup: tup[0])
            node_obj.async_H = max(0,node_obj.async_H-1)
            counter = 0
            for update in node_obj.async_agg:
                #             print(update[0])
                if update[0] == t:
                    if identity[0] == "cifar10":
                        node_obj.x.load_state_dict(aggrigate([update[1],node_obj.x,update[2]], [1,1,-1]).state_dict())
                    else:
                        node_obj.x = update[1] + (node_obj.x - update[2])
                    counter += 1
                elif update[0] < t:
                    print("***************ERROR*****************")
                else:
                    break
            del node_obj.async_agg[:counter]
        return c

    def gossip(self, one_link,D2,identity):
        if one_link:
            for node in [np.random.choice(len(self.all_node))]:
                neighb = np.random.choice(self.edge[node])
                neighb_obj = self.all_node[neighb]
                node_obj = self.all_node[node]
                if identity[0] == "cifar10":
                    normalizer = node_obj.data_size + neighb_obj.data_size
                    node_obj.x.load_state_dict(aggrigate([node_obj.x,neighb_obj.x],[node_obj.data_size/normalizer,neighb_obj.data_size/normalizer]).state_dict())
                else:
                    normalizer = node_obj.data_size + neighb_obj.data_size
                    node_obj.x = (node_obj.x * node_obj.data_size + neighb_obj.x * neighb_obj.data_size) / normalizer
                time = neighb_obj.observe_delay(neighb_obj.neighbor.index(node))
                c = 1
        else:
            x_values = []
            for node in self.all_node:
                x_values.append(copy.deepcopy(node.x))
            x_values = np.array(x_values)
            if D2:
                u_values = []
                for node in self.all_node:
                    u_values.append(node.u)
                u_values = np.array(u_values)
            w = []
            for node in range(len(self.all_node)):
                w_node = self.gossip_weight(node)
                w.append(w_node / w_node.sum())
            w = np.array(w)
            # print(w[0])
            # input("ok?")
            # final =
            if identity[0]=="cifar10":
                for node in range(len(self.all_node)):
                    node_obj = self.all_node[node]
                    node_obj.x.load_state_dict(aggrigate(x_values, w[node]).state_dict())
                    # if D2:
                    #     node_obj.c.load_state_dict(aggrigate([node_obj.x,node_obj.x_pre,node_obj.u],[1,-1,-1]).state_dict())
                    #     node_obj.x_pre = copy.deepcopy(node_obj.x)
                if D2:
                    for node in range(len(self.all_node)):
                        node_obj = self.all_node[node]
                        res = aggrigate(u_values, w[node])
                        # print(node_obj.c.state_dict()['fc.bias'])
                        # print(res.state_dict()['fc.bias'])
                        # print(node_obj.u.state_dict()['fc.bias'])
                        node_obj.c.load_state_dict(aggrigate([node_obj.c,node_obj.u,res],[1,-1,1]).state_dict())
                        # print(node_obj.c.state_dict()['fc.bias'])
                        # input("ok??")
            elif identity[0]=="w8a":
                res = np.matmul(w, x_values)
                for node in range(len(self.all_node)):
                    node_obj = self.all_node[node]
                    node_obj.x = np.copy(res[node])
                    # if D2:
                        # node_obj.c = node_obj.x - node_obj.x_pre - node_obj.u
                        # node_obj.x_pre = np.copy(node_obj.x)
                if D2:
                    res = np.matmul(w, u_values)
                    for node in range(len(self.all_node)):
                        node_obj = self.all_node[node]
                        node_obj.c = node_obj.c - node_obj.u + res[node]
            else:
                final = []
                res = []
                for i in range(len(x_values[0])):
                    res.append(np.matmul(w, x_values[:, i, :]))
                for i in range(len(self.all_node)):
                    final.append([res[j][i] for j in range(len(x_values[0]))])
                final = np.array(final)
                for node in range(len(self.all_node)):
                    node_obj = self.all_node[node]
                    node_obj.x = np.copy(final[node])
                if D2:
                    final = []
                    res = []
                    for i in range(len(u_values[0])):
                        res.append(np.matmul(w, u_values[:, i, :]))
                    for i in range(len(self.all_node)):
                        final.append([res[j][i] for j in range(len(u_values[0]))])
                    final = np.array(final)
                    for node in range(len(self.all_node)):
                        node_obj = self.all_node[node]
                        node_obj.c = node_obj.c - node_obj.u + final[node]
            times = []
            for node in range(len(self.all_node)):
                node_obj = self.all_node[node]
                times.append(max([node_obj.observe_delay(i) for i in range(len(self.edge[node]))]+[0]))
            time = max(times)
            edge_size = 0
            for node in self.edge:
                edge_size += len(self.edge[node])
            c = edge_size
        return time, c

    # def gossip_weight(self, node):
    #     return np.array([int(i in self.edge[node]) for i in range(len(self.all_node))]) + np.array(
    #         [int(i == node) for i in range(len(self.all_node))])
    def gossip_weight(self, node):
        return np.array([self.all_node[i].data_size * (i in self.edge[node]) for i in range(len(self.all_node))]) + np.array(
            [self.all_node[i].data_size * int(i == node) for i in range(len(self.all_node))])

    def final(self,identity,total_data):
        if identity[0] != "cifar10":
            return np.average([node.x for node in self.all_node], axis=0,
                              weights=[node.data_size for node in self.all_node])
        return aggrigate([node.x for node in self.all_node],[node.data_size/total_data for node in self.all_node])

    def final_c(self,identity,total_data):
        if identity[0] != "cifar10":
            return np.average([node.c for node in self.all_node], axis=0,
                              weights=[node.data_size for node in self.all_node])
        return aggrigate([node.c for node in self.all_node],[node.data_size/total_data for node in self.all_node])

    def rw(self, node):
        random_select = np.random.choice(self.edge[node])
        p = np.random.uniform()
        if p <= min(1, len(self.edge[node]) / len(self.edge[random_select])):
            return random_select
        return node
