from Functions import *

class Node:
    def __init__(self,identity, model,indices, whole_dataset, neighbor, lag, delay_dist, stream, delay_shift, device, criterion, num_worker, batch_size):
        self.delay_shift = delay_shift
        self.model = model
        self.indices = indices
        self.neighbor = neighbor
        self.lag = lag
        self.x = copy.deepcopy(model)
        self.x_tau = [copy.deepcopy(model) for st in range(stream)]
        self.x_pre = copy.deepcopy(model)
        self.t_pre = 0
        self.tau = [0 for st in range(stream)]
        self.count = [0 for i in range(len(neighbor))]
        self.delay_dist = delay_dist
        self.traverse_parent = [None for st in range(stream)]
        self.async_agg = []
        self.async_H = 0
        self.u = copy.deepcopy(model)
        self.u_other= copy.deepcopy(model)
        self.device = device
        self.criterion = criterion
        self.c = copy.deepcopy(model)
        self.u = copy.deepcopy(model)
        if identity[0]=="cifar10":
            self.c.load_state_dict(aggrigate([model, model], [1, -1]).state_dict())
            self.dataset = torch.utils.data.Subset(whole_dataset,list(range(indices[0],indices[1])))
            self.data_size = len(self.dataset)
            self.train_loader = torch.utils.data.DataLoader(self.dataset, batch_size=min(batch_size, len(self.dataset)), shuffle=True, num_workers=num_worker)
            self.train_batch = []
            self.optimizer = torch.optim.SGD(self.x.parameters(), .1,
                            momentum=0.9,
                            weight_decay=1e-4)

        else:
            self.dataset = [whole_dataset[0][indices[0]:indices[1]], whole_dataset[1][indices[0]:indices[1]]]
            self.data_size = len(self.dataset[0])
            self.train_loader = None
            self.train_batch = None
            self.optimizer = None

    def load_data(self):
        for (images,labels) in self.train_loader:
            self.train_batch.append((images,labels))

    def observe_delay(self, a_index):
        return self.delay_shift + np.random.exponential(self.delay_dist[a_index])

    def traverse_select(self, from_node, all_node, visited, st):
        if self.traverse_parent[st] is None:
            self.traverse_parent[st] = from_node
        pv = self.traverse_parent[st]
        Nv = self.neighbor
        for node in set(Nv).intersection(all_node):
            if node not in visited:
                a_index = self.neighbor.index(node)
                del_observed = self.observe_delay(a_index)
                return self.neighbor[a_index], del_observed
        for node in all_node:
            if node not in visited:
                a_index = self.neighbor.index(pv)
                del_observed = self.observe_delay(a_index)
                self.traverse_parent[st] = None
                return self.neighbor[a_index], del_observed
        self.traverse_parent[st] = None
        return "Done", 0

    def local_sgd(self, identity, total_data, t, lr):
        if identity[0] != "cifar10":
            i_feature, i_label = uniform_data_catch(self.dataset[0], self.dataset[1])
        else:
            if len(self.train_batch)==0:
                self.load_data()
            i_feature, i_label = self.train_batch[0]
            self.train_batch = self.train_batch[1:]
        self.x, gr = logistic_regression(identity, total_data, self.x, i_feature, i_label, lr, self.optimizer, self.device, self.criterion)
        return gr

    def synchronize(self,identity, t, x_agg, total_data, st):
        if identity[0] == "cifar10":
            x_agg = aggrigate([x_agg,self.x,self.x_pre,self.x_tau[st]],[1,self.data_size / total_data,1-self.data_size / total_data,-1])
            self.x.load_state_dict(x_agg.state_dict())
        else:
            x_agg = x_agg + self.data_size / total_data * (self.x - self.x_pre) + self.x_pre - self.x_tau[st]
            self.x = copy.deepcopy(x_agg)
        self.t_pre = t + 1
        self.x_tau[st] = copy.deepcopy(x_agg)
        self.x_pre = copy.deepcopy(x_agg)
        self.tau[st] = t + 1
        return x_agg



    def ada_local_sgd(self, dataset, total_data, t,H, learning_rate):
        data_size = len(self.feature)
        i_feature, i_label = uniform_data_catch(self.feature, self.label)
        temp, gr = logistic_regression(dataset, total_data, self.x, i_feature, i_label, learning_rate, num_steps=1)
        self.u += temp - self.x
        self.x = np.zeros(self.d) + (t+1)/max(((t+1)//H*H),1)*self.u_other + data_size/total_data * self.u
        self.x_hat = update_x_hat(self.x, self.x_hat, t)
        return gr

    def ada_synchronize(self,u_agg,total_data):
        data_size = len(self.feature)
        self.u_other = u_agg - data_size/total_data * self.u


