import numpy as np
import copy

def gen_gossip(m, pair):
    u, v = pair
    P = np.eye(m)
    P[u, u] = P[v, v] = P[u, v] = P[v, u] = .5
    return P

def seq2sigma(m, seq):
    P = np.eye(m)
    for pair in seq:
        P = np.dot(P, gen_gossip(m, pair))
    eiv = np.linalg.eigvals(np.dot(P.T, P))
    return np.sort(eiv)[-2]

class Network(object):
    def __init__(self, node, topo):
        self.M, self.topo = node, topo
        self._generate_adjacent()
        self._bipartite()
    
    def _generate_adjacent(self, **kwargs):
        '''
        Generate different types of networks, and output the adjacent matrix.
        '''
        def complete(m, **kwargs):
            A = [list(range(m)) for i in range(m)]
            for i in range(m):
                A[i].remove(i)
            return A
        
        def ring(m, **kwargs):
            A = []
            for i in range(m):
                A.append([(i-1)%m, (i+1)%m])
            return A
                
        def path(m, **kwargs):
            A = []
            for i in range(m):
                a = []
                if i > 0:
                    a.append(i-1)
                if i < m-1:
                    a.append(i+1)
                A.append(a)
            return A
        
        def star(m, **kwargs):
            A = []
            for i in range(m-1):
                A.append([m-1])
            A.append(list(range(m-1)))
            return A
        
        def smallworld(m, **kwargs):
            k, p = kwargs.pop('k', 2), kwargs.pop('p', .2)
            A = [[] for i in range(m)]
            cand = [list(range(m)) for i in range(m)]
            for u in range(m):
                cand[u].remove(u)
                
            for u in range(m):
                for j in range(k):
                    if len(cand[u]) <= 0:
                        break
                    if np.random.rand() < p:
                        v = np.random.choice(cand[u])
                    else:
                        v = u+1
                        while(not v%m in cand[u]):
                            v += 1
                        v = v%m
                            
                    cand[u].remove(v)
                    cand[v].remove(u)
                    A[u].append(v)
                    A[v].append(u)
            return A
        
        def dumbbell(m, **kwargs):
            n = m//2
            A = [list(range(n)) for i in range(n)]
            A.extend([list(range(n, m)) for i in range(m-n)])
            for i in range(m):
                A[i].remove(i)
            A[n-1].append(n)
            A[n].append(n-1)
            return A
        
        def wheel(m, **kwargs):
            k, p = kwargs.pop('k', 4), kwargs.pop('p', .2)
            A = smallworld(m-1, k=k, p=p)
            for i in range(m-1):
                A[i].append(m-1)
            A.append(list(range(m-1)))
            return A
        
        gen_dict = {'complete':complete, 'ring':ring, 'smallworld':smallworld, 'path':path, 'star':star, 'dumbbell':dumbbell, 'wheel':wheel}
        self.A = gen_dict[self.topo](self.M, **kwargs)
        
        self.E = []
        for i in range(self.M):
            self.E.extend([(i,j) for j in self.A[i] if i < j])
        
    def _bipartite(self): 
        '''
        Use a greedy strategy to create a bipartite subgraph.
        Note that, this is ONLY used for realizing [Lian et al, 2018]'s implementation for symmetric gossiping,
        which is restricted on a bipartite graph due to the need for deadlock-avoidance.
        '''
        A_bipart = [[] for i in range(self.M)]
        active = {}
        
        root = np.random.choice(list(range(self.M)))
        V, E = [root], []
        active[root] = True
        
        E_cand = copy.copy(self.E)
        
        while(len(V) < self.M):
            #print(V, E_cand)
            for i,j in E_cand:
                if i in V:
                    #print(i,j)
                    break
                if j in V:
                    k = i
                    i = j
                    j = k
                    #print(i,j)
                    break
            V.append(j)
            active[j] = not active[i]
            for k in self.A[j]:
                if k in V:
                    if k < j:
                        E_cand.remove((k,j))
                    else:
                        E_cand.remove((j,k))
                    if active[k] == active[i]:
                        E.append((k,j))
                        A_bipart[k].append(j)
                        A_bipart[j].append(k)
        
        self.A_bi = A_bipart
        self.active = active
        self.E_bi = E
            
        
        
    