import argparse
from pathlib import Path
from tdw_image_dataset.image_dataset import ImageDataset


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-n', '--name', default='', help='name of the dataset')
    parser.add_argument('-s', '--scenes', nargs='+', default=[], help='names of the scenes to generate')
    parser.add_argument('-d', '--directory', default='', help='the directory to save the dataset to')
    args = parser.parse_args()

    cam_rot_range = 20.0
    subset_ids = None
    
    if args.name == 'tdw5k':
        """
        Generate a dataset that is of same size as the HvM dataset
        around 4608 training iamges and 1152 testing images
        only have 8 categories
        multiple scenes
        """

        num_img_total = 4608 + 1152
        subset_ids = [
            'n02774152', # 'bag, handbag, pocketbook, purse’, 12 records
            'n02933112', # 'cabinet’, 33 records
            'n03001627', # 'chair’, 25 records
            'n03761084', # 'microwave oven’, 12 records
            'n03880531', # 'pan’, 12 records
            'n04256520', # ‘sofa’, 14 records
            'n04379243', # ‘table’, 20 records
            'n04461879', # ‘toy’, 12 records
        ]
    
    elif args.name == 'tdw1m':
        """
        Generate a dataset that is of same size Imagenet
        around 1,300,000 training images and 50,000 testing images
        that have all categories
        multiple scenes
        """

        num_img_total = 1300000 + 50000
    
    elif args.name == 'tdw1m_1c_n03001627':
        """
        Generate a dataset that is of same size Imagenet
        around 1,300,000 training images and 50,000 testing images
        that only have 1 category: 'n03001627' (chair)
        multiple scenes
        """

        num_img_total = 1300000 + 50000
        subset_ids = ['n03001627', ]  # 'chair’, 25 records
    
    elif args.name == 'tdw1m_1c_n02774152':
        """
        Generate a dataset that is of same size Imagenet
        around 1,300,000 training images and 50,000 testing images
        that only have 1 category: 'n02774152' (bag)
        multiple scenes
        """

        num_img_total = 1300000 + 50000
        subset_ids = ['n02774152', ]  # 'bag, handbag, pocketbook, purse’, 12 records
    
    elif args.name == 'tdw1m_2c':
        # 1m dataset with 2 categories

        num_img_total = 1300000 + 50000
        subset_ids = ['n00104409',
                      'n00609236',
                      ]
    
    elif args.name == 'tdw1m_4c':
        # 1m dataset with 4 categories

        num_img_total = 1300000 + 50000
        subset_ids = ['n00104409',
                      'n00609236',
                      'n02206856',
                      'n02694662',
                      ]
    
    elif args.name == 'tdw1m_6c':
        # 1m dataset with 6 categories

        num_img_total = 1300000 + 50000
        subset_ids = ['n00104409',
                      'n00609236',
                      'n02206856',
                      'n02694662',
                      'n02769748',
                      'n02773838',
                      ]

    elif args.name == 'tdw1m_8c':
        # 1m dataset with 8 categories
    
        num_img_total = 1300000 + 50000
        subset_ids = ['n00104409',
                      'n00609236',
                      'n02206856',
                      'n02694662',
                      'n02769748',
                      'n02773838',
                      'n02774152',
                      'n02799175',
                      ]
    
    elif args.name == 'tdw1m_16c':
        # 1m dataset with 16 categories

        num_img_total = 1300000 + 50000
        subset_ids = ['n00104409',
                      'n00609236',
                      'n02206856',
                      'n02694662',
                      'n02769748',
                      'n02773838',
                      'n02774152',
                      'n02799175',
                      'n02801938',
                      'n02818832',
                      'n02828884',
                      'n02870526',
                      'n02871005',
                      'n02876657',
                      'n02883344',
                      'n02933112',
                      ]
    
    elif args.name == 'tdw1m_obj_centered':
        """
        Generate a dataset that is of same size Imagenet
        around 1,300,000 training images and 50,000 testing images
        that have all categories
        multiple scenes
        objects are centered in the generated images
        """

        num_img_total = 1300000 + 50000
        cam_rot_range = 0.01
    
    elif args.name == 'tdw10m':
        """
        Generate a dataset that have around 10M images, 10M for training, 100K for testing
        that have all categories
        multiple scenes
        """

        num_img_total = 10000000 + 100000

    elif args.name == 'tdw100m':
        """
        Generate a dataset that have around 100M images, 100M for training, 100K for testing
        that have all categories
        multiple scenes
        """

        num_img_total = 100000000 + 100000

    else:
        raise NotImplementedError("Unknown dataset name")
    
    if args.directory != '':
        output_dir = Path(args.directory).joinpath(args.name)
    else:
        output_dir = Path.home().joinpath(args.name)
    
    # 10 scenes, 8 outdoors, 2 indoors
    scenes = [
        "box_room_2018", # indoor
        "building_site",
        "dead_grotto",
        "downtown_alleys",
        "iceland_beach",
        "lava_field",
        "ruin",
        "savanna_flat_6km",
        "suburb_scene_2023",
        "tdw_room", # indoor
        ]
    
    c = ImageDataset(
        num_img_total=num_img_total,
        output_directory=output_dir,
        subset_wnids=subset_ids,
        scene_list=scenes,
        scene_to_generate=args.scenes,
        cam_rot_range=cam_rot_range,
        )

    c.generate_multi_scene()
