import networkx as nx
import numpy as np

from pattern_match import find_pattern_list
import os
import pickle
import json
import random

# Example usage: generate Erdős–Rényi graphs with edge probability 0.3
min_nodes = 5
max_nodes = 15


def generate_random_graph_fixed_edges(n, num_edges,directed=False):
    # Check if the number of edges is valid
    max_edges = n * (n - 1) // 2  # Maximum possible edges in an undirected graph without self-loops
    if num_edges > max_edges:
        raise ValueError(f"Too many edges: max possible edges for {n} nodes is {max_edges}")

    # Initialize an empty graph with n nodes
    if directed:
        dag=nx.DiGraph()
        dag.add_nodes_from(range(n))
        while dag.number_of_edges() < num_edges:
            # Choose two distinct nodes randomly
            u, v = random.sample(range(n), 2)
            
            # Add edge if it doesn't create a cycle
            if not dag.has_edge(u, v): # and not nx.has_path(dag, v, u):
                dag.add_edge(u, v)
        return dag
    else:
        G = nx.Graph()
    G.add_nodes_from(range(n))
    
    # Generate all possible edges
    possible_edges = [(i, j) for i in range(n) for j in range(i + 1, n)]
    
    # Randomly sample edges without replacement
    edges = random.sample(possible_edges, num_edges)
    
    # Add the sampled edges to the graph
    G.add_edges_from(edges)
    
    return G



def find_directed_triangles(G):
    # Use a set to avoid duplicate triangles (order-independent)
    triangles = set()
    for i in G.nodes():
        for j in G.successors(i):
            for k in G.successors(j):
                if G.has_edge(k, i):
                    # Sort the triangle to ensure uniqueness
                    triangle = tuple(sorted([i, j, k]))
                    triangles.add(triangle)
    return triangles




graph=nx.erdos_renyi_graph(max_nodes,p=0.5,directed=True)

def generate_adj_list(g):
    txt=''
    adjacency_list = list(nx.generate_adjlist(g))
    # print(adjacency_list)
    # Display the adjacency list
    for line in adjacency_list:
        line=line.split(' ')
        strat_node=line[0]
        txt+=', '+strat_node+' : '
        for idx in range(len(line)):
            if line[idx]==' ':continue
            if idx==0:continue
            txt+=line[idx]+' '
        # print(line)
    return txt[1:]

def generate_edge_list(g):
    txt=str(g.edges())[1:-1].replace(', ',' ')
    
    txt=txt.replace(') (',' | ')
    txt=txt.replace('(','')
    txt=txt.replace(')','')
    # print(txt)
    return txt

FFL=find_pattern_list(graph,'tFFL')
adj_list=generate_adj_list(graph)

def ans_generation(triangles,node=3):
    string=''
    for t in triangles:
        if node==3:
            string+=f'{t[0]} {t[1]} {t[2]} , '
        elif node==4:
            string+=f'{t[0]} {t[1]} {t[2]} {t[3]} , '
    return string[:-1]

ans=ans_generation(FFL)
print(ans)
graph_dicts={}
graph_list=[]
graph_description_adj_list=[]
graph_description_edge_list=[]
ans_list=[]
graph_set=set()
max_num=5000000
sub_count=0
counts=0
counts_bag=set()
max_length=0
max_length_edge=0
max_q_length=0
max_q_length_e=0
max_a_length=0
name_1='FFL'
name_2='nd-diamond'
base_path=f''
if os.path.exists(base_path)==False:
    os.makedirs(base_path)
di=True

while True:
# for i in range(max_num):

    n=random.randint(5,16)
    # n=10
    if di:
        max_edges = int((n * (n - 1) // 2))
    else:
        max_edges = int((n * (n - 1) // 2)*0.5)
    edges_num=random.randint(2, max_edges)
    graph=generate_random_graph_fixed_edges(n,edges_num,directed=di)
    description_adj=generate_adj_list(graph)
    description_edge=generate_edge_list(graph)
    if description_adj not in graph_set:
        graph_set.add(description_adj)
        FFL=find_pattern_list(graph,name_1)
        # FBL=find_pattern_list(graph,'FBL')
        # vs=find_pattern_list(graph,'vs')
        diamond=find_pattern_list(graph,name_2)
        if len(FFL)==0 and len(diamond)==0 :continue
        if len(FFL)>20  or len(diamond)>20 :continue
        # if 'FBL'+str(len(FBL)) not in graph_dicts:
        #     graph_dicts['FBL'+str(len(FBL))]=[]
        if name_1+str(len(FFL)) not in graph_dicts:
            graph_dicts[name_1+str(len(FFL))]=[]
        # if 'vs'+str(len(vs)) not in graph_dicts:
        #     graph_dicts['vs'+str(len(vs))]=[]
        if name_2+str(len(diamond)) not in graph_dicts:
            graph_dicts[name_2+str(len(diamond))]=[]
        # graph_dicts['FBL'+str(len(FBL))].append(sub_count)
        graph_dicts[name_1+str(len(FFL))].append(sub_count)
        graph_dicts[name_2+str(len(diamond))].append(sub_count)
        FFL_ans=ans_generation(FFL,3)
        diamond_ans=ans_generation(diamond,4)


        ans_list.append({name_1:FFL_ans,name_2:diamond_ans})
        graph_description_adj_list.append(description_adj)
        graph_description_edge_list.append(description_edge)
        description_length_adj=len(description_adj.split(' '))
        description_length_edge=len(description_edge.split(' '))
        ans_length=max(len(FFL_ans.split(' ')),len(diamond_ans.split(' ')))
        max_sentence_length_adj=description_length_adj+ans_length
        if max_sentence_length_adj>max_length:
            max_length=max_sentence_length_adj
        max_sentence_length_edge=description_length_edge+ans_length
        if max_sentence_length_edge>max_length_edge:
            max_length_edge=max_sentence_length_edge
        if description_length_adj>max_q_length:
            max_q_length=description_length_adj
        if description_length_edge>max_q_length_e:
            max_q_length_e=description_length_edge
        if ans_length>max_a_length:
            max_a_length=ans_length
        counts+=1
        sub_count+=1

    if counts%10000==0 and counts!=0 and counts not in counts_bag:
        print(counts)
        

        with open(os.path.join(base_path,f'tiny_{int(counts/10000)}_graphs.pkl'),'wb') as f:
            pickle.dump(graph_list,f)
        with open(os.path.join(base_path,f'tiny_{int(counts/10000)}_idx.json'),'w') as f:
            json.dump(graph_dicts,f)

        with open(os.path.join(base_path,f'tiny_{int(counts/10000)}_graphs_description_adj.pkl'),'wb') as f:
            pickle.dump(graph_description_adj_list,f)
        with open(os.path.join(base_path,f'tiny_{int(counts/10000)}_graphs_description_edge.pkl'),'wb') as f:
            pickle.dump(graph_description_edge_list,f)
            
        with open(os.path.join(base_path,f'tiny_{int(counts/10000)}_ans.pkl'),'wb') as f:
            pickle.dump(ans_list,f)
        counts_bag.add(counts)
        for key in graph_dicts.keys():
            print(key,len(graph_dicts[key]))
        graph_dicts={}
        graph_list=[]
        graph_description_adj_list=[]
        ans_list=[]
        sub_count=0
        print(f'saved, max length:{max_length_edge}, {max_q_length}, {max_q_length_e}, {max_a_length},')
    if counts==max_num:
        break
print(adj_list)
print(FFL)
