# modified from https://github.com/open-compass/VLMEvalKit and 

from tqdm import tqdm
import urllib.request
import hashlib
import os
from PIL import Image
import io
import base64
import math

dataset_name_split_mapping = {
    "POPE": "test",
    "Winoground-YN": "test",

    "MMMU_VAL_MultiChoice": "validation",
    "MMMU_VAL_OpenEnded": "validation",
    "MMMU_TEST_MultiChoice": "test",
    "MMMU_TEST_OpenEnded": "test",
    "MathVista_MultiChoice": "testmini",
    "MathVista_OpenEnded": "testmini",

    "COCO": "test",
    "Flickr30K": "test",
    "NoCaps": "validation",
    "WHOOPS-Caption": "test",

    "VQAv2": "testdev",
    "VQAv2_VAL": "validation",
    "VQAv2_TEST": "testdev",
    "OK-VQA": "test",
    "VizWiz_VAL": "validation",
    "VizWiz_TEST": "test",
    "TextVQA": "validation",
    # "WHOOPS-VQA": "test",

    "MMLU": "test",
    "PIQA_VAL": "validation",
}

dataset_name_image_id_mapping = {
    "COCO": "cocoid",
    "Flickr30K": "filename",
    "NoCaps": "image_id",
    "WHOOPS-Caption": "image_id",
}

dataset_name_answer_mapping = {
    "MME": "answer",
    "POPE": "answer",
    "Winoground-YN": "answer",
    "HallusionBench": "answer",

    "MMBench_DEV_EN": "answer",
    "SEEDBench_IMG": "answer",
    "ScienceQA_VAL": "answer",
    "ScienceQA_TEST": "answer",
    "MMMU_VAL_MultiChoice": "answer",
    "MMMU_TEST_MultiChoice": "answer",
    "MMMU_VAL_OpenEnded": "answer",
    "MMMU_TEST_OpenEnded": "answer",
    "MathVista_MultiChoice": "answer_transformed",
    "MathVista_OpenEnded": "answer_transformed",

    "COCO": "sentences_raw",
    "Flickr30K": "caption",
    "NoCaps": "annotations_captions",
    "WHOOPS-Caption": "crowd_captions",

    "VQAv2": "answers",
    "VQAv2_VAL": "multiple_choice_answer",
    "VQAv2_TEST": "answers",
    "OK-VQA": "answers",
    "VizWiz_VAL": "answers",
    "VizWiz_TEST": "answers",
    "TextVQA": "answers",
    "GQA_TESTDEV_BALANCED": "answer",
    # "WHOOPS-VQA": "answer",
    "MMVet": "answer",

    "MMLU": "answer_transformed",
    "PIQA_VAL": "answer_transformed",
}

dataset_URLs = {
    'MMBench_DEV_EN': "https://opencompass.openxlab.space/utils/VLMEval/MMBench_DEV_EN.tsv", 
    'MMBench_TEST_EN': "https://opencompass.openxlab.space/utils/VLMEval/MMBench_TEST_EN.tsv", 
    'MMBench_DEV_CN': "https://opencompass.openxlab.space/utils/VLMEval/MMBench_DEV_CN.tsv", 
    'MMBench_TEST_CN': "https://opencompass.openxlab.space/utils/VLMEval/MMBench_TEST_CN.tsv", 
    'CCBench': "https://opencompass.openxlab.space/utils/VLMEval/CCBench.tsv", 
    'MME': "https://opencompass.openxlab.space/utils/VLMEval/MME.tsv", 
    'SEEDBench_IMG': "https://opencompass.openxlab.space/utils/VLMEval/SEEDBench_IMG.tsv", 
    "MMVet": "https://opencompass.openxlab.space/utils/VLMEval/MMVet.tsv",
    'ScienceQA_VAL': "https://opencompass.openxlab.space/utils/VLMEval/ScienceQA_VAL.tsv",
    'ScienceQA_TEST': "https://opencompass.openxlab.space/utils/VLMEval/ScienceQA_TEST.tsv",
    'HallusionBench': "https://opencompass.openxlab.space/utils/VLMEval/HallusionBench.tsv",
}

def listinstr(lst, s):
    assert isinstance(lst, list)
    for item in lst:
        if item in s:
            return True
    return False


def is_none(value):
    if value is None:
        return True
    if type(value) is float and math.isnan(value):
        return True
    # if type(value) is str and value.lower() == 'nan':
    #     return True
    # if type(value) is str and value.lower() == 'none':
    #     return True
    return False

def get_dataset_type(dataset):
    if listinstr(['mmbench', 'ccbench', 'seedbench', 'scienceqa', 'mathvista_multichoice', 'mmmu_val_multichoice', 'mmmu_test_multichoice', 'mmlu', 'piqa'], dataset.lower()):
        return 'multi-choice'
    elif listinstr(['mme', 'pope', 'winoground-yn', 'hallusion'], dataset.lower()):
        return 'Y/N'
    elif listinstr(['coco', 'flickr', 'nocaps', 'whoops-caption'], dataset.lower()):
        return 'Caption'
    elif listinstr(['vqav2', 'ok-vqa', 'vizwiz', 'textvqa', 'gqa', 'ocrvqa', 'chartqa', 'docvqa', 'llavabench', 'mmvet', 'mathvista_openended', 'mmmu_val_openended', 'mmmu_test_openended'], dataset.lower()): # , 'whoops-vqa'
        return 'VQA'
    return None

def download_file(url, filename=None):
    class DownloadProgressBar(tqdm):
        def update_to(self, b=1, bsize=1, tsize=None):
            if tsize is not None:
                self.total = tsize
            self.update(b * bsize - self.n)
        
    if filename is None:
        filename = url.split('/')[-1]

    with DownloadProgressBar(unit='B', unit_scale=True,
                             miniters=1, desc=url.split('/')[-1]) as t:
        urllib.request.urlretrieve(url, filename=filename, reporthook=t.update_to)
    return filename

def encode_image_to_base64(img, target_size=-1):
    # if target_size == -1, will not do resizing
    # else, will set the max_size ot (target_size, target_size)
    if img.mode in ("RGBA", "P"):
        img = img.convert("RGB")
    tmp = os.path.join('/tmp', str(uuid4()) + '.jpg')
    if target_size > 0:
        img.thumbnail((target_size, target_size))
    img.save(tmp)
    with open(tmp, 'rb') as image_file:
        image_data = image_file.read()
    ret = base64.b64encode(image_data).decode('utf-8')
    os.remove(tmp)
    return ret

def encode_image_file_to_base64(image_path, target_size=-1):
    image = Image.open(image_path)
    return encode_image_to_base64(image, target_size=target_size)
    
def decode_base64_to_image(base64_string, target_size=-1):
    image_data = base64.b64decode(base64_string)
    image = Image.open(io.BytesIO(image_data))
    if image.mode in ('RGBA', 'P'):
        image = image.convert('RGB')
    if target_size > 0:
        image.thumbnail((target_size, target_size))
    return image

def decode_base64_to_image_file(base64_string, image_path, target_size=-1):
    image = decode_base64_to_image(base64_string, target_size=target_size)
    image.save(image_path, quality=100, subsampling=0)