import os
import json
import subprocess
import numpy as np
import pandas as pd
import torchvision
import torch

from PIL import Image
from tqdm import tqdm
from torch.utils.data import Dataset
from easydict import EasyDict as edict
from torchvision.datasets.utils import download_url
from datasets import load_dataset

from .perturbations import shuffle_nouns_and_adj, shuffle_allbut_nouns_and_adj, shuffle_within_trigrams, shuffle_trigrams
from .retrieval import pre_caption

auth_token = None  # place the auth_token here

def get_winoground_retrieval_acc(scores):
    if isinstance(scores, tuple):
        # Note both are N x N_image x N_text
        scores_t2i = scores[0]
        scores_i2t = scores[1]
    else:
        scores_t2i = scores
        scores_i2t = scores
    acc_i2t = get_image2text_acc_from_winoground_matrix(scores_i2t)
    acc_t2i = get_text2image_acc_from_winoground_matrix(scores_t2i)
    return acc_i2t, acc_t2i

def get_image2text_acc_from_winoground_matrix(scores_i2t):
    # Note scores_i2t is N x N_image x N_text
    image_0_scores, image_1_scores = np.array_split(scores_i2t, 2, axis=1)
    image_1_scores = np.flip(image_1_scores, axis=-1)
    image_0_scores = image_0_scores.squeeze(axis=1)
    image_1_scores = image_1_scores.squeeze(axis=1)
    image_scores = np.concatenate([image_0_scores, image_1_scores], axis=0)
    acc = float((image_scores.argmax(axis=1) == 0).mean())
    return acc

def get_text2image_acc_from_winoground_matrix(scores_t2i):
    scores_t2i = scores_t2i.transpose(0, 2, 1) # N x N_text x N_image
    text_0_scores, text_1_scores = np.array_split(scores_t2i, 2, axis=1)
    text_1_scores = np.flip(text_1_scores, axis=-1)
    text_0_scores = text_0_scores.squeeze(axis=1)
    text_1_scores = text_1_scores.squeeze(axis=1)
    text_scores = np.concatenate([text_0_scores, text_1_scores], axis=0)
    acc = float((text_scores.argmax(axis=1) == 0).mean())
    return acc

def get_winoground_scores(scores):
    if isinstance(scores, tuple):
        # Note both are N x N_image x N_text
        scores_t2i = scores[0]
        scores_i2t = scores[1]
    else:
        scores_t2i = scores
        scores_i2t = scores
    ids = list(range(scores_t2i.shape[0]))
    winoground_scores = []
    for id, score_i2t, score_t2i in zip(ids, scores_i2t, scores_t2i):
        winoground_scores.append({
            'i2t' : {
                "id" : id,
                "c0_i0": score_i2t[0][0],
                "c0_i1": score_i2t[1][0],
                "c1_i0": score_i2t[0][1],
                "c1_i1": score_i2t[1][1]},
            't2i' : {
                "id" : id,
                "c0_i0": score_t2i[0][0],
                "c0_i1": score_t2i[1][0],
                "c1_i0": score_t2i[0][1],
                "c1_i1": score_t2i[1][1]},
            })
    return winoground_scores

def get_winoground_acc(scores):
    text_correct_count = 0
    image_correct_count = 0
    group_correct_count = 0
    def text_correct(result):
        return result["c0_i0"] > result["c1_i0"] and result["c1_i1"] > result["c0_i1"]

    def image_correct(result):
        return result["c0_i0"] > result["c0_i1"] and result["c1_i1"] > result["c1_i0"]

    def group_correct(result):
        return image_correct(result['t2i']) and text_correct(result['i2t'])
    
    for result in scores:
        text_correct_count += 1 if text_correct(result['i2t']) else 0
        image_correct_count += 1 if image_correct(result['t2i']) else 0
        group_correct_count += 1 if group_correct(result) else 0

    denominator = len(scores)
    result = {
        'text': text_correct_count/denominator,
        'image': image_correct_count/denominator,
        'group': group_correct_count/denominator,
    }
    return result

class Winoground(Dataset):
    def __init__(self, image_preprocess=None, root_dir='./'):
        self.winoground = load_dataset("facebook/winoground", 
                                       cache_dir=root_dir, use_auth_token=auth_token)["test"]
        self.preprocess = image_preprocess
        self.original_tags = self.get_original_tags()
        self.new_tags = self.get_new_tags()
    
    def __len__(self):
        return len(self.winoground)

    def __getitem__(self, idx):
        example = self.winoground[idx]
        image_0 = self.preprocess(example["image_0"].convert("RGB"))
        image_1 = self.preprocess(example["image_1"].convert("RGB"))
        
        item = edict({"id": example["id"],
                      "image_options": [image_0, image_1],
                      "caption_options": [example["caption_0"], example["caption_1"]]})
        return item
    
    def get_original_tags(self):
        tags = {}
        for example in self.winoground:
            if example['num_main_preds'] == 1:
                if '1 Main Pred' not in tags:
                    tags["1 Main Pred"] = []
                tags['1 Main Pred'].append(example["id"])
            elif example['num_main_preds'] == 2:
                if '2 Main Pred' not in tags:
                    tags["2 Main Pred"] = []
                tags['2 Main Pred'].append(example["id"])
            else:
                # This won't happen
                raise ValueError(f"num_main_preds: {example['num_main_preds']}")
            if example["collapsed_tag"] not in tags:
                tags[example["collapsed_tag"]] = []
            tags[example["collapsed_tag"]].append(example["id"])
        return tags

    def get_new_tags(self, path="./why_winoground_hard.json"):
        new_tag_dict = json.load(open(path))
        tags = {}
        for idx in new_tag_dict:
            curr_tags = new_tag_dict[idx]
            if len(curr_tags) == 0:
                if "No Tag" not in tags:
                    tags["No Tag"] = []
                tags["No Tag"].append(int(idx))
            for tag in curr_tags:
                if tag not in tags:
                    tags[tag] = []
                tags[tag].append(int(idx))
        return tags
    
    
    def evaluate_scores(self, scores):
        winoground_scores = get_winoground_scores(scores)
        acc = get_winoground_acc(winoground_scores)
        print("Winoground text score:", acc['text'])
        print("Winoground image score:", acc['image'])
        print("Winoground group score:", acc['group'])
        acc_i2t, acc_t2i = get_winoground_retrieval_acc(scores)
        print("Winoground i2t retrieval accuracy:", acc_i2t)
        print(f"Winoground t2i retrieval accuracy:", acc_t2i)
        
        results = {}
        results['all'] = acc
        results['all'].update({'i2t': acc_i2t, 't2i': acc_t2i})
        for tag in self.original_tags:
            results[tag] = get_winoground_acc([winoground_scores[i] for i in self.original_tags[tag]])
        for tag in self.new_tags:
            results[tag] = get_winoground_acc([winoground_scores[i] for i in self.new_tags[tag]])
            print(f"Winoground {tag} text score: {results[tag]['text']}")
            print(f"Winoground {tag} image score: {results[tag]['image']}")
            print(f"Winoground {tag} group score: {results[tag]['group']}")
        return results, acc['group']
    


class VG_Relation(Dataset):
    def __init__(self, image_preprocess, root_dir="./", download=False):
        '''
        image_preprocess: a function that takes in a PIL image and returns a tensor.
        root_dir: Directory for the VG-R dataset.
        download: Whether to download the dataset if it does not exist.
        '''
        self.root_dir = root_dir
        annotation_file = os.path.join(root_dir, "visual_genome_relation.json")
        image_dir = os.path.join(root_dir, "images")
        if not os.path.exists(image_dir):
            print("Image Directory for VG_Relation could not be found!")
            if download:
                self.download()
            else:
                raise RuntimeError("Please either download the dataset by letting `--download` or specify the correct directory.")
        
        if not os.path.exists(annotation_file):
            subprocess.call(["gdown", "--id", "1kX2iCHEv0CADL8dSO1nMdW-V0NqIAiP3", "--output", annotation_file])
                
        with open(annotation_file, "r") as f:
            self.dataset = json.load(f)
        
        self.all_relations = list()
        for item in self.dataset:
            item["image_path"] = os.path.join(image_dir, item["image_path"])
            self.all_relations.append(item["relation_name"])

        self.image_preprocess = image_preprocess

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        test_case = self.dataset[index]
        image = Image.open(test_case["image_path"]).convert('RGB')
        # Get the bounding box that contains the relation. This is to remove the irrelevant details in the scene.
        image = image.crop((test_case["bbox_x"], test_case["bbox_y"], test_case["bbox_x"] + test_case["bbox_w"], test_case["bbox_y"] + test_case["bbox_h"]))

        if self.image_preprocess is not None:
            image = self.image_preprocess(image)

        # Each test case has a correct and incorrect caption.
        true_caption = test_case["true_caption"]
        false_caption = test_case["false_caption"]
        item = edict({"image_options": [image], "caption_options": [false_caption, true_caption]})
        return item
    
    def download(self):
        os.makedirs(self.root_dir, exist_ok=True)
        image_zip_file = os.path.join(self.root_dir, "vgr_vga_images.zip")
        subprocess.call(["gdown", "--no-cookies", "1qaPlrwhGNMrR3a11iopZUT_GPP_LrgP9", "--output", image_zip_file])
        subprocess.call(["unzip", "vgr_vga_images.zip"], cwd=self.root_dir)

        
    def evaluate_scores(self, scores):
        """
        Scores: N x 1 x 2, i.e. first caption is the perturbed one, second is the positive one
        """
        if isinstance(scores, tuple):
            scores_i2t = scores[1]
            # scores_t2i = scores[0] 
        else:
            # scores_t2i = scores
            scores_i2t = scores

        metrics = {"Accuracy": None}
        preds = np.argmax(np.squeeze(scores_i2t, axis=1), axis=-1)
        correct_mask = (preds == 1)
        metrics["Accuracy"] = np.mean(correct_mask)

        all_relations = np.array(self.all_relations)

        result_records = []
        # Log the accuracy of all relations
        for relation in np.unique(all_relations):
            relation_mask = (all_relations == relation)
            if relation_mask.sum() == 0:
                continue
            result_records.append({
                "Relation": relation,
                "Accuracy": correct_mask[relation_mask].mean(),
                "Count": relation_mask.sum(),
                "Dataset": "Visual Genome Relation"
            })
            
        # return result_records
        symmetric = ['adjusting', 'attached to', 'between', 'bigger than', 'biting', 'boarding', 'brushing', 'chewing', 'cleaning', 'climbing', 'close to', 'coming from', 'coming out of', 'contain', 'crossing', 'dragging', 'draped over', 'drinking', 'drinking from', 'driving', 'driving down', 'driving on', 'eating from', 'eating in', 'enclosing', 'exiting', 'facing', 'filled with', 'floating in', 'floating on', 'flying', 'flying above', 'flying in', 'flying over', 'flying through', 'full of', 'going down', 'going into', 'going through', 'grazing in', 'growing in', 'growing on', 'guiding', 'hanging from', 'hanging in', 'hanging off', 'hanging over', 'higher than', 'holding onto', 'hugging', 'in between', 'jumping off', 'jumping on', 'jumping over', 'kept in', 'larger than', 'leading', 'leaning over', 'leaving', 'licking', 'longer than', 'looking in', 'looking into', 'looking out', 'looking over', 'looking through', 'lying next to', 'lying on top of', 'making', 'mixed with', 'mounted on', 'moving', 'on the back of', 'on the edge of', 'on the front of', 'on the other side of', 'opening', 'painted on', 'parked at', 'parked beside', 'parked by', 'parked in', 'parked in front of', 'parked near', 'parked next to', 'perched on', 'petting', 'piled on', 'playing', 'playing in', 'playing on', 'playing with', 'pouring', 'reaching for', 'reading', 'reflected on', 'riding on', 'running in', 'running on', 'running through', 'seen through', 'sitting behind', 'sitting beside', 'sitting by', 'sitting in front of', 'sitting near', 'sitting next to', 'sitting under', 'skiing down', 'skiing on', 'sleeping in', 'sleeping on', 'smiling at', 'sniffing', 'splashing', 'sprinkled on', 'stacked on', 'standing against', 'standing around', 'standing behind', 'standing beside', 'standing in front of', 'standing near', 'standing next to', 'staring at', 'stuck in', 'surrounding', 'swimming in', 'swinging', 'talking to', 'topped with', 'touching', 'traveling down', 'traveling on', 'tying', 'typing on', 'underneath', 'wading in', 'waiting for', 'walking across', 'walking by', 'walking down', 'walking next to', 'walking through', 'working in', 'working on', 'worn on', 'wrapped around', 'wrapped in', 'by', 'of', 'near', 'next to', 'with', 'beside', 'on the side of', 'around']
        filtered_result_records = []
        for result_record in result_records:
            if result_record["Relation"] not in symmetric:
                filtered_result_records.append(result_record)
        
        df = pd.DataFrame(filtered_result_records)
        macro_accuracy = df.Accuracy.mean()
        print(f"VG-Relation Macro Accuracy: {macro_accuracy}")
        return filtered_result_records, macro_accuracy



class VG_Attribution(Dataset):
    def __init__(self, image_preprocess, root_dir="./", download=False):
        '''
        image_preprocess: a function that takes in a PIL image and returns a tensor.
        root_dir: Directory for the VG-A dataset.
        '''
        self.root_dir = root_dir
        annotation_file = os.path.join(root_dir, "visual_genome_attribution.json")
        image_dir = os.path.join(root_dir, "images")
        if not os.path.exists(image_dir):
            print("Image Directory for VG_Attribution could not be found!")
            if download:
                self.download()
            else:
                raise RuntimeError("Please either download the dataset by letting `--download` or specify the correct directory.")
        
        
        if not os.path.exists(annotation_file):
            subprocess.call(["gdown", "--id", "13tWvOrNOLHxl3Rm9cR3geAdHx2qR3-Tw", "--output", annotation_file])

        with open(annotation_file, "r") as f:
            self.dataset = json.load(f)
        
        for item in self.dataset:
            item["image_path"] = os.path.join(image_dir, item["image_path"])
        
        # Set of attributes in each test case
        self.all_attributes = [f"{item['attributes'][0]}_{item['attributes'][1]}" for item in self.dataset]
        self.image_preprocess = image_preprocess

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        test_case = self.dataset[index]
        image = Image.open(test_case["image_path"]).convert('RGB')
        # Get the bounding box that contains the relation. This is to remove the irrelevant details in the scene.
        image = image.crop((test_case["bbox_x"], test_case["bbox_y"], test_case["bbox_x"] + test_case["bbox_w"], test_case["bbox_y"] + test_case["bbox_h"]))

        if self.image_preprocess is not None:
            image = self.image_preprocess(image)

        # Each test case has a correct and incorrect caption.
        true_caption = test_case["true_caption"]
        false_caption = test_case["false_caption"]
        item = edict({"image_options": [image], "caption_options": [false_caption, true_caption]})
        return item
    
    def download(self):
        os.makedirs(self.root_dir, exist_ok=True)
        image_zip_file = os.path.join(self.root_dir, "vgr_vga_images.zip")
        subprocess.call(["gdown", "--no-cookies",  "1qaPlrwhGNMrR3a11iopZUT_GPP_LrgP9", "--output", image_zip_file])
        subprocess.call(["unzip", "vgr_vga_images.zip"], cwd=self.root_dir)

    
    def evaluate_scores(self, scores):
        """
        Scores: N x 1 x 2, i.e. first caption is the perturbed one, second is the positive one
        """
        if isinstance(scores, tuple):
            scores_i2t = scores[1]
            # scores_t2i = scores[0] 
        else:
            # scores_t2i = scores
            scores_i2t = scores

        preds = np.argmax(np.squeeze(scores_i2t, axis=1), axis=-1)
        correct_mask = (preds == 1)
        result_records = []
        all_attributes = np.array(self.all_attributes)
        for attr in np.unique(all_attributes):
            attr_mask = (all_attributes == attr)
            if attr_mask.sum() < 25:
                continue
            result_records.append({
                "Attributes": attr,
                "Accuracy": correct_mask[attr_mask].mean(),
                "Count": attr_mask.sum(),
                "Dataset": "Visual Genome Attribution"
            })
            
        df = pd.DataFrame(result_records)
        macro_accuracy = df.Accuracy.mean()
        print(f"VG-Attribution Macro Accuracy: {macro_accuracy}")
        return result_records, macro_accuracy



class COCO_Order(Dataset):
    def __init__(self, image_preprocess=None, root_dir="./", max_words=30, split="test",
                 download=False, save_test_cases=True):  
        """
        COCO Order Dataset.
        image_preprocess: image preprocessing function
        root_dir: The directory of the coco dataset. This directory should contain test2014 files.
        max_words: Cropping the caption to max_words.
        split: 'val' or 'test'
        download: Whether to download the dataset if it does not exist.
        save_test_cases: Whether to load/save the test cases to a file.
        """
        perturb_functions = [shuffle_nouns_and_adj, shuffle_allbut_nouns_and_adj, shuffle_within_trigrams, shuffle_trigrams]
        self.root_dir = os.path.join(root_dir, 'coco2014')
        if not os.path.exists(self.root_dir):
            print("Directory for COCO could not be found!")
            if download:
                print("Downloading COCO now.")
                self.download()
            else:
                raise RuntimeError("Please either download the dataset by letting `--download` or specify the correct directory.")
        
        urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json',
                'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'}
        filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'}
        download_url(urls[split],self.root_dir)
        
        self.annotation = json.load(open(os.path.join(self.root_dir,filenames[split]),'r'))
        self.image_preprocess = image_preprocess
        self.img_root = self.root_dir
        self.split = split
        
        ####################### Added
        if save_test_cases:
            assert perturb_functions == [shuffle_nouns_and_adj, shuffle_allbut_nouns_and_adj, shuffle_within_trigrams, shuffle_trigrams]
            self.test_cases_path = os.path.join('aro_captions', f"test_cases_{split}_coco2014.json")
            if os.path.exists(self.test_cases_path):
                self.test_cases = json.load(open(self.test_cases_path, 'r'))
            else:
                self.test_cases = []
                for img_id, ann in tqdm(enumerate(self.annotation)):
                    for i, caption in enumerate(ann['caption']):
                        test_case = {}
                        test_case["image"] = ann["image"]
                        test_case["caption_options"] = [pre_caption(caption,max_words)]

                        for perturb_fn in perturb_functions:
                            test_case["caption_options"].append(pre_caption(perturb_fn(caption), max_words))
                        self.test_cases.append(test_case)
                json.dump(self.test_cases, open(self.test_cases_path, 'w'))
            return
        ############################     
                
        ##### Below is the original code in ARO codebase
        self.test_cases = []
        
        for img_id, ann in tqdm(enumerate(self.annotation)):
            for i, caption in enumerate(ann['caption']):
                test_case = {}
                test_case["image"] = ann["image"]
                test_case["caption_options"] = [pre_caption(caption,max_words)]

                for perturb_fn in perturb_functions:
                    test_case["caption_options"].append(pre_caption(perturb_fn(caption), max_words))
                self.test_cases.append(test_case)
                                    
    def __len__(self):
        return len(self.test_cases)
    
    def __getitem__(self, index):  
        test_case = self.test_cases[index]  
        image_path = os.path.join(self.img_root, test_case["image"])       
         
        image = Image.open(image_path).convert('RGB')    
        if self.image_preprocess is not None: 
            image = self.image_preprocess(image)  
        
        item = edict({"image_options": [image], "caption_options": test_case["caption_options"]})
        return item
    
    def download(self):
        import subprocess
        os.makedirs(self.root_dir, exist_ok=True)
        #subprocess.call(["wget", "http://images.cocodataset.org/zips/train2014.zip"], cwd=self.root_dir)
        #subprocess.call(["unzip", "train2014.zip"], cwd=self.root_dir)
        
        subprocess.call(["wget", "http://images.cocodataset.org/zips/val2014.zip"], cwd=self.root_dir)
        subprocess.call(["unzip", "val2014.zip"], cwd=self.root_dir)
        
        subprocess.call(["wget", "http://images.cocodataset.org/zips/test2014.zip"], cwd=self.root_dir)
        subprocess.call(["unzip", "test2014.zip"], cwd=self.root_dir)
        
    
    def evaluate_scores(self, scores):
        if isinstance(scores, tuple):
            scores_i2t = scores[1]
            # scores_t2i = scores[0]
        
        else:
            # scores_t2i = scores
            scores_i2t = scores
        
        preds = np.argmax(np.squeeze(scores_i2t, axis=1), axis=-1)
        correct_mask = (preds == 0)
        records = [{"Precision@1": np.mean(correct_mask)}]
        df = pd.DataFrame(records)
        macro_accuracy = df['Precision@1'].mean()
        print(f"COCO-order {self.split} Macro Accuracy: {df['Precision@1'].mean()}")
        return records, macro_accuracy


class Flickr30k_Order(Dataset):
    def __init__(self, image_preprocess, split, root_dir="./", max_words=30, save_test_cases=True,
                 *args, **kwargs):  
        """
        image_preprocess: image preprocessing function
        split: 'val' or 'test'
        root_dir: The directory of the flickr30k images. This should contain the `flickr30k-images` directory that \
            contains all the images. 
        save_test_cases: Whether to load/save the test cases to a file.
        """
        urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json',
                'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json'}
        filenames = {'val':'flickr30k_val.json','test':'flickr30k_test.json'}
        self.root_dir = os.path.join(root_dir, 'flickr30k')
        if not os.path.exists(self.root_dir):
            print("Directory for Flickr30k could not be found!")
            flickr_url = "https://forms.illinois.edu/sec/229675"
            raise RuntimeError(f"You need to manually sign up and download the dataset from {flickr_url} and place it in the `root_dir`.")
        
        download_url(urls[split],self.root_dir)
        self.annotation = json.load(open(os.path.join(self.root_dir,filenames[split]),'r'))
        self.image_preprocess = image_preprocess
        self.split = split
        self.test_cases = []
        
        perturb_functions = [shuffle_nouns_and_adj, shuffle_allbut_nouns_and_adj, shuffle_within_trigrams, shuffle_trigrams]
        
        ####################### Added by ZQ
        if save_test_cases:
            assert perturb_functions == [shuffle_nouns_and_adj, shuffle_allbut_nouns_and_adj, shuffle_within_trigrams, shuffle_trigrams]
            self.test_cases_path = os.path.join('aro_captions', f"test_cases_{split}_flickr30k.json")
            if os.path.exists(self.test_cases_path):
                self.test_cases = json.load(open(self.test_cases_path, 'r'))
            else:
                self.test_cases = []
                for img_id, ann in tqdm(enumerate(self.annotation)):
                    for i, caption in enumerate(ann['caption']):
                        test_case = {}
                        test_case["image"] = ann["image"]
                        test_case["caption_options"] = [pre_caption(caption,max_words)]

                        for perturb_fn in perturb_functions:
                            test_case["caption_options"].append(pre_caption(perturb_fn(caption), max_words))
                        self.test_cases.append(test_case)
                json.dump(self.test_cases, open(self.test_cases_path, 'w'))
            return
        ############################    
                
                
        ##### Below is the original code in ARO codebase
        for img_id, ann in tqdm(enumerate(self.annotation)):
            for i, caption in enumerate(ann['caption']):
                test_case = {}
                test_case["image"] = ann["image"]
                test_case["caption_options"] = [pre_caption(caption,max_words)]

                for perturb_fn in perturb_functions:
                    test_case["caption_options"].append(pre_caption(perturb_fn(caption), max_words))
                self.test_cases.append(test_case)
                                
    def __len__(self):
        return len(self.test_cases)
    
    def __getitem__(self, index):  
        test_case = self.test_cases[index]  
        image_path = os.path.join(self.root_dir, test_case["image"])        
        image = Image.open(image_path).convert('RGB')    
        
        if self.image_preprocess is not None: 
            image = self.image_preprocess(image)  
            
        item = edict({"image_options": [image], "caption_options": test_case["caption_options"]})
        return item
    
    def evaluate_scores(self, scores):
        if isinstance(scores, tuple):
            scores_i2t = scores[1]
        else:
            scores_i2t = scores
        
        preds = np.argmax(np.squeeze(scores_i2t, axis=1), axis=-1)
        correct_mask = (preds == 0)
        result_records = [{"Precision@1": np.mean(correct_mask)}]
        df = pd.DataFrame(result_records)
        macro_accuracy = df['Precision@1'].mean()
        print(f"Flickr-order {self.split} Macro Accuracy: {macro_accuracy}")
        return result_records, macro_accuracy


class Crepe_Productivity(Dataset):
    def __init__(self, image_preprocess, root_dir="./", download=False, hard_neg_type='atom', complexity='all'):
        '''
        image_preprocess: a function that takes in a PIL image and returns a tensor.
        root_dir: Directory for the Crepe dataset.
        download: Whether to download the dataset if it does not exist.
        # TODO: Document the rest of the arguments.
        '''
        import ast
        self.root_dir = root_dir
        if not os.path.exists(os.path.join(root_dir, "VG_100K")) or not os.path.exists(os.path.join(root_dir, "VG_100K_2")):
            print(f"Directory for Crepe dataset could not be found! Downloading to {root_dir}...")
            self.download()
        assert hard_neg_type in ['atom', 'swap_3', 'swap_5', 'negate']
        assert complexity in range(4, 13) or complexity == 'all'
        self.hard_neg_type = hard_neg_type
        self.complexity = complexity
        
        if complexity == 'all':
            self.data_path = os.path.join("crepe_productivity_0409", f"prod_vg_hard_negs_{hard_neg_type}_all.csv")
        else:
            self.data_path = os.path.join("crepe_productivity_0409", hard_neg_type, f"prod_vg_hard_negs_{hard_neg_type}_complexity_{complexity}.csv")
        df = pd.read_csv(self.data_path)
        assert 'x' in df.columns and 'y' in df.columns and 'width' in df.columns and 'height' in df.columns, "missing x, y, width, or height."
        self.xs = df['x'].tolist()
        self.ys = df['y'].tolist()
        self.heights = df['height'].tolist()
        self.widths = df['width'].tolist()
        self.hard_negs = [ast.literal_eval(ls_str) for ls_str in df['hard_negs']]
        self.images = df['image_id'].tolist()
        self.captions = df['caption'].tolist()
        
        # if not os.path.exists(image_dir):
        #     print("Image Directory could not be found!")
        #     if download:
        #         self.download()
        #     else:
        #         raise RuntimeError("Please either download the dataset by letting `--download` or specify the correct directory.")
        
        self.image_preprocess = image_preprocess
    
    def get_image_by_id(self, image_id): 
        vg_image_paths = [f'{self.root_dir}/VG_100K', f'{self.root_dir}/VG_100K_2']
        for p in vg_image_paths:
            path = os.path.join(p, f"{image_id}.jpg")
            if os.path.exists(path):
                return Image.open(path).convert("RGB")
        raise FileNotFoundError(f'The image with id {image_id} is not found.')

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        raw_image = self.get_image_by_id(self.images[idx])
        raw_image = torchvision.transforms.functional.crop(raw_image, self.ys[idx], self.xs[idx], self.heights[idx], self.widths[idx])
        image = self.image_preprocess(raw_image)
        texts = [str(self.captions[idx])] + list(self.hard_negs[idx])
        item = edict({"image_options": [image], "caption_options": texts})
        return item
    
    def download(self):
        os.makedirs(self.root_dir, exist_ok=True)
        subprocess.call(["wget", "https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip"], cwd=self.root_dir)
        subprocess.call(["wget", "https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip"], cwd=self.root_dir)
        subprocess.call(["unzip", "images.zip"], cwd=self.root_dir)
        subprocess.call(["unzip", "images2.zip"], cwd=self.root_dir)

    def evaluate_scores(self, scores):
        if isinstance(scores, tuple):
            scores_i2t = scores[1]
        else:
            scores_i2t = scores
        
        preds = np.argmax(np.squeeze(scores_i2t, axis=1), axis=-1)
        correct_mask = (preds == 0)
        result_records = [{"Precision@1": np.mean(correct_mask)}]
        df = pd.DataFrame(result_records)
        macro_accuracy = df['Precision@1'].mean()
        print(f"Crepe {self.hard_neg_type} {self.complexity} Macro Accuracy: {macro_accuracy}")
        return result_records, macro_accuracy
        

class EqBen_All(Dataset):
    def __init__(self, image_preprocess=None, root_dir='./', download=False):
        self.preprocess = image_preprocess
        
        self.root_dir = root_dir
        self.img_root = os.path.join(root_dir, 'image')
        self.ann_root = os.path.join(root_dir, 'ann_json_finegrained_random.json')
        self.global_ann = json.load(open(self.ann_root))
        self.transform = image_preprocess
        
        if not os.path.exists(self.img_root):
            print("Directory for EQBen could not be found!")
            if download:
                print("Downloading EQBen now.")
                self.download()
            else:
                raise RuntimeError("Please either download the dataset by letting `--download` or specify the correct directory.")

    def process_img_pixel(self, image_path):
        # Define the model-specifc data pre-process (e.g., augmentation) for image pixel.
        image = Image.open(image_path).convert("RGB")
        return self.transform(image)

    def process_img_npy(self, image_path):
        # Define the model-specifc data pre-process (e.g., augmentation) for image numpy file (Youcook2).
        image = Image.fromarray(np.load(image_path)[:, :, [2, 1, 0]], 'RGB')
        return self.transform(image)
    
    def download(self):
        # import subprocess
        # os.makedirs(self.root_dir, exist_ok=True)
        # subprocess.call(["wget", "https://storage.googleapis.com/eqben-data/eqben_filter.tar.gz"], cwd=self.root_dir)
        # subprocess.call(["tar", "-xf", "eqben_filter.tar.gz"], cwd=self.root_dir)
        import subprocess
        os.makedirs(self.img_root, exist_ok=True)
        if not os.path.exists(os.path.join(self.root_dir, 'eqben_image_full.tar.gz')):
            subprocess.call(["wget", "https://storage.googleapis.com/eqben-data/eqben_image_full.tar.gz"], cwd=self.root_dir)
        
        subprocess.call(["tar", '-xf' "eqben_image_full.tar.gz"], cwd=self.root_dir)
        
        subprocess.call(["wget", "https://storage.googleapis.com/eqben-data/eqben_ann/ann_json_finegrained_random.json"], cwd=self.root_dir)
    
    def __len__(self):
        return len(self.global_ann)

    
    def __getitem__(self, index):
        database = self.global_ann[index]['image'].split('/')[0]
        img_path = os.path.join(self.img_root, self.global_ann[index]['image'])
        caption = self.global_ann[index]['caption']
        if database == 'eqbenyoucook2':
            image = self.process_img_npy(img_path)
        else:
            image = self.process_img_pixel(img_path)
        item = edict({"image_options": [image], "caption_options": [caption]})
        return item
    
    def evaluate_scores(self, scores):
        # winoground_scores = self.get_winoground_scores(scores)
        # acc = self.get_winoground_acc(winoground_scores)
        # print("EQBen text score:", acc['text'])
        # print("EQBen image score:", acc['image'])
        # print("EQBen group score:", acc['group'])
        
        # results = {}
        # results['all'] = acc
        # return results, acc['group']
        return None, None

class EqBen_Val(Dataset):
    def __init__(self, image_preprocess=None, root_dir='./', download=False):
        self.preprocess = image_preprocess
        
        self.root_dir = root_dir
        self.img_root = os.path.join(root_dir, 'image_subset')
        self.ann_root = os.path.join(root_dir, 'eqben_subset_10percent_final.json')
        self.global_ann = json.load(open(self.ann_root))
        self.transform = image_preprocess
        
        if not os.path.exists(self.img_root):
            print("Directory for EQBen could not be found!")
            if download:
                print("Downloading EQBen now.")
                self.download()
            else:
                raise RuntimeError("Please either download the dataset by letting `--download` or specify the correct directory.")

    def process_img_pixel(self, image_path):
        # Define the model-specifc data pre-process (e.g., augmentation) for image pixel.
        image = Image.open(image_path).convert("RGB")
        return self.transform(image)

    def process_img_npy(self, image_path):
        # Define the model-specifc data pre-process (e.g., augmentation) for image numpy file (Youcook2).
        image = Image.fromarray(np.load(image_path)[:, :, [2, 1, 0]], 'RGB')
        return self.transform(image)
    
    def download(self):
        import subprocess
        os.makedirs(self.img_root, exist_ok=True)
        if not os.path.exists(os.path.join(self.root_dir, 'subset_image.tar.gz')):
            subprocess.call(["wget", "https://storage.googleapis.com/eqben-data/eqben_subset/subset_image.tar.gz"], cwd=self.root_dir)
        
        subprocess.call(["tar", '-xvf' "subset_image.tar.gz"], cwd=self.root_dir)
        
        subprocess.call(["wget", "https://storage.googleapis.com/eqben-data/eqben_subset/eqben_subset_10percent_final.json"], cwd=self.root_dir)
    
    def __len__(self):
        return len(self.global_ann)
    
    def __getitem__(self, index):
        file_type_0 = self.global_ann[index]['image0'].split('.')[-1]
        file_type_1 = self.global_ann[index]['image1'].split('.')[-1]
        if file_type_0 == 'npy':
            image_0 = self.process_img_npy(os.path.join(self.img_root, self.global_ann[index]['image0']))
        else:
            image_0 = self.process_img_pixel(os.path.join(self.img_root, self.global_ann[index]['image0']))
        if file_type_1 == 'npy':
            image_1 = self.process_img_npy(os.path.join(self.img_root, self.global_ann[index]['image1']))
        else:
            image_1 = self.process_img_pixel(os.path.join(self.img_root, self.global_ann[index]['image1']))
        
        caption_0 = self.global_ann[index]['caption0']
        caption_1 = self.global_ann[index]['caption1']
        item = edict({"image_options": [image_0, image_1], "caption_options": [caption_0, caption_1]})
        return item
    
    def evaluate_scores(self, scores):
        winoground_scores = get_winoground_scores(scores)
        acc = get_winoground_acc(winoground_scores)
        print("EQBen_Val text score:", acc['text'])
        print("EQBen_Val image score:", acc['image'])
        print("EQBen_Val group score:", acc['group'])
        acc_i2t, acc_t2i = get_winoground_retrieval_acc(scores)
        print("EQBen_Val i2t retrieval accuracy:", acc_i2t)
        print(f"EQBen_Val t2i retrieval accuracy:", acc_t2i)
        
        results = {}
        results['all'] = acc
        results['all'].update({'i2t': acc_i2t, 't2i': acc_t2i})
        return results, acc['group']


def image_url_download(url_file, to_folder):
    from urllib.request import urlretrieve
    count = 0
    if not os.path.exists(to_folder):
        os.mkdir(to_folder)
    contents = json.load(open(url_file, 'r'))
    for img in contents:
        if not os.path.exists(os.path.join(to_folder, img)):
            try:
                urlretrieve(contents[img], os.path.join(to_folder, img))
            except:
                count += 1
    print(count)

class VL_CheckList(Dataset):
    def __init__(self, image_preprocess, root_dir="./", download=False, split='Attribute', subsplit='color', subsubsplit=None):
        '''
        image_preprocess: a function that takes in a PIL image and returns a tensor.
        root_dir: Directory for the VG-R dataset.
        download: Whether to download the dataset if it does not exist.
        '''
        self.root_dir = os.path.join(root_dir, "VL-CheckList")
        self.swig_dir = os.path.join(self.root_dir, "swig")
        self.hake_dir = os.path.join(self.root_dir, "hake")
        self.vcoco_dir = os.path.join(self.hake_dir, "vcoco")
        self.vg_dir = os.path.join(self.root_dir, "vg")
        self.hcvrd_dir = os.path.join(self.hake_dir, "hcvrd")
        
        self.vg_dir_1 = os.path.join(root_dir, "VG_100K")
        self.vg_dir_2 = os.path.join(root_dir, "VG_100K_2")
        # if not os.path.exists(self.root_dir):
        #     print(f"Image Directory for VL-CheckList could not be found at {self.root_dir}!")
        if download:
            self.download()
        else:
            raise RuntimeError("Please either download the dataset by letting `--download` or specify the correct directory.")
        
        
        self.dataset_names = {
            'Attribute': {
                'action': {
                    'vaw': {'root': self.vg_dir, 'json': "vl_checklist/Attribute/vaw/action.json"},
                    'vg': {'root': self.vg_dir, 'json': "vl_checklist/Attribute/vg/action.json"},
                },
                'color': {
                    'vaw': {'root': self.vg_dir, 'json': "vl_checklist/Attribute/vaw/color.json"},
                    'vg': {'root': self.vg_dir, 'json': "vl_checklist/Attribute/vg/color.json"},
                },
                'material': {
                    'vaw': {'root': self.vg_dir, 'json': "vl_checklist/Attribute/vaw/material.json"},
                    'vg': {'root': self.vg_dir, 'json': "vl_checklist/Attribute/vg/material.json"},
                },
                'size': {
                    'vaw': {'root': self.vg_dir, 'json': "vl_checklist/Attribute/vaw/size.json"},
                    'vg': {'root': self.vg_dir, 'json': "vl_checklist/Attribute/vg/size.json"},
                },
                'state': {
                    'vaw': {'root': self.vg_dir, 'json': "vl_checklist/Attribute/vaw/state.json"},
                    'vg': {'root': self.vg_dir, 'json': "vl_checklist/Attribute/vg/state.json"},
                },
            },
            'Object': {
                'Location': {
                    'center': {
                        'hake': {'root': self.hake_dir, 'json': "vl_checklist/Object/Location/hake_location/hake_obj_center.json"},
                        'swig_agent': {'root': self.swig_dir, 'json': "vl_checklist/Object/Location/swig_location/swig_agent/swig_agent_center.json"},
                        'swig_destination': {'root': self.swig_dir, 'json': "vl_checklist/Object/Location/swig_location/swig_destination/swig_destination_center.json"},
                        'swig_item': {'root': self.swig_dir, 'json': "vl_checklist/Object/Location/swig_location/swig_item/swig_item_center.json"},
                        'swig_tool': {'root': self.swig_dir, 'json': "vl_checklist/Object/Location/swig_location/swig_tool/swig_tool_center.json"},
                        'vg_obj': {'root': self.vg_dir, 'json': "vl_checklist/Object/Location/vg_location/vg_obj_center.json"},
                        'vg_subj': {'root': self.vg_dir, 'json': "vl_checklist/Object/Location/vg_location/vg_subj_center.json"},
                    },
                    'margin': {
                        'hake': {'root': self.hake_dir, 'json': "vl_checklist/Object/Location/hake_location/hake_obj_margin.json"},
                        'swig_agent': {'root': self.swig_dir, 'json': "vl_checklist/Object/Location/swig_location/swig_agent/swig_agent_margin.json"},
                        'swig_destination': {'root': self.swig_dir, 'json': "vl_checklist/Object/Location/swig_location/swig_destination/swig_destination_margin.json"},
                        'swig_item': {'root': self.swig_dir, 'json': "vl_checklist/Object/Location/swig_location/swig_item/swig_item_margin.json"},
                        'swig_tool': {'root': self.swig_dir, 'json': "vl_checklist/Object/Location/swig_location/swig_tool/swig_tool_margin.json"},
                        'vg_obj': {'root': self.vg_dir, 'json': "vl_checklist/Object/Location/vg_location/vg_obj_margin.json"},
                        'vg_subj': {'root': self.vg_dir, 'json': "vl_checklist/Object/Location/vg_location/vg_subj_margin.json"},
                    },
                    'mid': {
                        'hake': {'root': self.hake_dir, 'json': "vl_checklist/Object/Location/hake_location/hake_obj_mid.json"},
                        'swig_agent': {'root': self.swig_dir, 'json': "vl_checklist/Object/Location/swig_location/swig_agent/swig_agent_mid.json"},
                        'swig_destination': {'root': self.swig_dir, 'json': "vl_checklist/Object/Location/swig_location/swig_destination/swig_destination_mid.json"},
                        'swig_item': {'root': self.swig_dir, 'json': "vl_checklist/Object/Location/swig_location/swig_item/swig_item_mid.json"},
                        'swig_tool': {'root': self.swig_dir, 'json': "vl_checklist/Object/Location/swig_location/swig_tool/swig_tool_mid.json"},
                        'vg_obj': {'root': self.vg_dir, 'json': "vl_checklist/Object/Location/vg_location/vg_obj_mid.json"},
                        'vg_subj': {'root': self.vg_dir, 'json': "vl_checklist/Object/Location/vg_location/vg_subj_mid.json"},
                    },
                },
                'Size': {
                    'large': {
                        'hake': {'root': self.hake_dir, 'json': "vl_checklist/Object/Size/hake_size/hake_obj_large.json"},
                        'swig_agent': {'root': self.swig_dir, 'json': "vl_checklist/Object/Size/swig_size/swig_agent/swig_agent_large.json"},
                        'swig_destination': {'root': self.swig_dir, 'json': "vl_checklist/Object/Size/swig_size/swig_destination/swig_destination_large.json"},
                        'swig_item': {'root': self.swig_dir, 'json': "vl_checklist/Object/Size/swig_size/swig_item/swig_item_large.json"},
                        'swig_tool': {'root': self.swig_dir, 'json': "vl_checklist/Object/Size/swig_size/swig_tool/swig_tool_large.json"},
                        'vg_obj': {'root': self.vg_dir, 'json': "vl_checklist/Object/Size/vg_size/vg_obj_large.json"},
                        'vg_subj': {'root': self.vg_dir, 'json': "vl_checklist/Object/Size/vg_size/vg_subj_large.json"},
                    },
                    'medium': {
                        'hake': {'root': self.hake_dir, 'json': "vl_checklist/Object/Size/hake_size/hake_obj_medium.json"},
                        'swig_agent': {'root': self.swig_dir, 'json': "vl_checklist/Object/Size/swig_size/swig_agent/swig_agent_medium.json"},
                        'swig_destination': {'root': self.swig_dir, 'json': "vl_checklist/Object/Size/swig_size/swig_destination/swig_destination_medium.json"},
                        'swig_item': {'root': self.swig_dir, 'json': "vl_checklist/Object/Size/swig_size/swig_item/swig_item_medium.json"},
                        'swig_tool': {'root': self.swig_dir, 'json': "vl_checklist/Object/Size/swig_size/swig_tool/swig_tool_medium.json"},
                        'vg_obj': {'root': self.vg_dir, 'json': "vl_checklist/Object/Size/vg_size/vg_obj_medium.json"},
                        'vg_subj': {'root': self.vg_dir, 'json': "vl_checklist/Object/Size/vg_size/vg_subj_medium.json"},
                    },
                    'small': {
                        'hake': {'root': self.hake_dir, 'json': "vl_checklist/Object/Size/hake_size/hake_obj_small.json"},
                        'swig_agent': {'root': self.swig_dir, 'json': "vl_checklist/Object/Size/swig_size/swig_agent/swig_agent_small.json"},
                        'swig_destination': {'root': self.swig_dir, 'json': "vl_checklist/Object/Size/swig_size/swig_destination/swig_destination_small.json"},
                        'swig_item': {'root': self.swig_dir, 'json': "vl_checklist/Object/Size/swig_size/swig_item/swig_item_small.json"},
                        'swig_tool': {'root': self.swig_dir, 'json': "vl_checklist/Object/Size/swig_size/swig_tool/swig_tool_small.json"},
                        'vg_obj': {'root': self.vg_dir, 'json': "vl_checklist/Object/Size/vg_size/vg_obj_small.json"},
                        'vg_subj': {'root': self.vg_dir, 'json': "vl_checklist/Object/Size/vg_size/vg_subj_small.json"},
                    },
                },
            },
            'Relation': {
                'action': {
                    'hake': {'root': self.hake_dir, 'json': "vl_checklist/Relation/hake_action.json"},
                    'swig': {'root': self.swig_dir, 'json': "vl_checklist/Relation/swig_action.json"},
                    'vg': {'root': self.vg_dir, 'json': "vl_checklist/Relation/vg/action.json"},
                },
                'spatial': {
                    'vg': {'root': self.vg_dir, 'json': "vl_checklist/Relation/vg/spatial.json"},
                },
            },
        }
        self.check_all_images_exist()
        
        self.split = split
        self.subsplit = subsplit
        self.subsubsplit = subsubsplit
        assert self.split in self.dataset_names
        assert self.subsplit in self.dataset_names[self.split]
        subset = self.dataset_names[self.split][self.subsplit]
        if self.subsubsplit is not None:
            assert self.subsubsplit in self.dataset_names[self.split][self.subsplit]
            subset = self.dataset_names[self.split][self.subsplit][self.subsubsplit]
        self.dataset = self.get_dataset(subset)

        self.image_preprocess = image_preprocess
    
    def get_dataset(self, subset):
        test_cases = []

        count = 0
        image_count = 0
        duplicate_count = 0
        for dataset in subset.values():
            json_file_path = dataset['json']

            with open(json_file_path, 'r') as json_file:
                tuples_list = json.load(json_file)
                
            for img_tuple in tuples_list:
                img_path = os.path.join(self.root_dir, dataset['root'], img_tuple[0])
                
                assert os.path.exists(img_path)
                assert len(img_tuple[1]["NEG"]) == len(img_tuple[1]["POS"])
                for pos_caption, neg_caption in zip(img_tuple[1]["POS"], img_tuple[1]["NEG"]):
                    if pos_caption == neg_caption:
                        # print(f"Found duplicate caption for {img_path}! Skipping...")
                        # print(f"Caption: {pos_caption} | {neg_caption}")
                        duplicate_count += 1
                        continue
                    test_case = {
                        'image_path': img_path,
                        'true_caption': pos_caption,
                        'false_caption': neg_caption,
                    }
                    test_cases.append(test_case)
                    count += 1
                image_count += 1
            # print(f"Found {duplicate_count} duplicate captions")
            # print(f"Found {image_count} images")
            # print(f"Found {count} test cases")

        print(f"Found {len(test_cases)} test cases in total.")
        return test_cases
    
    def check_all_images_exist(self):
        all_images_exist = True

        def check_images_exist(subcategory):
            nonlocal all_images_exist
            for dataset in subcategory.values():
                json_file_path = dataset['json']

                with open(json_file_path, 'r') as json_file:
                    tuples_list = json.load(json_file)
                    
                for img_tuple in tuples_list:
                    img_path = os.path.join(self.root_dir, dataset['root'], img_tuple[0])
                    
                    if not os.path.exists(img_path):
                        print(f"Image not found: {img_path}")
                        all_images_exist = False

        for category_name, category_dict in self.dataset_names.items():
            if category_name == "Object":
                for subcategory_key, subcategory_value in category_dict.items():
                    for subsubcategory_key, subsubcategory_value in subcategory_value.items():
                        check_images_exist(subsubcategory_value)
            else:
                for subcategory_key, subcategory_value in category_dict.items():
                    check_images_exist(subcategory_value)
        return all_images_exist

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        test_case = self.dataset[index]
        image = Image.open(test_case["image_path"]).convert('RGB')

        if self.image_preprocess is not None:
            image = self.image_preprocess(image)

        # Each test case has a correct and incorrect caption.
        true_caption = test_case["true_caption"]
        false_caption = test_case["false_caption"]
        item = edict({"image_options": [image], "caption_options": [false_caption, true_caption]})
        return item
    
    def download(self):
        os.makedirs(self.root_dir, exist_ok=True)
        # os.makedirs(self.swig_dir, exist_ok=True)
        os.makedirs(self.hake_dir, exist_ok=True)
        os.makedirs(self.vg_dir, exist_ok=True)
        os.makedirs(self.vcoco_dir, exist_ok=True)
        os.makedirs(self.hcvrd_dir, exist_ok=True)
        swig_file = os.path.join(self.root_dir, "images_512.zip")
        if not os.path.exists(swig_file):
            subprocess.call(["wget", "https://swig-data-weights.s3.us-east-2.amazonaws.com/"], cwd=self.root_dir)
            subprocess.call(["unzip", "images_512.zip"], cwd=self.root_dir)
            subprocess.call(["mv", 'images_512', 'swig'], cwd=self.root_dir)
        
        if not os.path.exists(os.path.join(self.hake_dir, "openimages")):
            open_images_file = os.path.join(self.hake_dir, "openimages.tar.gz")
            if not os.path.exists(open_images_file):
                subprocess.call(["gdown", "--id", "1dfClthFhwPIKKgb2w6mPeU5om5gmmajB"], cwd=self.hake_dir)
            subprocess.call(["unzip", "openimages.zip"], cwd=self.hake_dir)
        
        if not os.path.exists(os.path.join(self.hake_dir, "hake_images_20190730")):
            hake1_file = os.path.join(self.hake_dir, "hake_images_20190730.tar.gz")
            if not os.path.exists(hake1_file):
                subprocess.call(["gdown", "--id", "1Smrsy9AsOUyvj66ytGmB5M3WknljwuXL"], cwd=self.hake_dir)
            subprocess.call(["tar", "-xvf", "hake_images_20190730.tar.gz"], cwd=self.hake_dir)
        
        if not os.path.exists(os.path.join(self.hake_dir, "hake_images_20200614")):
            hake2_file = os.path.join(self.hake_dir, "hake_images_20200614.tar.gz")
            if not os.path.exists(hake2_file):
                subprocess.call(["gdown", "--id", "14K_4FfjviJNDVLJdGM96W2ZLN55dDb2-"], cwd=self.hake_dir)
            subprocess.call(["tar", "-xvf", "hake_images_20200614.tar.gz"], cwd=self.hake_dir)
            
        if not os.path.exists(os.path.join(self.hake_dir, "hico_20160224_det")):
            hico_file = os.path.join(self.hake_dir, "hico_20160224_det.tar.gz")
            if not os.path.exists(hico_file):
                subprocess.call(["gdown", "--id", "1QZcJmGVlF9f4h-XLWe9Gkmnmj2z1gSnk"], cwd=self.hake_dir)
            subprocess.call(["tar", "-xvf", "hico_20160224_det.tar.gz"], cwd=self.hake_dir)
        
        if not os.path.exists(os.path.join(self.hake_dir, "pic")):
            pic_file = os.path.join(self.hake_dir, "pic.zip")
            if not os.path.exists(pic_file):
                subprocess.call(["gdown", "--id", "1Sh-aiTDECxl1SmTsgeer4TWuri9_I9lN"], cwd=self.hake_dir)
            subprocess.call(["unzip" "pic.zip"], cwd=self.hake_dir)
        
            
        root_dir = os.path.join(self.root_dir, "..")
        vg1_file = os.path.join(root_dir, "images.zip")
        vg2_file = os.path.join(root_dir, "images2.zip")
        if not os.path.exists(os.path.join(self.vg_dir, "VG_100K")):
            if not os.path.exists(os.path.join(root_dir, 'VG_100K')):
                if not os.path.exists(vg1_file):
                    subprocess.call(["wget", "https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip"], cwd=root_dir)
                subprocess.call(["unzip", "images.zip"], cwd=root_dir)
            subprocess.call(["ln", '-s', os.path.join(root_dir, 'VG_100K'), os.path.join(self.vg_dir, "VG_100K")])
        
        if not os.path.exists(os.path.join(self.vg_dir, "VG_100K_2")):
            if not os.path.exists(os.path.join(root_dir, 'VG_100K_2')):
                if not os.path.exists(vg2_file):
                    subprocess.call(["wget", "https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip"], cwd=root_dir)
                subprocess.call(["unzip", "images2.zip"], cwd=root_dir)
            subprocess.call(["ln", '-s', os.path.join(root_dir, 'VG_100K_2'), os.path.join(self.vg_dir, "VG_100K_2")])
        
        if len(os.listdir(self.hcvrd_dir)) != 15284:
            # image_url_download("hcvrd_url.json", self.hcvrd_dir)
            subprocess.call(["gdown", "--id", "1PnCrere0bZxCyAk7BQpFRMC7BhU7ZjNn"], cwd=self.hake_dir)
            subprocess.call(["unzip", "hcvrd.zip"], cwd=self.hake_dir)
        
        coco_dir = os.path.join(root_dir, "coco2014")
        os.makedirs(coco_dir, exist_ok=True)
        if not os.path.exists(os.path.join(coco_dir, "train2014.zip")):
            subprocess.call(["wget", "http://images.cocodataset.org/zips/train2014.zip"], cwd=coco_dir)
            subprocess.call(["unzip", "train2014.zip"], cwd=coco_dir)
        if not os.path.exists(os.path.join(self.vcoco_dir, 'train2014')):
            subprocess.call(['ln', '-s', os.path.join(coco_dir, 'train2014'), os.path.join(self.vcoco_dir, 'train2014')])
            
        if not os.path.exists(os.path.join(coco_dir, "val2014.zip")):
            subprocess.call(["wget", "http://images.cocodataset.org/zips/val2014.zip"], cwd=coco_dir)
            subprocess.call(["unzip", "val2014.zip"], cwd=coco_dir)
        if not os.path.exists(os.path.join(self.vcoco_dir, 'val2014')):
            subprocess.call(['ln', '-s', os.path.join(coco_dir, 'val2014'), os.path.join(self.vcoco_dir, 'val2014')])
            
    def evaluate_scores(self, scores):
        """
        Scores: N x 1 x 2, i.e. first caption is the perturbed one, second is the positive one
        """
        if isinstance(scores, tuple):
            scores_i2t = scores[1]
            # scores_t2i = scores[0] 
        else:
            # scores_t2i = scores
            scores_i2t = scores

        metrics = {"Accuracy": None}
        preds = np.argmax(np.squeeze(scores_i2t, axis=1), axis=-1)
        correct_mask = (preds == 1)
        metrics["Accuracy"] = np.mean(correct_mask)
        macro_accuracy = metrics["Accuracy"]
        print(f"VL-CheckList {self.split} Macro Accuracy: {macro_accuracy}")
        return correct_mask, macro_accuracy

class Laion(Dataset):

    def __init__(self, root_dir='./', split="subset_5000_01", image_preprocess=None, download=True):
        self.root_dir = root_dir
        self.split = split
        self.image_preprocess = image_preprocess
        self.download = download
        self.data_dir = os.path.join(self.root_dir, "laion")
        
        file_names = [
            
            "subset_3_00.json", "subset_3_01.json", "subset_3_02.json",
            "subset_3_03.json", "subset_3_04.json", "subset_3_05.json",
            "subset_3_06.json", "subset_3_07.json", "subset_3_08.json",
            "subset_3_09.json",
            
            "subset_10_00.json", "subset_10_01.json", "subset_10_02.json",
            "subset_10_03.json", "subset_10_04.json", "subset_10_05.json",
            "subset_10_06.json", "subset_10_07.json", "subset_10_08.json",
            "subset_10_09.json",
            
            "subset_50_00.json", "subset_50_01.json", "subset_50_02.json",
            "subset_50_03.json", "subset_50_04.json", "subset_50_05.json",
            "subset_50_06.json", "subset_50_07.json", "subset_50_08.json",
            "subset_50_09.json",
            "subset_1000_00.json", "subset_1000_01.json", "subset_1000_02.json", "subset_1000_03.json",
            "subset_1000_04.json", "subset_1000_05.json", "subset_1000_06.json", "subset_1000_07.json",
            "subset_1000_08.json", "subset_1000_09.json",
            "subset_2000_00.json", "subset_2000_01.json", "subset_2000_02.json",
            "subset_500_00.json", "subset_500_01.json", "subset_500_02.json",
            "subset_100_00.json", "subset_100_01.json", "subset_100_02.json", "subset_100_03.json", "subset_100_04.json", "subset_100_05.json",
            "subset_100_06.json", "subset_100_07.json", "subset_100_08.json", "subset_100_09.json",
            "subset_5000_00.json", "subset_5000_01.json"
        ]
        assert f"{split}.json" in file_names
        with open(os.path.join('laion', f"{split}.json")) as f:
            self.data = json.load(f)
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        path = self.data[idx]['path']
        caption = self.data[idx]['caption']
        img_path = os.path.join(self.data_dir, f"{path}")
        img = Image.open(img_path).convert('RGB')
        if self.image_preprocess is not None:
            img = self.image_preprocess(img)
        return img, caption


def get_winoground(image_preprocess, root_dir='./'):
    return Winoground(image_preprocess=image_preprocess, root_dir=root_dir)

def get_visual_genome_relation(image_preprocess, root_dir='./', download=True):
    return VG_Relation(image_preprocess=image_preprocess, root_dir=root_dir, download=download)

def get_visual_genome_attribution(image_preprocess, root_dir='./', download=True):
    return VG_Attribution(image_preprocess=image_preprocess, root_dir=root_dir, download=download)

def get_coco_order(image_preprocess, max_words=30, download=True, root_dir='./', split="test"):
    return COCO_Order(root_dir=root_dir, split=split, image_preprocess=image_preprocess, max_words=max_words, download=download)

def get_flickr30k_order(image_preprocess, max_words=30, download=True, root_dir='./', split="test"):
    return Flickr30k_Order(root_dir=root_dir, split=split, image_preprocess=image_preprocess, max_words=max_words, download=download)

def get_crepe_productivity(image_preprocess, root_dir='./', hard_neg_type='negate', complexity='all', download=False):
    return Crepe_Productivity(root_dir=root_dir, hard_neg_type=hard_neg_type, complexity=complexity, image_preprocess=image_preprocess, download=download)

def get_vl_checklist(image_preprocess, root_dir='./', download=True, split='Attribute', subsplit='color', subsubsplit=None):
    return VL_CheckList(image_preprocess=image_preprocess, root_dir=root_dir, download=download, split=split, subsplit=subsplit, subsubsplit=subsubsplit)

def get_eqben_all(image_preprocess, root_dir='./', download=True):
    return EqBen_All(image_preprocess=image_preprocess, root_dir=root_dir, download=download)

def get_eqben_val(image_preprocess, root_dir='./', download=True):
    return EqBen_Val(image_preprocess=image_preprocess, root_dir=root_dir, download=download)