# Copyright (c) Facebook, Inc. and its affiliates.
import argparse
import json
import numpy as np
import pickle
import io
import gzip
import sys
import time
from nltk.corpus import wordnet
from tqdm import tqdm
import operator
import torch
import os

sys.path.insert(0, 'third_party/CenterNet2/')
sys.path.insert(0, 'third_party/Deformable-DETR')
from tar_dataset import DiskTarDataset, _TarDataset

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--imagenet_dir", default='/home/ubuntu/efs/imagenet/ImageNet-21k/')
    parser.add_argument("--tarfile_path", default='/home/ubuntu/efs/imagenet/metadata-22k/tar_files.npy')
    parser.add_argument("--tar_index_dir", default='/home/ubuntu/efs/imagenet/metadata-22k/tarindex_npy')
    parser.add_argument("--out_path", default='/home/ubuntu/efs/imagenet/annotations/imagenet-22k_image_info.json')
    parser.add_argument("--workers", default=16, type=int)
    args = parser.parse_args()

    imagenet_21k_path = '/home/ubuntu/efs/imagenet/ImageNet-21K/images/train/'
    classnames = [f for f in os.listdir(imagenet_21k_path) if not f.endswith('.tar')]
    # start_time = time.time()
    # print('Building dataset')
    # dataset = DiskTarDataset(args.tarfile_path, args.tar_index_dir)
    # end_time = time.time()
    # print(f"Took {end_time-start_time} seconds to make the dataset.")
    # print(f"Have {len(dataset)} samples.")
    # print('dataset', dataset)
    #
    #
    # tar_files = np.load(args.tarfile_path)

    categories = []
    # for i, tar_file in enumerate(tar_files):
    #     wnid = tar_file[-13:-4]
    for i, classname in enumerate(classnames):
        wnid = classname
        synset = wordnet.synset_from_pos_and_offset('n', int(wnid[1:]))
        synonyms = [x.name() for x in synset.lemmas()]
        category = {
            'id': i + 1,
            'synset': synset.name(),
            'name': synonyms[0],
            'def': synset.definition(),
            'synonyms': synonyms,
        }
        categories.append(category)
    print('categories', len(categories))

    # data_loader = torch.utils.data.DataLoader(
    #     dataset, batch_size=1, shuffle=False,
    #     num_workers=args.workers,
    #     collate_fn=operator.itemgetter(0),
    # )
    images = []
    for img, label, index in tqdm(data_loader):
        if label == -1:
            continue
        image = {
            'id': int(index) + 1,
            'pos_category_ids': [int(label) + 1],
            'height': int(img.height),
            'width': int(img.width),
            'tar_index': int(index),
        }
        images.append(image)
    
    data = {'categories': categories, 'images': images, 'annotations': []}
    try:
        for k, v in data.items():
            print(k, len(v))
        print('Saving to ', args.out_path)
        json.dump(data, open(args.out_path, 'w'))
    except:
        pass
    import pdb; pdb.set_trace()
    
