"""
Download data from HF Hub

"""

from argparse import ArgumentParser
from datasets import load_dataset
import logging
from pathlib import Path
from pprint import pformat


SUBSETS = {
    "cooking": [
        "17_40",
        "20_39",
        "20_26",
        "10_46",
        "16_44",
        "4_3",
        "15_33",
        "1_34",
        "22_24",
        "18_3",
        "8_31",
        "15_28",
        "20_25",
        "13_20",
        "23_39",
        "25_6",
        "4_35",
        "1_36",
        "13_9",
        "8_30",
        "3_49",
        "2_38",
        "21_47",
        "13_45",
        "27_37",
        "13_41",
        "29_32",
        "15_37",
        "8_44",
        "9_108",
        "5_28",
        "22_41",
        "26_42",
        "16_20",
        "1_143",
        "8_26",
        "9_8",
        "9_36",
        "25_22",
        "29_28",
        "15_29",
        "2_42",
        "17_45",
        "9_13",
        "4_5",
        "29_34",
        "7_26",
        "4_36",
        "21_24",
        "29_17",
        "1_28",
        "8_45",
        "15_2",
        "21_50",
        "13_44",
        "25_41",
        "2_4",
        "1_30",
        "28_26",
        "5_37",
        "22_32",
        "22_38",
        "7_30",
        "17_43",
        "4_43",
        "15_46",
        "18_28",
        "26_6",
        "18_49",
        "8_11",
        "16_26",
        "25_3",
        "9_24",
        "10_31",
        "28_3",
        "17_28",
        "23_19",
        "1_43",
        "23_41",
        "20_32",
        "28_38",
        "27_34",
        "18_27",
        "29_35",
        "16_2",
        "10_7",
        "21_46",
        "27_3",
        "1_32",
        "23_32",
        "5_35",
        "16_18",
        "15_17",
        "8_15",
        "21_32",
        "29_37",
        "20_48",
        "18_2",
        "26_18",
        "18_19",
        "9_47",
        "1_136",
        "9_19",
        "17_10",
        "17_5",
        "21_29",
        "26_46",
        "3_34",
        "21_103",
        "7_35",
        "27_31",
        "9_25",
        "5_44",
        "3_36",
        "7_3",
        "28_45",
        "29_22",
        "21_44",
        "27_4",
        "28_28",
        "26_29",
        "3_50",
        "8_50",
        "16_1",
        "27_26",
        "28_16",
        "5_42",
        "12_9",
        "27_38",
        "13_14",
        "5_11",
        "22_40",
        "10_50",
        "5_15",
        "29_19",
        "16_39",
        "22_4",
        "17_37",
        "23_23",
        "9_4",
        "20_29",
        "28_14",
        "4_32",
        "20_44",
        "1_42",
        "8_40",
        "21_19",
        "2_3",
        "7_48",
        "10_6",
        "29_49",
        "22_37",
        "17_36",
        "22_137",
        "1_37",
        "23_9",
        "18_31",
        "8_25",
        "18_24",
        "9_45",
        "27_29",
        "27_49",
        "2_47",
        "13_5",
        "29_29",
        "22_30",
        "4_40",
        "5_27",
        "15_30",
        "28_21",
        "21_14",
        "8_35",
        "12_41",
        "25_42",
        "28_25",
        "12_43",
        "3_5",
        "28_10",
        "17_19",
        "10_47",
        "25_11",
        "12_15",
        "2_8",
        "5_19",
        "20_22",
        "29_15",
        "12_5",
        "13_32",
        "22_26",
        "28_42",
        "16_42",
        "13_38",
        "23_17",
        "1_25",
        "12_10",
        "7_24",
        "18_41",
        "21_3",
        "18_101",
        "29_5",
        "29_7",
        "21_15",
        "18_11",
        "16_40",
        "4_44",
        "13_36",
        "4_30",
        "29_129",
        "7_135",
        "7_50",
        "2_41",
        "25_40",
        "7_19",
        "18_33",
        "17_23",
        "3_22",
        "28_29",
        "16_35",
        "10_26",
        "15_18",
        "1_49",
        "25_13",
        "15_41",
        "2_28",
        "22_2",
        "9_12",
        "10_48",
        "17_20",
        "26_34",
        "7_38",
        "12_51",
    ],
    "assembly": [
        "a02",
        "a03",
        "a07",
        "a08",
        "a09",
        "a10",
        "a12",
        "a13",
        "a14",
        "a16",
        "a18",
        "a19",
        "a20",
        "a23",
        "a24",
        "a26",
        "a28",
        "a29",
        "a30",
        "b01a",
        "b01b",
        "b02a",
        "b02b",
        "b03a",
        "b03b",
        "b04a",
        "b04b",
        "b04c",
        "b04d",
        "b05a",
        "b05b",
        "b05c",
        "b05d",
        "b06a",
        "b06b",
        "b06c",
        "b06d",
        "b08a",
        "b08b",
        "b08d",
        "c02a",
        "c02b",
        "c02c",
        "c03a",
        "c03b",
        "c03c",
        "c03d",
        "c03e",
        "c03f",
        "c04a",
        "c04d",
        "c05a",
        "c06a",
        "c06b",
        "c06c",
        "c06d",
        "c06f",
        "c07a",
        "c07b",
        "c07c",
        "c08a",
        "c08b",
        "c08c",
        "c09a",
        "c09b",
        "c09c",
        "c10a",
        "c10b",
        "c10c",
        "c11a",
        "c12a",
        "c12e",
        "c13a",
        "c13c",
        "c13d",
        "c13e",
        "c13f",
        "c14a",
    ],
}


def main(args):
    if args.target in ["cooking", "both"]:
        logging.info("Download ProMQA(-Cooking)")
        for task_id in SUBSETS["cooking"]:
            logging.info(f"Download subset: {task_id}")
            load_dataset("kimihiroh/promqa-cooking-2", task_id, split="test")

    if args.target in ["assembly", "both"]:
        logging.info("Download ProMQA-Assembly")
        for task_id in SUBSETS["assembly"]:
            logging.info(f"Download subset: {task_id}")
            load_dataset("kimihiroh/promqa-assembly", task_id, split="test")

        logging.info("Download instruction graphs")
        load_dataset("kimihiroh/assembly101-graph", split="test")


if __name__ == "__main__":
    parser = ArgumentParser(description="Download data")
    parser.add_argument(
        "--target",
        type=str,
        help="target data to download",
        choices=["cooking", "assembly", "both"],
        default="both",
    )
    parser.add_argument(
        "--dirpath_log", type=Path, help="dirpath for log", default="./log"
    )
    args = parser.parse_args()

    if not args.dirpath_log.exists():
        args.dirpath_log.mkdir(parents=True)

    logging.basicConfig(
        format="%(asctime)s:%(levelname)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        level=logging.INFO,
        handlers=[
            logging.StreamHandler(),
            logging.FileHandler(args.dirpath_log / "download.log"),
        ],
    )

    logging.info(pformat(vars(args)))

    main(args)