####################################################################################
# 1. Installing requirements.txt                                                   #
# 2. Set gpu on config.device on ./configs/default_cifar10_configs.py              #
# 3. And set data_parallel on run_lib too. then just run modifying arg.mode        #
# We set default as gpu=0 and parallel with [0,1]                                  #
####################################################################################
import run_lib
import logging
import os
import tensorflow as tf
import torch
import argparse
import warnings
import utils
from data import load_data
from options import Options

warnings.filterwarnings("ignore")
arg = argparse.ArgumentParser()

# select model
arg.add_argument("--model", type=str, default="subvp") # vp, subvp
arg.add_argument("--dir", type=str, default="original")
# First train_ER, which is pre-training encoder and decoder. After finish pre-training, train conditional score network. And then finally, cal_score to check predictive and discriminative score.  
arg.add_argument("--mode", type=str, default="train_ER") # train_ER, train_conditional_score, cal_score


arg = arg.parse_args()
if arg.model == "ve":
  arg.name = "cifar10_ncsnpp_continuous"
elif arg.model == "vp":
  arg.name = "cifar10_ddpmpp_continuous"
elif arg.model == "subvp":
  arg.name = "cifar10_ddpmpp_continuous"

config = __import__(f"configs.{arg.model}.{arg.name}")#__import__(f"configs",fromlist=[None])

arg.config = eval(f"config.{arg.model}.{arg.name}.get_config()")#score_sde_checkpoints/ve/cifar10_ncsnpp_deep_continuous
arg.workdir = "/".join([arg.dir + "_" + "train_score", arg.model, arg.name])

utils.fix_random_seed(2)

def main(): 
  tf.io.gfile.makedirs(arg.workdir)
  gfile_stream = open(os.path.join(arg.workdir, f'stdout_{arg.mode}.txt'), 'w')
  handler = logging.StreamHandler(gfile_stream)
  formatter = logging.Formatter('%(levelname)s - %(filename)s - %(asctime)s - %(message)s')
  handler.setFormatter(formatter)
  logger = logging.getLogger()
  logger.addHandler(handler)
  logger.setLevel('INFO') 
  opt = Options().parse()
  ori_data = load_data(opt)
  if arg.mode == "train_ER":
    run_lib.train_ER(opt, ori_data, arg.workdir)
    # run_lib.train_ER(arg.config, arg.workdir)
  elif arg.mode == "train_conditional_score":
    run_lib.train_conditional_score(opt, ori_data, arg.config, arg.workdir)
  elif arg.mode == "cal_score":
    run_lib.cal_score(opt, ori_data, arg.config, arg.workdir)


if __name__ == "__main__":
  main()
