from dsl.gdl import *
from util.connected import *
import copy
import datetime


def generalize(GDL_program, data, target_labeled_graphs):
  # current GDL program and its score
  best_GDL_program = copy.deepcopy(GDL_program)
  # print("Start")
  # print(GDL_program.nodeVars)
  # print(GDL_program.edgeVars)
  best_score = eval_GDL_program_on_graphs_GC_Score(GDL_program, data, target_labeled_graphs)
  # print("End")

  # Remove edges unreachable nodes will be removed
  #(best_GDL_program, best_score) = remove_edges(best_GDL_program, best_score, data, target_labeled_graphs) 
  
  (best_GDL_program1, best_score1) = remove_edges_from_front(best_GDL_program, best_score, data, target_labeled_graphs) 

  (best_GDL_program2, best_score2) = remove_edges_from_back(best_GDL_program, best_score, data, target_labeled_graphs) 

  if best_score1 >= best_score2:
    best_GDL_program = best_GDL_program1
    best_score = best_score1
  else:
    best_GDL_program = best_GDL_program2
    best_score = best_score2

  # print("Remove edges")
  # print(best_GDL_program.nodeVars)
  # print(best_GDL_program.edgeVars)

  best_GDL_program =  process_GDL_program_edges(best_GDL_program)
  # print("Process edges")
  # print(best_GDL_program.nodeVars)
  # print(best_GDL_program.edgeVars)
  
  (best_GDL_program, best_score) = generalize_edge(best_GDL_program, best_score, data, target_labeled_graphs) 
  # print("Generalize edge")
  # print(best_GDL_program.nodeVars)
  # print(best_GDL_program.edgeVars)
  
  (best_GDL_program, best_score) = generalize_node(best_GDL_program, best_score, data, target_labeled_graphs) 
  # print("Generalize node")
  # print(best_GDL_program.nodeVars)
  # print(best_GDL_program.edgeVars)
  
  # print("Best Score : {}".format(best_score))

  #additional refining for small dataset
  if len(data.train_graphs) < 1000:
    (best_GDL_program, best_score) = generalize_all(best_GDL_program, best_score, data, target_labeled_graphs) 
  # print("Generalize all")
  # print(best_GDL_program.nodeVars)
  # print(best_GDL_program.edgeVars)
  # Remove unreachable nodes
  best_GDL_program = remove_unreachable_nodes(best_GDL_program)
  return best_GDL_program

def process_unreachable_nodes(GDL_program):
  node_indices_after = set()
  for _, (_, fr_node, to_node) in enumerate(GDL_program.edgeVars):
    node_indices_after.add(fr_node)
    node_indices_after.add(to_node)
  for index in range(len(GDL_program.nodeVars)):
    if not index in node_indices_after:
      GDL_program.nodeVars[index] = {}
  return GDL_program


def process_GDL_program_edges(GDL_program):
  current_abs_edges = copy.deepcopy(GDL_program.edgeVars)
  #print(GDL_program.nodeVars)
  #print(GDL_program.edgeVars)
  #print(current_abs_edges)
  new_edgeVars = [current_abs_edges[0]]

  reachable = set()
  reachable.add(current_abs_edges[0][1])
  reachable.add(current_abs_edges[0][2])
  candidates = set()
  for i in range(len(current_abs_edges)):
    candidates.add(i)
  candidates.remove(0)

  while (len(candidates) > 0):
    tmp_candidates = copy.deepcopy(candidates)
    for _, val in enumerate(tmp_candidates):
      if (current_abs_edges[val][1] in reachable) or (current_abs_edges[val][2] in reachable):
        reachable.add(current_abs_edges[val][1])
        reachable.add(current_abs_edges[val][2])
        new_edgeVars.append(current_abs_edges[val])
        candidates.remove(val)
    

  new_GDL_program = GDL ()
  new_GDL_program.nodeVars = copy.deepcopy(GDL_program.nodeVars)
  new_GDL_program.edgeVars = new_edgeVars 
  
  return new_GDL_program

def remove_edges_from_front(GDL_program, current_score, data, target_labeled_graphs) :

  # print(len(target_labeled_graphs))
  best_GDL_program = GDL_program
  best_score = current_score
  edge_idx = 0
  #remove edge
  while(edge_idx < len(best_GDL_program.edgeVars)-1):
    new_GDL_program = copy.deepcopy(best_GDL_program)
    new_GDL_program.edgeVars.pop(edge_idx)
    # print("new GDL program")
    # print(new_GDL_program.edgeVars)
    if not (is_connected(new_GDL_program)) : 
      # print("This graph Is not connected")
      pgm = separate_a_gdl_program(new_GDL_program)
      new_score = eval_GDL_program_on_graphs_GC_Score(pgm, data, target_labeled_graphs)
      if (new_score >= best_score):
        best_GDL_program = pgm
        best_score = new_score
        # print()
        # print("NewAbsGraph")
        # print(pgm.nodeVars)
        # print(pgm.edgeVars)
        # print()
        # print("New Score : {}".format(new_score))
        edge_idx = 0
        continue
      else:
        edge_idx = edge_idx + 1
        continue

    new_GDL_program = process_GDL_program_edges(new_GDL_program)
    new_score = eval_GDL_program_on_graphs_GC_Score(new_GDL_program, data, target_labeled_graphs)

    if (new_score >= best_score):
      best_GDL_program = new_GDL_program
      best_score = new_score
      # print()
      # print("NewAbsGraph")
      # print(new_GDL_program.nodeVars)
      # print(new_GDL_program.edgeVars)
      # print()
      # print("New Score : {}".format(new_score))
    else:
      # print("Not Improved")
      # print("edge_idx : {}".format(edge_idx))
      # print("edge var len : {}".format(len(best_GDL_program.edgeVars)))
      edge_idx = edge_idx + 1
  # sys.exit()
  best_GDL_program = remove_unreachable_nodes(best_GDL_program)
  # print("Best GDL program")
  # print(best_GDL_program.nodeVars)
  # print(best_GDL_program.edgeVars)
  # sys.exit()
  return (best_GDL_program, best_score)


def remove_edges_from_back(GDL_program, current_score, data, target_labeled_graphs) :

  best_GDL_program = GDL_program
  best_score = current_score
  edge_idx = len(GDL_program.edgeVars) - 1
  #remove edge
  while(edge_idx >= 0):
    new_GDL_program = copy.deepcopy(best_GDL_program)
    new_GDL_program.edgeVars.pop(edge_idx)
    if not (is_connected(new_GDL_program)) : 
      #print("This graph Is not connected")
      edge_idx = edge_idx - 1
      continue

    new_GDL_program = process_GDL_program_edges(new_GDL_program)
    new_score = eval_GDL_program_on_graphs_GC_Score(new_GDL_program, data, target_labeled_graphs)

    if (new_score >= best_score):
      best_GDL_program = new_GDL_program
      best_score = new_score
      #print()
      #print("NewAbsGraph")
      #print()
      #print("New Score : {}".format(new_score))

    edge_idx = edge_idx - 1
  return (best_GDL_program, best_score)
 



def generalize_edge(GDL_program, current_score, data, target_labeled_graphs) :
  best_GDL_program = GDL_program
  best_score = current_score
  edge_idx = len(best_GDL_program.edgeVars) - 1
  while(edge_idx >= 0):
    new_GDL_program = copy.deepcopy(best_GDL_program)
    new_itv = {}
    new_from = new_GDL_program.edgeVars[edge_idx][1]
    new_to = new_GDL_program.edgeVars[edge_idx][2] 
    new_GDL_program.edgeVars[edge_idx] = (new_itv, new_from, new_to)
    start = datetime.datetime.now()
    new_score = eval_GDL_program_on_graphs_GC_Score(new_GDL_program, data, target_labeled_graphs)
    if (new_score >= best_score):
      best_GDL_program = new_GDL_program
      best_score = new_score
    edge_idx = edge_idx - 1
    finish = datetime.datetime.now() 
    elapsed = finish - start
    if(elapsed > datetime.timedelta(seconds = data.timeLimit)):
      break
  return (best_GDL_program, best_score)



def generalize_node(GDL_program, current_score, data, target_labeled_graphs) :
  best_GDL_program = GDL_program
  best_score = current_score
  node_idx = len(best_GDL_program.nodeVars) - 1
  while(node_idx >= 0):
    new_GDL_program = copy.deepcopy(best_GDL_program)
    new_GDL_program.nodeVars[node_idx] = {}
    start = datetime.datetime.now()
    new_score = eval_GDL_program_on_graphs_GC_Score(new_GDL_program, data, target_labeled_graphs)

    if (new_score >= best_score):
      best_GDL_program = new_GDL_program
      best_score = new_score

    node_idx = node_idx - 1
    finish = datetime.datetime.now() 
    elapsed = finish - start
    if(elapsed > datetime.timedelta(seconds = data.timeLimit)):
      break
  return (best_GDL_program, best_score)



def generalize_all(GDL_program, current_score, data, target_labeled_graphs) :
  best_GDL_program = GDL_program
  best_score = current_score
  current_GDL_program = GDL_program
  flag = False
  #start_0 = datetime.datetime.now()
  #print("Widening node intervals")
  for node_idx in range(len(GDL_program.nodeVars)):
    itvs = current_GDL_program.nodeVars[node_idx]
    if itvs == {}:
      continue
    else:
      for _, feat_idx in enumerate(itvs):
        (a, b) = itvs[feat_idx]
        if a != -99 and b != 99:
          new_GDL_program = copy.deepcopy(current_GDL_program)
          new_itvs = copy.deepcopy(itvs)
          new_itvs[feat_idx] = (a,99)
          new_GDL_program.nodeVars[node_idx] = new_itvs
          
          start = datetime.datetime.now() 
          new_score = eval_GDL_program_on_graphs_GC_Score(new_GDL_program, data, target_labeled_graphs)

          if (new_score >= best_score):
            flag = True
            best_GDL_program = new_GDL_program
            best_score = new_score
          finish = datetime.datetime.now() 
          elapsed = finish - start
          if(elapsed > datetime.timedelta(seconds = data.timeLimit)):
            return (best_GDL_program, best_score)

          new_GDL_program = copy.deepcopy(current_GDL_program)
          new_itvs = copy.deepcopy(itvs)
          new_itvs[feat_idx] = (-99,b)
          new_GDL_program.nodeVars[node_idx] = new_itvs

          start = datetime.datetime.now() 
          new_score = eval_GDL_program_on_graphs_GC_Score(new_GDL_program, data, target_labeled_graphs)

          if (new_score >= best_score):
            flag = True
            best_GDL_program = new_GDL_program
            best_score = new_score
          finish = datetime.datetime.now() 
          elapsed = finish - start
          if(elapsed > datetime.timedelta(seconds = data.timeLimit)):
            return (best_GDL_program, best_score)


        elif (a != -99 and b == 99) or (a == -99 and b != 99):
          new_GDL_program = copy.deepcopy(current_GDL_program)
          new_itvs = copy.deepcopy(itvs)
          del new_itvs[feat_idx]
          #new_itvs[feat_idx] = (-99,99)
          new_GDL_program.nodeVars[node_idx] = new_itvs

          start = datetime.datetime.now() 
          new_score = eval_GDL_program_on_graphs_GC_Score(new_GDL_program, data, target_labeled_graphs)
                    
          if (new_score >= best_score):
            flag = True
            best_GDL_program = new_GDL_program
            best_score = new_score
          finish = datetime.datetime.now() 
          elapsed = finish - start
          if(elapsed > datetime.timedelta(seconds = data.timeLimit)):
            return (best_GDL_program, best_score)

        else:
          continue

  #print("Widening edge intervals")
  for edge_idx in range(len(GDL_program.edgeVars)):
    (itvs, p, q) = current_GDL_program.edgeVars[edge_idx]
    if itvs == {}:
      continue
    else:
      for _, feat_idx in enumerate(itvs):
        (a, b) = itvs[feat_idx]
        if a != -99 and b != 99:
          new_GDL_program = copy.deepcopy(current_GDL_program)
          new_itvs = copy.deepcopy(itvs)
          new_itvs[feat_idx] = (a,99)
          new_GDL_program.edgeVars[edge_idx] = (new_itvs, p, q)

          start = datetime.datetime.now() 
          new_score = eval_GDL_program_on_graphs_GC_Score(new_GDL_program, data, target_labeled_graphs)
          
          if (new_score >= best_score):
            flag = True
            best_GDL_program = new_GDL_program
            best_score = new_score
          finish = datetime.datetime.now() 
          elapsed = finish - start
          if(elapsed > datetime.timedelta(seconds = data.timeLimit)):
            return (best_GDL_program, best_score)

          new_GDL_program = copy.deepcopy(current_GDL_program)
          new_itvs = copy.deepcopy(itvs)
          new_itvs[feat_idx] = (-99,b)
          new_GDL_program.edgeVars[edge_idx] = (new_itvs, p, q)

          start = datetime.datetime.now() 
          new_score = eval_GDL_program_on_graphs_GC_Score(new_GDL_program, data, target_labeled_graphs)
          
          if (new_score >= best_score):
            flag = True
            best_GDL_program = new_GDL_program
            best_score = new_score
          finish = datetime.datetime.now() 
          elapsed = finish - start
          if(elapsed > datetime.timedelta(seconds = data.timeLimit)):
            return (best_GDL_program, best_score)
          
        elif (a != -99 and b == 99) or (a == -99 and b != 99):
          new_GDL_program = copy.deepcopy(current_GDL_program)
          new_itvs = copy.deepcopy(itvs)
          del new_itvs[feat_idx]
          #new_itvs[feat_idx] = (-99,99)
          new_GDL_program.edgeVars[edge_idx] = (new_itvs, p, q)

          start = datetime.datetime.now() 
          new_score = eval_GDL_program_on_graphs_GC_Score(new_GDL_program, data, target_labeled_graphs)
          
          if (new_score >= best_score):
            flag = True
            best_GDL_program = new_GDL_program
            best_score = new_score
          finish = datetime.datetime.now() 
          elapsed = finish - start
          if(elapsed > datetime.timedelta(seconds = data.timeLimit)):
            return (best_GDL_program, best_score)
        else:
          continue

  #print("Removing edge")
  edge_idx = len(GDL_program.edgeVars) - 1
  while(edge_idx >= 0):
    new_GDL_program = copy.deepcopy(current_GDL_program)
    new_GDL_program.edgeVars.pop(edge_idx)
    if not (is_connected(new_GDL_program)) : 
      edge_idx = edge_idx - 1
      continue
    new_GDL_program =  process_GDL_program_edges(new_GDL_program)
    start = datetime.datetime.now() 
    new_score = eval_GDL_program_on_graphs_GC_Score(new_GDL_program, data, target_labeled_graphs)
    if (new_score >= best_score):
      flag = True
      #print("edge_idx : {}".format(edge_idx))
      best_GDL_program = new_GDL_program
      best_score = new_score
    finish = datetime.datetime.now() 
    elapsed = finish - start
    if(elapsed > datetime.timedelta(seconds = data.timeLimit)):
      return (best_GDL_program, best_score)
    edge_idx = edge_idx - 1

  if flag == False : # no improvement
    return (best_GDL_program, best_score)

  else:
    return generalize_all(best_GDL_program, best_score, data, target_labeled_graphs)




def separate_a_gdl_program(program):
  separated_program_set = set()
  # Construct first pgm
  indices = set()
  fst_pgm = GDL()
  fst_pgm.nodeVars = copy.deepcopy(program.nodeVars)
  # print(program.edgeVars)
  fst_pgm.edgeVars = [program.edgeVars[0]]
  indices.add(program.edgeVars[0][1])
  indices.add(program.edgeVars[0][2])
  old_indices = set()
  while len(old_indices) != len(indices):
    old_indices = copy.deepcopy(indices)
    for _, val in enumerate(program.edgeVars):
      if val[1] in indices and val[2] in indices:
        continue
      elif val[1] in indices and val[2] not in indices:
        indices.add(val[2])
        fst_pgm.edgeVars.append(val)
      elif val[1] not in indices and val[2] in indices:
        indices.add(val[1])
        fst_pgm.edgeVars.append(val)

  # Construct second pgm
  snd_pgm = GDL()
  snd_pgm.nodeVars = copy.deepcopy(program.nodeVars)
  snd_pgm.edgeVars = []
  new_indices = set()
  for _, val in enumerate(program.edgeVars):
    if val[1] not in indices and val[2] not in indices:
      new_indices.add(val[1])
      new_indices.add(val[2])
      snd_pgm.edgeVars.append(val)
  # fst_pgm = remove_unreachable_nodes(fst_pgm)
  # snd_pgm = remove_unreachable_nodes(snd_pgm)
  if len(fst_pgm.edgeVars) > len(snd_pgm.edgeVars):
    return fst_pgm
  else:
    return snd_pgm