import ast
import torch
import argparse
from run import run

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run InvarGC experiments")

    # Dataset | Method
    parser.add_argument("--dataset", type=str, default="conftep", help="dataset name [synthetic | real]")
    parser.add_argument("--method", type=str, default="invargc", help="method to run [invargc | ...]")

    # Model Hyper-Parameters
    parser.add_argument("--num_series", type=int, default=5, help="number of observed series d")
    parser.add_argument("--num_confound", type=int, default=1, help="number of latent confounders p")
    parser.add_argument("--hidden",type=ast.literal_eval,default=[100],help="hidden layer dimensions, e.g. --hidden [100,50]")
    parser.add_argument("--seed", type=int, default=2021, help="random seed")

    # Training Hyper-Parameters
    parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
    parser.add_argument("--max_iter", type=int, default=2000, help="max training iterations")
    parser.add_argument("--lam_h", type=float, default=1e-4, help="ridge penalty")
    parser.add_argument("--lam_c", type=float, default=1e-2, help="confounder penalty")
    parser.add_argument("--lam_v", type=float, default=1e-3, help="intervention penalty")
    parser.add_argument("--lookback", type=int, default=5, help="early stopping lookback")
    parser.add_argument("--check_every", type=int, default=50, help="check interval")

    # GPU
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
                        help="device to use: 'cuda' or 'cpu'")

    args = parser.parse_args()
    run(args)