from model.generators import GRAN, DartsGenerator
from utils.nas_helper import fmt_graphs

from easydict import EasyDict as edict
import torch
import yaml

import code
import os

def test():
  while True:
    print("[LOG] Please input exp dir path\n>>> ", end="")
    exp_dir = input().strip()
    config_file = os.path.join(exp_dir, "config.yaml")
    try:
      config = edict(yaml.safe_load(open(config_file, 'r')))
      break
    except FileNotFoundError:
      print("[ERR] File not found, please try again")

  generator = eval(config.generator.cls)(config)

  def load(iter):
    try:
      file_name = f"gen_snapshot_iter{iter}.pt"
      path = os.path.join(exp_dir, file_name)
      # print(f"[CMD] Open {path}")
      generator.load_state_dict(torch.load(path))
    except FileNotFoundError:
      print("[ERR] File not found, please try again")
    return

  def sample(num=1):
    # print(f"[CMD] Sample {num} graphs")
    samples = []
    while len(samples) < num:
      sample = generator.sample(1)
      if str(sample) not in [str(x) for x in samples]:
        samples.extend(sample)
    print(fmt_graphs(config, samples))

  def ls():
    print("\n".join(sorted(os.listdir(exp_dir))))

  code.interact(local=locals(), banner="[LOG] Beginning interactive console session")

