import torch
from dynamic_graph import *


class Ring(DynamicGraph):
    def __init__(self, n_nodes):
        w = torch.zeros((n_nodes, n_nodes))

        for i in range(n_nodes):
            w[i,i] = 1/3
            w[i, (i+1)%n_nodes] = 1/3
            w[i, (i-1)%n_nodes] = 1/3

        super().__init__([w])



class Ring2(DynamicGraph):
    def __init__(self, n_nodes, delay_ids):
        w = torch.zeros((n_nodes, n_nodes))

        for i in range(n_nodes):
            node_id = delay_ids.index(i)
            next_ids = delay_ids.index((i+1)%n_nodes)
            prev_ids = delay_ids.index((i-1)%n_nodes)

            w[node_id,node_id] = 1/3
            w[node_id,next_ids] = 1/3
            w[node_id,prev_ids] = 1/3
                
        super().__init__([w])
        
