import os
from utils.basic_utils import dict_to_markdown
from utils.convert_annotations.processor import Processor

def main():
    import argparse
    parser = argparse.ArgumentParser(description='Convert charades json file to tvr format.')
    parser.add_argument("--dset_name", type=str, default=['charades', 'activitynet', 'activitynet-CG', 'tacos',
                                                          'charades-CD', 'charades-Unseen', 'charades-CG'], help="path to dataset")
    parser.add_argument("--phase", type=list, default=['train', 'val', 'test'], help="path to dataset")
    parser.add_argument("--thd", type=float, default=0.5, help="")
    parser.add_argument("--root_dir", type=str, default='data/', help="path to annotations.")
    config = parser.parse_args()
    config.root_dir = os.path.join(config.root_dir, config.dset_name)

    if 'charades' in config.dset_name:
        # There is no val set in charades-sta dataset.
        config.phase = ['train', 'test']
        if 'charades-CD' in config.dset_name:
            config.phase = ['train', 'val', 'test_iid', 'test_ood']
        elif 'charades-CG' in config.dset_name:
            config.phase = ['novel_composition', 'novel_word',  'test_trivial',  'train']
        elif 'charades-Unseen' in config.dset_name:
            config.phase = ['train', 'test_unseen',  'test_seen']

    elif 'activitynet' in config.dset_name:
        config.phase = ['train', 'val_1', 'val_2']
        if 'activitynet-CD' in config.dset_name:
            config.phase = ['train', 'val', 'test_iid', 'test_ood']
        elif 'activitynet-CG' in config.dset_name:
            config.phase = ['novel_composition', 'novel_word',  'test_trivial',  'train']

    elif 'tacos' == config.dset_name:
        config.phase = ['train', 'val', 'test']
    else:
        raise NotImplementedError('Invalid dataset, {} '.format(config.dset_name))
    print(dict_to_markdown(vars(config), max_str_len=120))

    processor = Processor(config)
    processor.convert()


if __name__ == '__main__':
    main()