import sys
import os

import gdown


DIR = "./expert_datasets"

DEMOS = {
    "minigrid": [
        ("cardinal_ac", "1joS7NckzcPpNt-1KqE4uodGXq-NKWfLw"),
    ],
    "maze2d": [
        ("box_01", "1-uOE0RC3dKD658cBCSqTVh8XaypAQ1Sy"),
        ("full_ac", "1jZPHhd9Bnp2dPoJW6tF2SIPyh65oJN8s"),
    ],
    "fetchpick": [
        ("box_01", "1WZQ5i5ROCyyWRHOuKuJ2nndeTUJSTheB"),
    ],
    "fetchpush": [
        ("box_005", "1cNrMIEhZITUca4KqB50DHeSVbAsX5g9m"),
        ("box_07", "1DZCTG1gPziCuBGJbnfIyVJRYYZSijej6"),
    ],
    "widowx": [
        # ("sim_control", "1fANjAT5W1YfKB4GJ_7XFIPlXXhOoX14m"),
        ("v2", "1eTCyHbvCJcoFI3wa9Rca4xQ8qlURSzAQ"),
    ],
}


if __name__ == "__main__":
    tasks = []
    if len(sys.argv) > 1:
        tasks = sys.argv[1:]
    else:
        tasks = ["minigrid", "maze2d", "fetchpick", "fetchpush", "widowx"]

    os.makedirs(DIR, exist_ok=True)

    for task in tasks:
        for postfix, id in DEMOS[task]:
            url = "https://drive.google.com/uc?id=" + id
            target_path = "%s/%s_%s.pt" % (DIR, task, postfix)
            if os.path.exists(target_path):
                print("%s is already downloaded." % target_path)
            else:
                print("Downloading demo (%s_%s) from %s" % (task, postfix, url))
                gdown.download(url, target_path)
