import networkx as nx
import random
import os
import argparse
from tqdm import tqdm
import numpy as np
import re


def DAG_random_walk(source_node, target_node, num_nodes):
    stack = [source_node]
    visited = [] # to eliminate cycles

    while stack != []:
        cur_node = stack.pop()
        visited.append(cur_node)
        if cur_node == target_node:
            return visited
        
        remaining = [element for element in range(cur_node, num_nodes) if element in G[cur_node] and (element in reachability[target_node] or element == target_node)] #if we want the path to contain cycles, we should remove "and element not in visited"
        # if len(remaining) == 0:
        #     return random_walk(source_node, target_node, num_nodes) # for non-DAGs

        next_node = random.choice(remaining)
        stack.append(next_node)

    return visited

def create_dataset(num_nodes, num_path_per_pair):
    train_set = []
    for (s, t) in train_node_pair:
        for _ in range(num_path_per_pair):
            # path = nx.shortest_path(G, s, t)
            path = DAG_random_walk(s, t, num_nodes)
            path = [s, t] + path
            train_set.append(path)

    test_set = []
    for (s, t) in test_node_pair:
        # path = nx.shortest_path(G, s, t)
        path = DAG_random_walk(s, t, num_nodes)
        path = [s, t] + path
        test_set.append(path)
                    
    return train_set, test_set


def format_data(data):
    return f"{data[0]} {data[1]} " + ' '.join(str(num) for num in data[2:]) + '\n'
        
def write_dataset(dataset, file_name):
    with open(file_name, "w") as file:
        for data in dataset:
            file.write(format_data(data))

def get_reachable_nodes(random_digraph, target_node):  
    # Get the transitive closure of the graph  
    TC = nx.transitive_closure(random_digraph)  
    # Find the predecessors in the transitive closure (nodes that can reach the target_node)  
    reachable_from = TC.predecessors(target_node)  
    return list(reachable_from)

def obtain_reachability(random_digraph):
    reachability = {}  
    pairs = 0
    for node in random_digraph.nodes():  
        reachability[node] = get_reachable_nodes(random_digraph, node)
        pairs += len(reachability[node])
    return reachability, pairs

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Generate a random graph based on the given parameters.')  
    parser.add_argument('--num_nodes', type=int, default=100)  
    parser.add_argument('--edge_prob', type=float, default=0.15, help='Probability of creating an edge between two nodes')
    parser.add_argument('--DAG', type=bool, default=True, help='Whether the graph should be a Directed Acyclic Graph')
    parser.add_argument('--num_path_per_pair', type=int, default=10)
    parser.add_argument('--chance_in_train', type=float, default=0.2, help='Chance of a pair being in the training set')

    parser.add_argument('--unreachable', action='store_true', help='Allow unreachable pairs in the graph')  
    parser.add_argument('--load_graph', type=str, default=None)  

    args = parser.parse_args() 
    num_nodes = args.num_nodes
    DAG = args.DAG
    edge_prob = args.edge_prob
    test_node_pair_ratio = 1 - args.chance_in_train
    num_path_per_pair = args.num_path_per_pair
    
    G = nx.DiGraph()
    # Add nodes to the graph
    for i in range(num_nodes):
        G.add_node(i)
    
    # Add edges to the graph based on the probability
    for i in range(num_nodes):
        for j in range(num_nodes):
            if DAG:
                if i < j and random.random() < edge_prob: 
                    G.add_edge(i, j)
            else:
                if i != j and random.random() < edge_prob:
                    G.add_edge(i, j)

    reachability, feasible_pairs = obtain_reachability(G)
    
    train_node_pair = []
    test_node_pair = []

    for target in range(num_nodes):
        if len(reachability[target]) > 0:
            train_sources_num = max(int(len(reachability[target]) * args.chance_in_train), 1)
            train_sources = random.sample(reachability[target], train_sources_num)
            for source in reachability[target]:
                if source in train_sources and target - source <= 50:
                    train_node_pair.append((source, target))
                else:
                    test_node_pair.append((source, target))


    train_set, test_set = create_dataset(num_nodes, num_path_per_pair)
    folder_name = os.path.join(os.path.dirname(__file__), f'{num_nodes}_1_1')
    if not os.path.exists(folder_name):  
        os.makedirs(folder_name)

    true_reach = np.zeros(shape=(num_nodes, num_nodes))
    for i in range(num_nodes):
        for j in reachability[i]:
            true_reach[i][int(j)] = 1
            
            
    adj = nx.adjacency_matrix(G).todense()
    np.save(os.path.join(os.path.dirname(__file__), f'{num_nodes}_1_1/true_adj_matrix.npy'), adj)
    np.save(os.path.join(os.path.dirname(__file__), f'{num_nodes}_1_1/true_reach_matrix.npy'), true_reach)
    write_dataset(train_set, os.path.join(os.path.dirname(__file__), f'{num_nodes}_1_1/simple_train.txt'))
    write_dataset(test_set, os.path.join(os.path.dirname(__file__),  f'{num_nodes}_1_1/simple_test.txt'))
    nx.write_graphml(G, os.path.join(os.path.dirname(__file__), f'{num_nodes}_1_1/path_graph.graphml') )

    adj = nx.adjacency_matrix(G).todense()
    prob = 0.5

    train2train = []
    train2test = []
    with open(f'./{num_nodes}_1_1/simple_train.txt', 'r') as f:
        cnt = 0
        for line in f:
            if cnt % args.num_path_per_pair == 0:
                path = re.findall(r'\d+', line)
                if np.random.random() < prob:
                    train2train.append(path)
                else:
                    train2test.append(path)
            cnt += 1
                    
    test2train = []  
    test2test = []  
    with open(f'./{num_nodes}_1_1/simple_test.txt', 'r') as f:
        for line in f:
            path = re.findall(r'\d+', line)
            if np.random.random() < prob:
                test2train.append(path)
            else:
                test2test.append(path)
        
    print(len(train2train), len(train2test), len(test2train), len(test2test))
    with open(f'./{num_nodes}_1_1/aug_train.txt', 'w') as f:
        for path in train2train:
            for _ in range(10):
                f.write(f'{path[0]} {path[1]}\n')
        for path in test2train:
            for _ in range(10):
                f.write(f'{path[0]} {path[1]}\n')
        
    with open(f'./{num_nodes}_1_1/train2train.txt', 'w') as f:
        for path in train2train:
            f.write(f'{path[0]} {path[1]}\n')
            
    with open(f'./{num_nodes}_1_1/train2test.txt', 'w') as f:
        for path in train2test:
            f.write(f'{path[0]} {path[1]}\n')

    with open(f'./{num_nodes}_1_1/test2train.txt', 'w') as f:
        for path in test2train:
            f.write(f'{path[0]} {path[1]}\n')
            
    with open(f'./{num_nodes}_1_1/test2test.txt', 'w') as f:
        for path in test2test:
            f.write(f'{path[0]} {path[1]}\n')
            