# python visualize_datasets.py --dataset VG_Relation 
# python visualize_datasets.py --dataset VG_Attribution 
# python visualize_datasets.py --dataset COCO_Order
# python visualize_datasets.py --dataset Flickr30k_Order 
# python visualize_datasets.py --dataset CrepeAtom 
# python visualize_datasets.py --dataset CrepeNegate
# python visualize_datasets.py --dataset Crepe5Swap 


# python visualize_datasets.py --dataset EqBen_Val 
# python visualize_datasets.py --dataset Winoground 

# python visualize_datasets.py --dataset VL_CheckList_Attribute_action
# python visualize_datasets.py --dataset VL_CheckList_Attribute_color 
# python visualize_datasets.py --dataset VL_CheckList_Attribute_material 
# python visualize_datasets.py --dataset VL_CheckList_Attribute_state
# python visualize_datasets.py --dataset VL_CheckList_Attribute_size

# python visualize_datasets.py --dataset VL_CheckList_Relation_action
# python visualize_datasets.py --dataset VL_CheckList_Relation_spatial 

# python visualize_datasets.py --dataset VL_CheckList_Object_Location_center 
# python visualize_datasets.py --dataset VL_CheckList_Object_Location_margin
# python visualize_datasets.py --dataset VL_CheckList_Object_Location_mid
# python visualize_datasets.py --dataset VL_CheckList_Object_Size_large 
# python visualize_datasets.py --dataset VL_CheckList_Object_Size_medium
# python visualize_datasets.py --dataset VL_CheckList_Object_Size_small
import os
import shutil
import torch
import argparse
from tqdm import tqdm
from main import config, get_dataset, get_model, get_score_config

NUM_SAMPLES = 400


def save_as_markdown(results, results_path):
    keys = results[0].keys()
    header = "| " + " | ".join(keys) + " |"
    separator = "| " + "|".join(["----"] * len(keys)) + " |"

    def format(item):
        s = "|"
        for k in keys:
            if 'image' in k:
                s += f" ![]({item[k]}) |"
            else:
                s += f" {item[k]} |"
        return s
    # Generate table rows for each dictionary in the list
    rows = [
        format(item)
        for item in results
    ]

    # Combine the header, separator, and rows to create the table
    table = "\n".join([header, separator] + rows)

    # Write the Markdown table to a file
    with open(results_path, "w") as f:
        f.write(table)

def get_example(example, save_dir, k, true_caption_idx=0):
    for image in example['image_options']:
        new_image_path = os.path.join(save_dir, f"{k}.png")
        image.save(new_image_path)
    # Must use the relative path to the image
    return {
        'id': k,
        'image': new_image_path,
        'caption_pos': example['caption_options'][true_caption_idx],
        'caption_neg': [example['caption_options'][i] for i in range(len(example['caption_options'])) if i != true_caption_idx],
    }

def get_winoground_example(example, save_dir, k):
    for idx, image in enumerate(example['image_options']):
        new_image_path = os.path.join(save_dir, f"{k}_{idx}.png")
        image.save(new_image_path)
    # Must use the relative path to the image
    return {
        'id': k,
        'image_0': os.path.join(save_dir, f"{k}_0.png"),
        'image_1': os.path.join(save_dir, f"{k}_1.png"),
        'caption_0': example['caption_options'][0],
        'caption_1': example['caption_options'][1],
    }

def main():
    all_results = {}
    args = config()
    
    print(f"Visualizing {args.dataset} dataset.")

    dataset = get_dataset(args.dataset, image_preprocess=lambda x: x, download=True, root_dir=args.root_dir)

    image_save_dir = os.path.join("visualization", args.dataset)
    if not os.path.exists(image_save_dir):
        os.makedirs(image_save_dir)
        
    is_winoground_style = len(dataset[0]['image_options']) == 2
    if args.dataset in ['VG_Relation', 'VG_Attribution']:
        true_caption_idx = 1
    elif "VL_CheckList" in args.dataset:
        true_caption_idx = 1
    else:
        true_caption_idx = 0
    
    results = []
    count = NUM_SAMPLES
    torch.manual_seed(1)
    
    random_indices = [int(k) for k in torch.randperm(len(dataset))[:count]]
    for k in tqdm(random_indices):
        if is_winoground_style:
            try:
                results.append(get_winoground_example(dataset[k], image_save_dir, k))
            except:
                import pdb; pdb.set_trace()
        else:
            results.append(get_example(dataset[k], image_save_dir, k, true_caption_idx=true_caption_idx))
            
    save_as_markdown(results, os.path.join(f"{args.dataset}.md"))

if __name__ == "__main__":
    main()