import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.optim.optimizer import Optimizer
import torch.distributed as dist
import time
import random

import sys
sys.path.append("../")
from ring import *
from exponential_graph import *
from one_peer_exponential_graph import *
from adic_graph import *
from one_peer_undirected_equidyn import *
from one_peer_directed_equidyn import *
from directed_equitopo import *
from undirected_equitopo import *
from torus import *
from adic_graph_wo_permutation import *

class FullActivationScheme():
    def __init__(self, n_nodes, graph, seed=0):
        self.n_nodes = n_nodes
        self.n_active_nodes = []
        self.graph_list = []
        self.n_active_nodes.append(self.n_nodes)
        
        if graph == "ring":
            self.graph_list.append(Ring(self.n_nodes))
        elif graph == "exp":
            self.graph_list.append(ExponentialGraph(self.n_nodes))
        elif graph == "one_peer_adic_wo_permutation":
            self.graph_list.append(AdicGraphWoPermutation(self.n_nodes, 1))
        elif graph == "one_peer_undirected_equidyn":
            self.graph_list.append(OnePeerUndirectedEquiDyn(self.n_nodes))
                
        self.seed = seed
        self.random_inst = random.Random(self.seed)
        
        self.current_active_nodes = []
        self.next_active_nodes = []

        self.current_all_active_nodes = list(range(self.n_nodes))
        self.next_all_active_nodes = list(range(self.n_nodes))
        
        # initialize
        self.update()
        self.update()

    def update(self):
        self.current_active_nodes = self.next_active_nodes

        # shuffle
        node_list = list(range(self.n_nodes))
        self.next_active_nodes = [self.random_inst.sample(node_list, self.n_nodes)]

        for graph in self.graph_list:
            graph.update()
        
    def is_active(self, node_id):
        return node_id in self.current_all_active_nodes
    
    def is_next_active(self, node_id):
        return node_id in self.next_all_active_nodes
    
    def get_token_location(self, graph_id, token_id):
        return self.current_active_nodes[graph_id][token_id]
    
    def get_next_token_location(self, graph_id, token_id):
        return self.next_active_nodes[graph_id][token_id]
    
    def get_token_id(self, node_id):
        if self.is_active(node_id):
            for graph_id in range(len(self.current_active_nodes)):
                if node_id in self.current_active_nodes[graph_id]:
                    return graph_id, self.current_active_nodes[graph_id].index(node_id)
        else:
            return None, None

    def get_next_token_id(self, node_id):
        if self.is_next_active(node_id):
            for graph_id in range(len(self.next_active_nodes)):
                if node_id in self.next_active_nodes[graph_id]:
                    return graph_id, self.next_active_nodes[graph_id].index(node_id)
        else:
            return None
        
    def get_out_neighbors(self, node_id):
        graph_id, token_id = self.get_token_id(node_id)
        out_token_ids = self.graph_list[graph_id].get_out_neighbors(token_id)
        return {self.get_next_token_location(graph_id, out_token_id) : out_token_ids[out_token_id] for out_token_id in out_token_ids}
        
    def get_in_neighbors(self, node_id):
        graph_id, token_id = self.get_next_token_id(node_id)
        in_token_ids = self.graph_list[graph_id].get_in_neighbors(token_id)
        return {self.get_token_location(graph_id, in_token_id) : in_token_ids[in_token_id] for in_token_id in in_token_ids}
