from subprocess import call
import os, json

from torchvision.datasets import VisionDataset
from PIL import Image


GITHUB_MAIN_ORIGINAL_ANNOTATION_PATH = 'https://github.com/mehdidc/retrieval_annotations/releases/download/1.0.0/coco_{}_karpathy.json'
GITHUB_MAIN_PATH = 'https://raw.githubusercontent.com/adobe-research/Cross-lingual-Test-Dataset-XTD10/main/XTD10/'
SUPPORTED_LANGUAGES = ['es', 'it', 'ko', 'pl', 'ru', 'tr', 'zh', 'en']

IMAGE_INDEX_FILE = 'mscoco-multilingual_index.json'
IMAGE_INDEX_FILE_DOWNLOAD_NAME = 'test_image_names.txt'

CAPTIONS_FILE_DOWNLOAD_NAME = 'test_1kcaptions_{}.txt'
CAPTIONS_FILE_NAME = 'multilingual_mscoco_captions-{}.json'

ORIGINAL_ANNOTATION_FILE_NAME = 'coco_{}_karpathy.json'



class Multilingual_MSCOCO(VisionDataset):

    def __init__(self, root, ann_file, transform=None, target_transform=None):
        super().__init__(root, transform=transform, target_transform=target_transform)
        self.ann_file = os.path.expanduser(ann_file)
        with open(ann_file, 'r') as fp:
            data = json.load(fp)
            
        self.data = [(img_path, txt) for img_path, txt in zip(data['image_paths'], data['annotations'])]
    
    def __getitem__(self, index):
        img, captions = self.data[index]

        # Image
        img = Image.open(os.path.join(self.root, img)).convert("RGB")
        if self.transform is not None:
            img = self.transform(img)

        # Captions
        target =  [captions, ]
        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target


    def __len__(self) -> int:
        return len(self.data)


def _get_downloadable_file(filename, download_url, is_json=True):
    if (os.path.exists(filename) == False):
        print("Downloading", download_url)
        call("wget {} -O {}".format(download_url, filename), shell=True)
    with open(filename, 'r') as fp:
        if (is_json):
            return json.load(fp)
        return [line.strip() for line in fp.readlines()]


def create_annotation_file(root, lang_code):
    print("Downloading multilingual_ms_coco index file")
    download_path = os.path.join(GITHUB_MAIN_PATH, IMAGE_INDEX_FILE_DOWNLOAD_NAME)
    target_images = _get_downloadable_file("multilingual_coco_images.txt", download_path, False)

    print("Downloading multilingual_ms_coco captions:", lang_code)
    download_path = os.path.join(GITHUB_MAIN_PATH, CAPTIONS_FILE_DOWNLOAD_NAME.format(lang_code))
    target_captions = _get_downloadable_file('raw_multilingual_coco_captions_{}.txt'.format(lang_code), download_path, False)

    number_of_missing_images = 0
    valid_images, valid_annotations, valid_indicies = [], [], []
    for i, (img, txt) in enumerate(zip(target_images, target_captions)):
        # Create a new file name that includes the root split
        root_split = 'val2014' if 'val' in img else 'train2014'
        filename_with_root_split = "{}/{}".format(root_split, img)

        if (os.path.exists(filename_with_root_split)):
            print("Missing image file", img)
            number_of_missing_images += 1
            continue

        valid_images.append(filename_with_root_split)
        valid_annotations.append(txt)
        valid_indicies.append(i)

    if (number_of_missing_images > 0):
        print("*** WARNING *** missing {} files.".format(number_of_missing_images))

    with open(os.path.join(root, CAPTIONS_FILE_NAME.format(lang_code)), 'w') as fp:
        json.dump({'image_paths': valid_images, 'annotations': valid_annotations, 'indicies': valid_indicies}, fp)
