import os
from argparse import ArgumentParser
import tqdm
import clip
import torch
import json
import pickle


def parse_args():
    parser = ArgumentParser()
    parser.add_argument('--dataset', choices=['mscoco', 'flickr30k'], required=True,
                        help="Captions of what dataset to extract")
    parser.add_argument('--vis-encoder', choices=['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14',
                                 'ViT-L/14@336px'], help="CLIP model to extract embeddings for")
    parser.add_argument('--localized-narratives', action='store_true', help="Use annotations from localized narratives")
    parser.add_argument('--datadir', type=str, required=True, help="Directory that holds the images of the respective datasets")
    return parser.parse_args()


def extract_caps(data, split_dict, model, device='cuda'):
    info = []

    for i in tqdm.trange(len(data), desc="Loading data..."):
        d = data[i]
        if "filename" in d.keys():
            img_id = d["filename"]
        else:
            img_id = d["image_id"]

        try:
            split = split_dict[img_id]
        except KeyError:
            continue

        if split == 'train' or split == "restval":
            with torch.no_grad():
                if 'sentences' in d.keys():
                    caps = [sent['raw'] for sent in d['sentences']]
                else:
                    caps = [c for c in d["caption"].split('.') if len(c)]
                tokenized = clip.tokenize(caps, truncate=True).to(device)
                emb = model.encode_text(tokenized).cpu().numpy()

            info.append((img_id, caps, emb))

    return info


def main():
    args = parse_args()
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    root = args.datadir

    if args.localized_narratives:
        if args.dataset == "flickr30k":
            anns = f"flickr30k_train_localized_narratives.jsonl"

            with open(os.path.join(root, anns), 'r') as f:
                data = [eval(l) for l in f.readlines()]
                split_dict = {d['image_id']: "train" for d in data}
        else:
            data = []
            for i in range(4):
                anns = f"coco_train_localized_narratives-0000{i}-of-00004.jsonl"
                with open(os.path.join(root, anns), 'r') as f:
                    data.extend([eval(l) for l in f.readlines()])
            split_dict = {d['image_id']: "train" for d in data}
    else:
        annotations = f'annotations/dataset_{args.dataset}.json'

        with open(os.path.join(annotations), 'r') as f:
            data = json.load(f)['images']
            split_dict = {d['filename']: d['split'] for d in data}

    # for encoder in ['ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']:
    model, _ = clip.load(args.vis_encoder)
    model = model.to(device)
    model.eval()

    os.makedirs(os.path.join('data', args.dataset), exist_ok=True)

    info = extract_caps(data, split_dict, model, device)

    if args.localized_narratives:
        pickle.dump(info, open(os.path.join('data', args.dataset, f'{args.vis_encoder.replace("/", "")}_train_caps_localized_narratives.pkl'), 'wb'))
    else:
        pickle.dump(info, open(os.path.join('data', args.dataset, f'{args.vis_encoder.replace("/", "")}_train_caps.pkl'), 'wb'))


if __name__ == '__main__':
    main()
