from dsl.gdl import *
from util.connected import graph_is_connected, separate_a_graph
from synthesis.generalize import generalize, process_GDL_program_edges
import sys


#def synthesis_bu(data):
def learn_a_GDL_program(data, target_graph):
  target_labeled_graphs = data.train_graphs & data.label_to_graphs[data.graph_to_label[target_graph]]
  # print("Learn")
  learned_tuples_set = set()
  default_score = float(len(target_labeled_graphs & data.train_graphs)/(len(data.train_graphs) + data.epsilon))
  #print("Default score : {}".format(data.default_score))
  #Given graph is connected  

  if graph_is_connected(data.graphs[target_graph],data):
    GDL_program = init_GDL_program_from_idx(data, target_graph)
    #print(len(GDL_program.nodeVars))
    GDL_program = GDL_program_process_edges(GDL_program)
    # print("init done")
    # print(GDL_program.nodeVars)
    # print(GDL_program.edgeVars)
    # sys.exit()
    learned_GDL_program = generalize(GDL_program, data, target_labeled_graphs)
    score = eval_GDL_program_on_graphs_GC_Score(learned_GDL_program, data, target_labeled_graphs)
    chosen_graphs = eval_GDL_program_on_graphs_GC(learned_GDL_program, data)
    #score, chosen_graphs = eval_GDL_program_on_graphs_GC_Score_graphs(learned_GDL_program, data, target_labeled_graphs)
    # The algorithm abandons the learned GDL program if the score is lower than the default score or the number of chosen graphs is 1.
    if (score < default_score * data.expected) or (len(chosen_graphs & data.train_graphs) == 1):
      print("This Learning failed!!")
    else:
      learned_tuple = (data.graph_to_label[target_graph], learned_GDL_program, score, frozenset(chosen_graphs))
      #learned_tuple = (learned_GDL_program, frozenset(chosen_graphs), score)
      learned_tuples_set.add(learned_tuple)
  #Given graph is not connected        
  else:    
    concrete_graphs_set = separate_a_graph(data.graphs[target_graph],data)
    for _, concrete_graph in enumerate(concrete_graphs_set):
      GDL_program = init_GDL_program_from_graph(data, concrete_graph)
      GDL_program = process_GDL_program_edges(GDL_program)
      learned_GDL_program = generalize(GDL_program, data, target_labeled_graphs)
      score = eval_GDL_program_on_graphs_GC_Score(learned_GDL_program, data, target_labeled_graphs)
      chosen_graphs = eval_GDL_program_on_graphs_GC(learned_GDL_program, data)   
      #score, chosen_graphs = eval_GDL_program_on_graphs_GC_Score_graphs(learned_GDL_program, data, target_labeled_graphs)
      if (score < default_score * data.expected) or (len(chosen_graphs & data.train_graphs) == 1):
        #if (score < data.default_score * data.expected):
        print("This learning failed!!")
      else:
        #learned_tuple = (learned_GDL_program, frozenset(chosen_graphs), score)
        learned_tuple = (data.graph_to_label[target_graph], learned_GDL_program, score, frozenset(chosen_graphs))            
        learned_tuples_set.add(learned_tuple)
  return learned_tuples_set


def GDL_program_process_edges(GDL_program):
  #print("AA")
  edge_set = set()
  covered_edge_set = set()
  covered_node_set = set()
  first_edge = GDL_program.edgeVars[0]
  my_edges = [first_edge]
  covered_edge_set.add((first_edge[1], first_edge[2]))
  covered_node_set.add(first_edge[1])
  covered_node_set.add(first_edge[2])
  while len(my_edges) < len(GDL_program.edgeVars):
    flag = True
    for _, edge in enumerate(GDL_program.edgeVars):
      fr_node = edge[1]
      to_node = edge[2]
      edge_set.add((fr_node, to_node))
      if (fr_node, to_node) in covered_edge_set:
        continue
      elif ((fr_node not in covered_node_set) and (to_node not in covered_node_set)):
        continue
      else:
        flag = False
      my_edges.append(edge)
      covered_edge_set.add((fr_node, to_node))
      covered_node_set.add(fr_node)
      covered_node_set.add(to_node)
    if flag :
      break
  GDL_program.edgeVars = my_edges   
  return GDL_program

#New Construction
def GDL_program_idx_processing(GDL_program):
  new_GDL_program = GDL()
  # new_GDL_program.nodeVars = []
  # new_GDL_program.edgeVars = []

  node_index_map = {}
  covered_node_cnt = 0
  covered_node_set = set()

  first_edge = GDL_program.edgeVars[0]
  
  fr_node = first_edge[1]
  node_index_map[fr_node] = 0
  covered_node_cnt += 1
  covered_node_set.add(fr_node)
  new_GDL_program.nodeVars.append(GDL_program.nodeVars[fr_node])
  
  
  to_node = first_edge[2]
  node_index_map[to_node] = 1
  covered_node_cnt += 1
  covered_node_set.add(to_node)
  new_GDL_program.nodeVars.append(GDL_program.nodeVars[to_node])
  new_GDL_program.edgeVars.append((first_edge[0], 0, 1))

  #print(GDL_program.edgeVars)
  for _, edge in enumerate(GDL_program.edgeVars):
    #print(edge)
    if edge[1] in covered_node_set and edge[2] in covered_node_set:
      continue
    elif edge[1] in covered_node_set and edge[2] not in covered_node_set:
      to_node = edge[2]
      node_index_map[to_node] = covered_node_cnt
      covered_node_cnt += 1
      covered_node_set.add(to_node)
      new_GDL_program.nodeVars.append(GDL_program.nodeVars[to_node])
      new_GDL_program.edgeVars.append((edge[0], node_index_map[edge[1]], node_index_map[to_node]))

    elif edge[1] not in covered_node_set and edge[2] in covered_node_set:
      fr_node = edge[1]
      node_index_map[fr_node] = covered_node_cnt
      covered_node_cnt += 1
      covered_node_set.add(fr_node)
      new_GDL_program.nodeVars.append(GDL_program.nodeVars[fr_node])
      new_GDL_program.edgeVars.append((edge[0], node_index_map[fr_node], node_index_map[edge[2]]))
    else:
      raise ValueError("This should not happen")

  #print(new_GDL_program.nodeVars)
  #print(new_GDL_program.edgeVars)
  return new_GDL_program




#Assume the concrete graph is undirected but the following can also be used in directed graphs
def init_GDL_program_from_idx(data, graph_idx):
  GDL_program = GDL()
  GDL_program.nodeVars = []
  GDL_program.edgeVars = []
  node_abs_node_map = {}
  # print("Init GDL program")
  # print(data.graphs[graph_idx])
  for i, val in enumerate(data.graphs[graph_idx][0]):
    node_feature = data.X_node[val]
    abs_node = {}
    for idx, feat_val in enumerate(node_feature):
      abs_node[idx] = (feat_val,feat_val) 
    GDL_program.nodeVars.append(abs_node)
    node_abs_node_map[val] = i
  for _, val in enumerate(data.graphs[graph_idx][1]):
    from_node = data.A[val][0]
    to_node = data.A[val][1]
    new_itv = {}
    edge_feature = data.X_edge[val]
    for idx, feat_val in enumerate(edge_feature):
      new_itv[idx] = (feat_val, feat_val)
    abs_edge = (new_itv, node_abs_node_map[from_node], node_abs_node_map[to_node])
    if to_node > from_node:
      GDL_program.edgeVars.append(abs_edge) 
  return GDL_program

#Assume the concrete graph is undirected but the following can also be used in directed graphs
def init_GDL_program_from_graph(data, concrete_graph):
  GDL_program = GDL()
  GDL_program.nodeVars = []
  GDL_program.edgeVars = []
  node_abs_node_map = {}

  for i, val in enumerate(concrete_graph[0]):
    node_feature = data.X_node[val]
    abs_node = {}
    for idx, feat_val in enumerate(node_feature):
      abs_node[idx] = (feat_val,feat_val) 
    GDL_program.nodeVars.append(abs_node)
    node_abs_node_map[val] = i
  for _, val in enumerate(concrete_graph[1]):
    from_node = data.A[val][0]
    to_node = data.A[val][1]
    new_itv = {}
    edge_feature = data.X_edge[val]
    for idx, feat_val in enumerate(edge_feature):
      new_itv[idx] = (feat_val, feat_val)
    abs_edge = (new_itv, node_abs_node_map[from_node], node_abs_node_map[to_node])
    if to_node > from_node:
      GDL_program.edgeVars.append(abs_edge)
  return GDL_program



