'''
Generates a specific dag shape 
'''

import argparse
import numpy as np
import os
import random

def generate_collition(n):
    '''
    Generates a collision graph (all nodes go into one node)
    '''
    graph = np.zeros((n,n))
    graph[:, -1] = 1
    graph[-1, -1] = 0
        
    return graph

def generate_collition_plus_independent(n):
    graph = np.zeros((n,n))
    graph[:, -1] = 1
    graph[-1, -1] = 0
    for i in range(1,n-1,2):
        graph[i, -1] = 0
    return graph

def generate_star(n):
    '''
    Generates a star graph
    '''
    raise NotImplementedError


def generate_chain(n):
    '''
    Generates a chain graph
    '''
    graph = np.zeros((n,n))
    for i in range(n-1):
        graph[i, i+1] = 1
    return graph

def one_to_all(n): # the one to rule them all
    graph = np.zeros((n,n))
    graph[0, :] = 1
    graph[0, 0] = 0
    return graph

def generate_collition_chain(n):
    graph = np.zeros((n,n))
    for i in range(n-1):
        graph[i, i+1] = 1
    graph[:, -1] = 1
    graph[-1, -1] = 0
    return graph

def add_random_edges(graph, n_edges):
    '''
    Adds random edges to the graph
    '''

    candidates = [(i,j) for i in range(1,n) for j in range(i+1, n) if graph[i,j] == 0]
    assert len(candidates) >= n_edges, "Not enough candidates to add edges"
    selection = np.choose(candidates, n_edges, replace=False)
    for i, j in selection:
        graph[i,j] = 1
    return graph


def save_dag(graph, path):
    np.save(path, graph)

if __name__ == "__main__":
    argparser = argparse.ArgumentParser()
    argparser.add_argument("-seed", type=int, default=123, help="random seed")
    argparser.add_argument("-shape", type=str, required=True, help="shape of the graph")
    argparser.add_argument("-n", type=int, required=True, help="number of nodes")
    argparser.add_argument("-path", type=str, required=True, help="path to save the graph")
    
    args = argparser.parse_args()

    np.random.seed(args.seed)
    random.seed(args.seed)  


    if args.shape == "collision":
        graph = generate_collition(args.n)
        # n = args.n
        # extra_edges = int((n**2 - n)/2 * 0.3)
        # graph = add_random_edges(graph, extra_edges)
    elif args.shape == "collision_plus_independent":
        graph = generate_collition_plus_independent(args.n)
    elif args.shape == "chain":
        graph = generate_chain(args.n)
    elif args.shape == "one_to_all":
        graph = one_to_all(args.n)
    elif args.shape == "collision_chain":
        graph = generate_collition_chain(args.n)
    else:
        raise ValueError("shape not recognized")
    
    save_dag(graph, args.path)