import pickle
from data_loader import * 
from synthesis.synthesis import learn_a_GDL_program
import argparse
import datetime
import os, sys
import datetime


# learn a GDL program from a target graph (target_graph)
def learn(data, target_graph):
  start = datetime.datetime.now()  
  tuple_set = learn_a_GDL_program(data, target_graph)
  finish = datetime.datetime.now()
  elapsed = finish - start
  # print(tuple_set)    
  if len(tuple_set) == 0:
    print("No learned GDL program")
    return
  # print("Tuple Set : {}".format(tuple_set))  
  print("Elapsed time from graph {}: {}".format(target_graph, elapsed))
  if not os.path.exists('datasets/{}/learned_GDL_programs'.format(data.name)):
    cmd = "mkdir datasets/{}/learned_GDL_programs".format(data.name)
    os.system(cmd)
  with open('datasets/{}/learned_GDL_programs/learned_GDL_program_from_{}.pickle'.format(data.name, target_graph), 'wb') as f:
    pickle.dump(tuple_set, f)

if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument("-d", "--dataset", default="MUTAG", help="Input dataset")  
  parser.add_argument("-e", "--epsilon", type=float, default=0.01, help="Input epsilon, default = 0.01")  
  parser.add_argument("-t", "--timelimit", type=int, default=10, help="Input time limit, default = 10")  
  #parser.add_argument('-x', '--expect', help="input expectation")
  parser.add_argument('-g', '--target_graph', help="target train graph")
  args = parser.parse_args()
  target_graph = int(args.target_graph)
  epsilon = float(args.epsilon)
  dataset = args.dataset

  if dataset == 'BBBP':
    data = load_BBBP()
  elif dataset == 'BACE':
    data = load_BACE()
  else:
    data = load_Data(dataset)

  # Hyperparameters
  data.timeLimit = args.timelimit
  data.epsilon = len(data.train_graphs) * args.epsilon 
  data.expected = 1.0


  start = datetime.datetime.now()  
  learn(data, target_graph)
  finish = datetime.datetime.now()
  elapsed = finish - start  
  #print()
  #print("======================================")
  #print("Training time: {}".format(elapsed))
  #print("======================================")

 
