
# We give two representive data : stock, energy, and their codes. For other datasets, you can easily get them from link in our paper by removing miscellaneous data, as described in our paper.
# If you want to experience with other data, you have to modify conditional score network appropriately, by changing normalization parts in ./models/conditional_ncsnpp.py. compare with energy.

import run_lib
import logging
import os
import tensorflow as tf
import torch
import argparse
import warnings
import utils

warnings.filterwarnings("ignore")
arg = argparse.ArgumentParser()
arg.add_argument("--model", type=str, default="vp", help='model selection') # vp, subvp
arg.add_argument("--mode", type=str, default="train", help='train or evaluate') # train, evaluate
arg.add_argument('--data_name', default='stock', type=str) # stock, energy, air, ai4i, occupancy
arg.add_argument('--z_dim', default=6, type=int, help='dimension of inputs') # 6, 28, 13, 5, 13
arg.add_argument('--seq_len', default=24, type=int) # sequential length, always 24
arg.add_argument('--hidden_dim', default=24, type=int, help='hidden dimension of inputs') # 24, 56, 40, 24, 40
arg.add_argument('--iteration', default=50000, type=int, help='pre-training iteration') # 50000, 100000, 50000, 50000, 100000
arg.add_argument('--nf', default=96, type=int, help='hidden dimension of temporal feature, s.') # 96, 56, 80, 96, 80
# Fix default bellow values
arg.add_argument('--num_layer', default=3, type=int)
arg.add_argument('--beta1', type=float, default=0.9, help='momentum term of adam')
arg.add_argument('--lr', type=float, default=0.001, help='initial learning rate for adam')
arg.add_argument('--batch_size', default=128, type=int)
arg.add_argument('--metric_iteration', default=10, type=int)
# gpu setting. basically we use data parallel, so remove such codes on run_lib.py if you don't want to use.
arg.add_argument('--gpu', default=0, type=int)
arg.add_argument('--parallel_device', default=[0,1], type=list)
arg = arg.parse_args()

config = __import__(f"configs.{arg.model}.score_network")

arg.config = eval(f"config.{arg.model}.score_network.get_config()")
arg.config.device = torch.device(f'cuda:{arg.gpu}') if torch.cuda.is_available() else torch.device('cpu')

utils.fix_random_seed(0)

def main(): 
  if arg.mode == "train":
    # pre-train encoder and decoder
    run_lib.train_ER(arg)
    # alternately training conditional score network and encoder-decoder pair.
    run_lib.train_conditional_score(arg)
  elif arg.mode == "evaluate":
    #calculate scores, just run the code after training.
    run_lib.evaluate(arg)


if __name__ == "__main__":
  main()
