import math
import torch

def graph_generator(mode, total_client):
    A = torch.zeros(total_client, total_client)
    if mode == 'ring':
        for i in range(total_client):
            A[i, (i+total_client-1)%total_client] = 1./3
            A[i, i%total_client] = 1./3
            A[i, (i+1)%total_client] = 1./3
    elif mode == 'star':
        A = torch.eye(total_client) * 0.5
        A[:,0] = 0.5
        A[0,:] = 1./total_client
    elif mode == 'grid':
        if total_client == 300:
            w = 15
            h = 20
        elif total_client == 350:
            w = 14
            h = 25
        elif total_client == 400:
            w = 20
            h = 20
        elif total_client == 450:
            w = 18
            h = 25
        else:
            raise NotImplementedError
        for i in range(1, h-1):
            for j in range(1, w-1):
                A[i*w+j,i*w+j] = 0.2
                A[i*w+j,(i-1)*w+j] = 0.2
                A[i*w+j,(i+1)*w+j] = 0.2
                A[i*w+j,i*w+(j-1)] = 0.2
                A[i*w+j,i*w+(j+1)] = 0.2
        for i in range(1,h-1):
            A[i*w,i*w] = 0.25
            A[i*w,(i-1)*w] = 0.25
            A[i*w,(i+1)*w] = 0.25
            A[i*w,i*w + 1] = 0.25

            A[(i+1)*w-1,(i+1)*w-1] = 0.25
            A[(i+1)*w-1,i*w-1] = 0.25
            A[(i+1)*w-1,(i+2)*w-1] = 0.25
            A[(i+1)*w-1,(i+1)*w-2] = 0.25
        for j in range(1,w-1):
            A[j,j] = 0.25
            A[j,j-1] = 0.25
            A[j,j+1] = 0.25
            A[j,w+j] = 0.25

            A[(h-1)*w+j,(h-1)*w+j] = 0.25
            A[(h-1)*w+j,(h-1)*w+j-1] = 0.25
            A[(h-1)*w+j,(h-1)*w+j+1] = 0.25
            A[(h-1)*w+j,(h-2)*w+j] = 0.25
        
        A[0,0] = 1./3
        A[0,1] = 1./3
        A[0,w] = 1./3

        A[w-1,w-1] = 1./3
        A[w-1,w-2] = 1./3
        A[w-1,2*w-1] = 1./3

        A[(h-1)*w, (h-1)*w] = 1./3
        A[(h-1)*w, (h-1)*w+1] = 1./3
        A[(h-1)*w, (h-2)*w] = 1./3

        A[h*w-1, h*w-1] = 1./3
        A[h*w-1, h*w-2] = 1./3
        A[h*w-1, (h-1)*w-1] = 1./3
    
    elif mode == 'exp':
        link_num = int(math.log(total_client, 2))
        for i in range(total_client):
            A[i,i] = 1./(1+link_num)
            for j in range(link_num):
                A[i,(i+2**j)%total_client] = 1./(1+link_num)
    elif mode == 'full':
        A = torch.ones((total_client, total_client)) / total_client
    else:
        raise NotImplementedError
    # print(A.sum(axis=1))
    return A