import os
import json
import torch
from collections import defaultdict

os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = ""

root_dir = os.path.join(os.path.dirname(__file__), "../")

from google.cloud import storage

from mm_embeddings import MODELS, get_embeddings

EMBEDDING_MODEL = MODELS["vlm2vec-qwen7b"](os.path.join(root_dir, "data/vhelm/raw/vhelm"))

bucket_name = "crfm-helm-public"
save_dir = os.path.join(root_dir, "data/vhelm/raw")
os.makedirs(save_dir, exist_ok=True)

# download scenarios
storage_client = storage.Client()
bucket = storage_client.bucket(bucket_name)
folder_name = "vhelm/benchmark_output/scenarios"
blobs = bucket.list_blobs(prefix=folder_name)
for blob in blobs:
    file_path = os.path.join(save_dir, blob.name)
    os.makedirs(os.path.dirname(file_path), exist_ok=True)
    blob.download_to_filename(file_path)

# download records
folders = [
    "vhelm/benchmark_output/runs/v2.1.2/",
]

storage_client = storage.Client()
bucket = storage_client.bucket(bucket_name)
for folder_name in folders:
    blobs = bucket.list_blobs(prefix=folder_name)
    for blob in blobs:
        if blob.name.endswith("per_instance_stats.json") or blob.name.endswith("instances.json") or blob.name.endswith("display_requests.json"):
            file_path = os.path.join(save_dir, blob.name)
            os.makedirs(os.path.dirname(file_path), exist_ok=True)
            blob.download_to_filename(file_path)
            
# parse model and dataset names
models = set()
datasets = set()
model_to_datasets = defaultdict(list)

save_dir = os.path.join(root_dir, "data/vhelm/raw/vhelm/benchmark_output/runs/v2.1.2/")
for d in os.listdir(save_dir):
    if "model=" not in d: continue
    dataset, model = d.split("model=")
    dataset = dataset.strip(',: ')
    model = model.split(',')[0]
    datasets.add(dataset)
    models.add(model)
    model_to_datasets[model].append(d)

print(f"downloaded {len(models)} models in total")
print(sorted(models))

selected_models = [
    'anthropic_claude-3-5-sonnet-20240620', 
    'anthropic_claude-3-5-sonnet-20241022', 
    'anthropic_claude-3-7-sonnet-20250219', 
    'anthropic_claude-3-7-sonnet-20250219-thinking-64k', 
    'anthropic_claude-3-haiku-20240307', 
    'anthropic_claude-3-opus-20240229', 
    'anthropic_claude-3-sonnet-20240229',
    
    'google_gemini-1.0-pro-vision-001', 
    'google_gemini-1.5-flash-001-safety-block-none', 
    'google_gemini-1.5-flash-preview-0514', 
    'google_gemini-1.5-pro-001-safety-block-none', 
    'google_gemini-1.5-flash-002', 
    'google_gemini-1.5-pro-002', 
    'google_gemini-1.5-pro-preview-0409', 
    'google_gemini-1.5-pro-preview-0514', 
    'google_gemini-2.0-flash-001', 
    'google_gemini-2.0-flash-exp', 
    'google_gemini-2.0-flash-lite-001', 
    'google_gemini-2.0-flash-lite-preview-02-05', 
    'google_gemini-2.0-flash-thinking-exp-01-21', 
    'google_gemini-2.0-pro-exp-02-05', 
    'google_gemini-2.5-pro-exp-03-25', 
    'google_paligemma-3b-mix-224', 
    'google_paligemma-3b-mix-448',
    
    'openai_gpt-4-1106-vision-preview', 
    'openai_gpt-4-turbo-2024-04-09', 
    'openai_gpt-4.1-2025-04-14', 
    'openai_gpt-4.1-mini-2025-04-14', 
    'openai_gpt-4.1-nano-2025-04-14', 
    'openai_gpt-4.5-preview-2025-02-27', 
    'openai_gpt-4o-2024-05-13', 
    'openai_gpt-4o-2024-08-06', 
    'openai_gpt-4o-2024-11-20', 
    'openai_gpt-4o-mini-2024-07-18', 
    'openai_o1-2024-12-17', 
    'openai_o3-2025-04-16-high-reasoning-effort', 
    'openai_o4-mini-2025-04-16-high-reasoning-effort',
    
    'qwen_qwen-vl-chat', 
    'qwen_qwen2-vl-72b-instruct', 
    'qwen_qwen2-vl-7b-instruct', 
    'qwen_qwen2.5-vl-32b-instruct', 
    'qwen_qwen2.5-vl-3b-instruct', 
    'qwen_qwen2.5-vl-72b-instruct', 
    'qwen_qwen2.5-vl-7b-instruct',
]

selected_datasets = [
    "blink",
    "mmmu",
    "flickr30k",
    "gqa",
    "math_vista",
    "mme",
    "real_world_qa",
    "seed_bench",
    "unicorn",
    "vibe_eval",
    "vqa",
]

# check dataset size
dataset_folders = model_to_datasets["anthropic_claude-3-5-sonnet-20240620"]
for dataset in selected_datasets:
    folders = [d for d in dataset_folders if d.startswith(dataset)]
    print(f"{dataset} has {len(folders)} folders")
    total_requests = 0
    for f in folders:
        with open(os.path.join(save_dir, f, "display_requests.json")) as fpr:
            requests = json.load(fpr)
        print(f"  {f} => {len(requests)}")
        total_requests += len(requests)
    print(f"{dataset}: {total_requests}")
    print("---------")

'''
blink has 14 folders
  blink:category=Relative_Depth,model=anthropic_claude-3-5-sonnet-20240620,groups=blink_perception => 124
  blink:category=Visual_Similarity,model=anthropic_claude-3-5-sonnet-20240620,groups=blink_perception => 135
  blink:category=Jigsaw,model=anthropic_claude-3-5-sonnet-20240620,groups=blink_reasoning => 150
  blink:category=Forensic_Detection,model=anthropic_claude-3-5-sonnet-20240620,groups=blink_knowledge => 132
  blink:category=IQ_Test,model=anthropic_claude-3-5-sonnet-20240620,groups=blink_reasoning => 150
  blink:category=Semantic_Correspondence,model=anthropic_claude-3-5-sonnet-20240620,groups=blink_perception => 139
  blink:category=Visual_Correspondence,model=anthropic_claude-3-5-sonnet-20240620,groups=blink_perception => 172
  blink:category=Multi-view_Reasoning,model=anthropic_claude-3-5-sonnet-20240620,groups=blink_reasoning => 133
  blink:category=Spatial_Relation,model=anthropic_claude-3-5-sonnet-20240620,groups=blink_perception => 143
  blink:category=Functional_Correspondence,model=anthropic_claude-3-5-sonnet-20240620,groups=blink_knowledge => 130
  blink:category=Object_Localization,model=anthropic_claude-3-5-sonnet-20240620,groups=blink_perception => 122
  blink:category=Art_Style,model=anthropic_claude-3-5-sonnet-20240620,groups=blink_perception => 117
  blink:category=Counting,model=anthropic_claude-3-5-sonnet-20240620,groups=blink_perception => 120
  blink:category=Relative_Reflectance,model=anthropic_claude-3-5-sonnet-20240620,groups=blink_perception => 134
blink: 1901
---------
mmmu has 30 folders
  mmmu:subject=Pharmacy,question_type=multiple-choice,model=anthropic_claude-3-5-sonnet-20240620 => 27
  mmmu:subject=Diagnostics_and_Laboratory_Medicine,question_type=multiple-choice,model=anthropic_claude-3-5-sonnet-20240620 => 30
  mmmu:subject=Biology,question_type=multiple-choice,model=anthropic_claude-3-5-sonnet-20240620 => 29
  mmmu:subject=Manage,question_type=multiple-choice,model=anthropic_claude-3-5-sonnet-20240620 => 24
  mmmu:subject=Basic_Medical_Science,question_type=multiple-choice,model=anthropic_claude-3-5-sonnet-20240620 => 28
  mmmu:subject=Psychology,question_type=multiple-choice,model=anthropic_claude-3-5-sonnet-20240620 => 29
  mmmu:subject=Music,question_type=multiple-choice,model=anthropic_claude-3-5-sonnet-20240620 => 30
  mmmu:subject=Art_Theory,question_type=multiple-choice,model=anthropic_claude-3-5-sonnet-20240620 => 30
  mmmu:subject=Literature,question_type=multiple-choice,model=anthropic_claude-3-5-sonnet-20240620 => 30
  mmmu:subject=Agriculture,question_type=multiple-choice,model=anthropic_claude-3-5-sonnet-20240620 => 30
  mmmu:subject=Art,question_type=multiple-choice,model=anthropic_claude-3-5-sonnet-20240620 => 30
  mmmu:subject=Mechanical_Engineering,question_type=multiple-choice,model=anthropic_claude-3-5-sonnet-20240620 => 30
  mmmu:subject=Architecture_and_Engineering,question_type=multiple-choice,model=anthropic_claude-3-5-sonnet-20240620 => 29
  mmmu:subject=Electronics,question_type=multiple-choice,model=anthropic_claude-3-5-sonnet-20240620 => 16
  mmmu:subject=Accounting,question_type=multiple-choice,model=anthropic_claude-3-5-sonnet-20240620 => 30
  mmmu:subject=Physics,question_type=multiple-choice,model=anthropic_claude-3-5-sonnet-20240620 => 29
  mmmu:subject=Geography,question_type=multiple-choice,model=anthropic_claude-3-5-sonnet-20240620 => 27
  mmmu:subject=Math,question_type=multiple-choice,model=anthropic_claude-3-5-sonnet-20240620 => 29
  mmmu:subject=Design,question_type=multiple-choice,model=anthropic_claude-3-5-sonnet-20240620 => 30
  mmmu:subject=Materials,question_type=multiple-choice,model=anthropic_claude-3-5-sonnet-20240620 => 30
  mmmu:subject=Marketing,question_type=multiple-choice,model=anthropic_claude-3-5-sonnet-20240620 => 29
  mmmu:subject=History,question_type=multiple-choice,model=anthropic_claude-3-5-sonnet-20240620 => 30
  mmmu:subject=Sociology,question_type=multiple-choice,model=anthropic_claude-3-5-sonnet-20240620 => 30
  mmmu:subject=Energy_and_Power,question_type=multiple-choice,model=anthropic_claude-3-5-sonnet-20240620 => 30
  mmmu:subject=Chemistry,question_type=multiple-choice,model=anthropic_claude-3-5-sonnet-20240620 => 23
  mmmu:subject=Computer_Science,question_type=multiple-choice,model=anthropic_claude-3-5-sonnet-20240620 => 27
  mmmu:subject=Economics,question_type=multiple-choice,model=anthropic_claude-3-5-sonnet-20240620 => 29
  mmmu:subject=Clinical_Medicine,question_type=multiple-choice,model=anthropic_claude-3-5-sonnet-20240620 => 30
  mmmu:subject=Public_Health,question_type=multiple-choice,model=anthropic_claude-3-5-sonnet-20240620 => 30
  mmmu:subject=Finance,question_type=multiple-choice,model=anthropic_claude-3-5-sonnet-20240620 => 22
mmmu: 847
---------
flickr30k has 1 folders
  flickr30k:model=anthropic_claude-3-5-sonnet-20240620 => 1000
flickr30k: 1000
---------
gqa has 1 folders
  gqa:model=anthropic_claude-3-5-sonnet-20240620 => 1000
gqa: 1000
---------
math_vista has 8 folders
  math_vista:grade=elementary_school,question_type=free_form,model=anthropic_claude-3-5-sonnet-20240620 => 151
  math_vista:grade=elementary_school,question_type=multi_choice,model=anthropic_claude-3-5-sonnet-20240620 => 50
  math_vista:grade=high_school,question_type=free_form,model=anthropic_claude-3-5-sonnet-20240620 => 31
  math_vista:grade=high_school,question_type=multi_choice,model=anthropic_claude-3-5-sonnet-20240620 => 275
  math_vista:grade=college,question_type=free_form,model=anthropic_claude-3-5-sonnet-20240620 => 57
  math_vista:grade=college,question_type=multi_choice,model=anthropic_claude-3-5-sonnet-20240620 => 55
  math_vista:grade=daily_life,question_type=free_form,model=anthropic_claude-3-5-sonnet-20240620 => 221
  math_vista:grade=daily_life,question_type=multi_choice,model=anthropic_claude-3-5-sonnet-20240620 => 160
math_vista: 1000
---------
mme has 4 folders
  mme:subject=posters,model=anthropic_claude-3-5-sonnet-20240620 => 294
  mme:subject=landmark,model=anthropic_claude-3-5-sonnet-20240620 => 400
  mme:subject=artwork,model=anthropic_claude-3-5-sonnet-20240620 => 400
  mme:subject=celebrity,model=anthropic_claude-3-5-sonnet-20240620 => 340
mme: 1434
---------
real_world_qa has 1 folders
  real_world_qa:model=anthropic_claude-3-5-sonnet-20240620 => 765
real_world_qa: 765
---------
seed_bench has 2 folders
  seed_bench:subject=visual-reasoning,model=anthropic_claude-3-5-sonnet-20240620 => 331
  seed_bench:subject=instance-interaction,model=anthropic_claude-3-5-sonnet-20240620 => 97
seed_bench: 428
---------
unicorn has 2 folders
  unicorn:subject=OODCV-VQA,model=anthropic_claude-3-5-sonnet-20240620 => 1000
  unicorn:subject=Sketchy-VQA,model=anthropic_claude-3-5-sonnet-20240620 => 1000
unicorn: 2000
---------
vibe_eval has 2 folders
  vibe_eval:subject=difficulty-hard,model=anthropic_claude-3-5-sonnet-20240620 => 100
  vibe_eval:subject=difficulty-normal,model=anthropic_claude-3-5-sonnet-20240620 => 169
vibe_eval: 269
---------
vqa has 3 folders
  vqa:model=anthropic_claude-3-5-sonnet-20240620,data_augmentation=dialect_prob=1.0_source=SAE_target=AAVE,groups=vqa_dialect => 1944
  vqa:model=anthropic_claude-3-5-sonnet-20240620,data_augmentation=robustness,groups=vqa_robustness => 2000
  vqa:model=anthropic_claude-3-5-sonnet-20240620,groups=vqa_base => 1000
vqa: 4944
---------
'''

VHELM_DATA = {}
VHELM_EMBED = {}

def _check_ordering(folders):
    # check all models used the same query
    examples = {}
    folders_to_exclude = set()
    for d in folders:
        with open(os.path.join(save_dir, d, "instances.json")) as f:
            instances = [ex['id'] for ex in json.load(f)]
        if dataset in examples:
            if instances != examples[dataset]:
                print(f"{d} is different")
                folders_to_exclude.add(d)
        else:
            examples[dataset] = instances
        # validate evaluation results are complete and use the same order
        with open(os.path.join(save_dir, d, "per_instance_stats.json")) as f:
            instances = [ex['instance_id'] for ex in json.load(f)]
        if examples[dataset] != instances[:len(examples[dataset])]:
            print(f"{d} has different instances in stats")
            folders_to_exclude.add(d)
        # validate display_requests files are complete and use the same order
        with open(os.path.join(save_dir, d, "display_requests.json")) as f:
            instances = [ex['instance_id'] for ex in json.load(f)]
        if examples[dataset] != instances:
            print(f"{d} has different instances in display")
            folders_to_exclude.add(d)
            
    print(f"{len(folders_to_exclude)} models to exclude: {folders_to_exclude}")
    
    return list(set(folders).difference(folders_to_exclude))

def _format_content(content):
    formatted_content = []
    for cnt in content:
      if cnt["content_type"] == "text/plain":
        formatted_content.append({"type": "text", "text": cnt["text"]})
      elif cnt["content_type"] == "image/jpeg":
        formatted_content.append({"type": "image", "location": cnt["location"]})
      elif cnt["content_type"] == "image/png":
        formatted_content.append({"type": "image", "location": cnt["location"]})
      else:
        raise NotImplementedError()
      
    return formatted_content

def _get_metrics(filename, metric_name):
  with open(filename) as f:
    stats = json.load(f)
    
  values = []
  instance_ids = []
  for stat in stats:
      for st in stat["stats"]:
          if st['name']['name'] == metric_name:
              value = st['sum'] # model based metrics may fail and no 'mean' in stat
              if stat["instance_id"] in instance_ids:
                values[instance_ids.index(stat["instance_id"])] = value
              else:
                values.append(value)
                instance_ids.append(stat["instance_id"])
              
  input_tokens = []
  instance_ids = []
  for stat in stats:
      for st in stat["stats"]:
          if st['name']['name'] == "num_prompt_tokens":
              input_token = st['mean']
              if stat["instance_id"] in instance_ids:
                input_tokens[instance_ids.index(stat["instance_id"])] = input_token
              else:
                input_tokens.append(input_token)
                instance_ids.append(stat["instance_id"])
              
  assert len(input_tokens) == len(values), f"{len(input_tokens)} != {len(values)}"
  
  output_tokens = []
  instance_ids = []
  for stat in stats:
      for st in stat["stats"]:
          if st['name']['name'] == "num_output_tokens":
              output_token = st['mean']
              if stat["instance_id"] in instance_ids:
                output_tokens[instance_ids.index(stat["instance_id"])] = output_token
              else:
                output_tokens.append(output_token)
                instance_ids.append(stat["instance_id"])
              
  assert len(output_tokens) == len(values), f"{len(output_tokens)} != {len(values)}"
                
  return values, input_tokens, output_tokens

def process_vqa():
    PRIMARY_METRIC = "quasi_prefix_exact_match"
    result_folders = [d for d in os.listdir(save_dir) if d.startswith("vqa") and d.endswith("vqa_base")]
    result_folders = [d for d in result_folders if d.split("model=")[1].split(',')[0] in selected_models]
    result_folders = _check_ordering(result_folders)
    result_folders = sorted(result_folders)
    
    data = {}
    for d in result_folders:
        model = d.split("model=")[1].split(',')[0]
        values, input_tokens, output_tokens = _get_metrics(os.path.join(save_dir, d, "per_instance_stats.json"), PRIMARY_METRIC)
        
        data[model] = {
          "stat_type": "binary",
          "stat_name": PRIMARY_METRIC,
          "values": values,
          "input_tokens": input_tokens,
          "output_tokens": output_tokens,
        }
    
    with open(os.path.join(save_dir, result_folders[0], "display_requests.json")) as f:
        requests = json.load(f)
        
    contents = [_format_content(request["request"]["multimodal_prompt"]["media_objects"]) for request in requests]
    embed = {
      "contents": contents,
      "embeddings": get_embeddings(EMBEDDING_MODEL, contents)
    }

    return {"vqa": data}, {"vqa": embed}
    
# vqa_data, vqa_embed = process_vqa()
# torch.save((vqa_data, vqa_embed), os.path.join(root_dir, "data/vhelm/cache/vqa.pth"))
vqa_data, vqa_embed = torch.load(os.path.join(root_dir, "data/vhelm/cache/vqa.pth"))
VHELM_DATA.update(vqa_data)
VHELM_EMBED.update(vqa_embed)

def process_vibe_eval():
  vibe_eval_data = {}
  vibe_eval_embed = {}
  
  PRIMARY_METRIC = "prometheus_vision"
  for dataset in ["vibe_eval:subject=difficulty-hard", "vibe_eval:subject=difficulty-normal"]:
    result_folders = [d for d in os.listdir(save_dir) if d.startswith(dataset)]
    result_folders = [d for d in result_folders if d.split("model=")[1].split(',')[0] in selected_models]
    result_folders = _check_ordering(result_folders)
    result_folders = sorted(result_folders)
    
    data = {}
    for d in result_folders:
        model = d.split("model=")[1].split(',')[0]
        values, input_tokens, output_tokens = _get_metrics(os.path.join(save_dir, d, "per_instance_stats.json"), PRIMARY_METRIC)
        
        print("--values--", min(values), max(values))
        
        data[model] = {
          "stat_type": "likert-5",
          "stat_name": PRIMARY_METRIC,
          "values": [v / 5 for v in values],
          "input_tokens": input_tokens,
          "output_tokens": output_tokens,
        }
    
    with open(os.path.join(save_dir, result_folders[0], "display_requests.json")) as f:
        requests = json.load(f)
                
    contents = [_format_content(request["request"]["multimodal_prompt"]["media_objects"]) for request in requests]
    embed = {
      "contents": contents,
      "embeddings": get_embeddings(EMBEDDING_MODEL, contents)
    }
    
    vibe_eval_data[dataset] = data
    vibe_eval_embed[dataset] = embed
    
  return vibe_eval_data, vibe_eval_embed

# vibe_eval_data, vibe_eval_embed = process_vibe_eval()
# torch.save((vibe_eval_data, vibe_eval_embed), os.path.join(root_dir, "data/vhelm/cache/vibe_eval.pth"))
vibe_eval_data, vibe_eval_embed = torch.load(os.path.join(root_dir, "data/vhelm/cache/vibe_eval.pth"))
VHELM_DATA.update(vibe_eval_data)
VHELM_EMBED.update(vibe_eval_embed)

def process_unicorn():
  all_data = {}
  all_embed = {}
  
  PRIMARY_METRIC = "quasi_prefix_exact_match"
  for dataset in ["unicorn:subject=OODCV-VQA", "unicorn:subject=Sketchy-VQA"]:
    result_folders = [d for d in os.listdir(save_dir) if d.startswith(dataset)]
    result_folders = [d for d in result_folders if d.split("model=")[1].split(',')[0] in selected_models]
    result_folders = _check_ordering(result_folders)
    result_folders = sorted(result_folders)
    
    data = {}
    for d in result_folders:
        model = d.split("model=")[1].split(',')[0]
        values, input_tokens, output_tokens = _get_metrics(os.path.join(save_dir, d, "per_instance_stats.json"), PRIMARY_METRIC)
                
        data[model] = {
          "stat_type": "binary",
          "stat_name": PRIMARY_METRIC,
          "values": values,
          "input_tokens": input_tokens,
          "output_tokens": output_tokens,
        }
    
    with open(os.path.join(save_dir, result_folders[0], "display_requests.json")) as f:
        requests = json.load(f)
                
    contents = [_format_content(request["request"]["multimodal_prompt"]["media_objects"]) for request in requests]
    embed = {
      "contents": contents,
      "embeddings": get_embeddings(EMBEDDING_MODEL, contents)
    }
    
    all_data[dataset] = data
    all_embed[dataset] = embed
    
  return all_data, all_embed

# unicorn_data, unicorn_embed = process_unicorn()
# torch.save((unicorn_data, unicorn_embed), os.path.join(root_dir, "data/vhelm/cache/unicorn.pth"))
unicorn_data, unicorn_embed = torch.load(os.path.join(root_dir, "data/vhelm/cache/unicorn.pth"))
VHELM_DATA.update(unicorn_data)
VHELM_EMBED.update(unicorn_embed)

def process_seed_bench():
  all_data = {}
  all_embed = {}
  
  PRIMARY_METRIC = "quasi_prefix_exact_match"
  for dataset in ["seed_bench:subject=visual-reasoning", "seed_bench:subject=instance-interaction"]:
    result_folders = [d for d in os.listdir(save_dir) if d.startswith(dataset)]
    result_folders = [d for d in result_folders if d.split("model=")[1].split(',')[0] in selected_models]
    result_folders = _check_ordering(result_folders)
    result_folders = sorted(result_folders)
    
    data = {}
    for d in result_folders:
        model = d.split("model=")[1].split(',')[0]
        values, input_tokens, output_tokens = _get_metrics(os.path.join(save_dir, d, "per_instance_stats.json"), PRIMARY_METRIC)
                
        data[model] = {
          "stat_type": "binary",
          "stat_name": PRIMARY_METRIC,
          "values": values,
          "input_tokens": input_tokens,
          "output_tokens": output_tokens,
        }
    
    with open(os.path.join(save_dir, result_folders[0], "display_requests.json")) as f:
        requests = json.load(f)
                
    contents = [_format_content(request["request"]["multimodal_prompt"]["media_objects"]) for request in requests]
    embed = {
      "contents": contents,
      "embeddings": get_embeddings(EMBEDDING_MODEL, contents)
    }
    
    all_data[dataset] = data
    all_embed[dataset] = embed
    
  return all_data, all_embed

# seed_bench_data, seed_bench_embed = process_seed_bench()
# torch.save((seed_bench_data, seed_bench_embed), os.path.join(root_dir, "data/vhelm/cache/seed_bench.pth"))
seed_bench_data, seed_bench_embed = torch.load(os.path.join(root_dir, "data/vhelm/cache/seed_bench.pth"))
VHELM_DATA.update(seed_bench_data)
VHELM_EMBED.update(seed_bench_embed)

def process_real_world_qa():
  all_data = {}
  all_embed = {}
  
  PRIMARY_METRIC = "quasi_exact_match"
  for dataset in ["real_world_qa"]:
    result_folders = [d for d in os.listdir(save_dir) if d.startswith(dataset)]
    result_folders = [d for d in result_folders if d.split("model=")[1].split(',')[0] in selected_models]
    result_folders = _check_ordering(result_folders)
    result_folders = sorted(result_folders)
    
    data = {}
    for d in result_folders:
        model = d.split("model=")[1].split(',')[0]
        values, input_tokens, output_tokens = _get_metrics(os.path.join(save_dir, d, "per_instance_stats.json"), PRIMARY_METRIC)
                
        data[model] = {
          "stat_type": "binary",
          "stat_name": PRIMARY_METRIC,
          "values": values,
          "input_tokens": input_tokens,
          "output_tokens": output_tokens,
        }
    
    with open(os.path.join(save_dir, result_folders[0], "display_requests.json")) as f:
        requests = json.load(f)
                
    contents = [_format_content(request["request"]["multimodal_prompt"]["media_objects"]) for request in requests]
    embed = {
      "contents": contents,
      "embeddings": get_embeddings(EMBEDDING_MODEL, contents)
    }
    
    all_data[dataset] = data
    all_embed[dataset] = embed
    
  return all_data, all_embed

# real_world_qa_data, real_world_qa_embed = process_real_world_qa()
# torch.save((real_world_qa_data, real_world_qa_embed), os.path.join(root_dir, "data/vhelm/cache/real_world_qa.pth"))
real_world_qa_data, real_world_qa_embed = torch.load(os.path.join(root_dir, "data/vhelm/cache/real_world_qa.pth"))
VHELM_DATA.update(real_world_qa_data)
VHELM_EMBED.update(real_world_qa_embed)

def process_mme():
  all_data = {}
  all_embed = {}
  
  PRIMARY_METRIC = "quasi_prefix_exact_match"
  for dataset in ["mme:subject=posters", "mme:subject=landmark", "mme:subject=artwork", "mme:subject=celebrity"]:
    result_folders = [d for d in os.listdir(save_dir) if d.startswith(dataset)]
    result_folders = [d for d in result_folders if d.split("model=")[1].split(',')[0] in selected_models]
    result_folders = _check_ordering(result_folders)
    result_folders = sorted(result_folders)
    
    data = {}
    for d in result_folders:
        model = d.split("model=")[1].split(',')[0]
        values, input_tokens, output_tokens = _get_metrics(os.path.join(save_dir, d, "per_instance_stats.json"), PRIMARY_METRIC)
                
        data[model] = {
          "stat_type": "binary",
          "stat_name": PRIMARY_METRIC,
          "values": values,
          "input_tokens": input_tokens,
          "output_tokens": output_tokens,
        }
    
    with open(os.path.join(save_dir, result_folders[0], "display_requests.json")) as f:
        requests = json.load(f)
                
    contents = [_format_content(request["request"]["multimodal_prompt"]["media_objects"]) for request in requests]
    embed = {
      "contents": contents,
      "embeddings": get_embeddings(EMBEDDING_MODEL, contents)
    }
    
    all_data[dataset] = data
    all_embed[dataset] = embed
    
  return all_data, all_embed

# mme_data, mme_embed = process_mme()
# torch.save((mme_data, mme_embed), os.path.join(root_dir, "data/vhelm/cache/mme.pth"))
mme_data, mme_embed = torch.load(os.path.join(root_dir, "data/vhelm/cache/mme.pth"))
VHELM_DATA.update(mme_data)
VHELM_EMBED.update(mme_embed)

def process_math_vista_free_form():
  all_data = {}
  all_embed = {}
  
  PRIMARY_METRIC = "exact_match"
  for dataset in ["math_vista:grade=elementary_school,question_type=free_form", "math_vista:grade=high_school,question_type=free_form", "math_vista:grade=college,question_type=free_form", "math_vista:grade=daily_life,question_type=free_form"]:
    result_folders = [d for d in os.listdir(save_dir) if d.startswith(dataset)]
    result_folders = [d for d in result_folders if d.split("model=")[1].split(',')[0] in selected_models]
    result_folders = _check_ordering(result_folders)
    result_folders = sorted(result_folders)
    
    data = {}
    for d in result_folders:
        model = d.split("model=")[1].split(',')[0]
        values, input_tokens, output_tokens = _get_metrics(os.path.join(save_dir, d, "per_instance_stats.json"), PRIMARY_METRIC)
                
        data[model] = {
          "stat_type": "binary",
          "stat_name": PRIMARY_METRIC,
          "values": values,
          "input_tokens": input_tokens,
          "output_tokens": output_tokens,
        }
    
    with open(os.path.join(save_dir, result_folders[0], "display_requests.json")) as f:
        requests = json.load(f)
                
    contents = [_format_content(request["request"]["multimodal_prompt"]["media_objects"]) for request in requests]
    embed = {
      "contents": contents,
      "embeddings": get_embeddings(EMBEDDING_MODEL, contents)
    }
    
    all_data[dataset] = data
    all_embed[dataset] = embed
    
  return all_data, all_embed

# math_vista_free_form_data, math_vista_free_form_embed = process_math_vista_free_form()
# torch.save((math_vista_free_form_data, math_vista_free_form_embed), os.path.join(root_dir, "data/vhelm/cache/math_vista_free_form.pth"))
math_vista_free_form_data, math_vista_free_form_embed = torch.load(os.path.join(root_dir, "data/vhelm/cache/math_vista_free_form.pth"))
VHELM_DATA.update(math_vista_free_form_data)
VHELM_EMBED.update(math_vista_free_form_embed)

def process_math_vista_multi_choice():
  all_data = {}
  all_embed = {}
  
  PRIMARY_METRIC = "exact_match"
  for dataset in ["math_vista:grade=elementary_school,question_type=multi_choice", "math_vista:grade=high_school,question_type=multi_choice", "math_vista:grade=college,question_type=multi_choice", "math_vista:grade=daily_life,question_type=multi_choice"]:
    result_folders = [d for d in os.listdir(save_dir) if d.startswith(dataset)]
    result_folders = [d for d in result_folders if d.split("model=")[1].split(',')[0] in selected_models]
    result_folders = _check_ordering(result_folders)
    result_folders = sorted(result_folders)
    
    data = {}
    for d in result_folders:
        model = d.split("model=")[1].split(',')[0]
        values, input_tokens, output_tokens = _get_metrics(os.path.join(save_dir, d, "per_instance_stats.json"), PRIMARY_METRIC)
                
        data[model] = {
          "stat_type": "binary",
          "stat_name": PRIMARY_METRIC,
          "values": values,
          "input_tokens": input_tokens,
          "output_tokens": output_tokens,
        }
    
    with open(os.path.join(save_dir, result_folders[0], "display_requests.json")) as f:
        requests = json.load(f)
                
    contents = [_format_content(request["request"]["multimodal_prompt"]["media_objects"]) for request in requests]
    embed = {
      "contents": contents,
      "embeddings": get_embeddings(EMBEDDING_MODEL, contents)
    }
    
    all_data[dataset] = data
    all_embed[dataset] = embed
    
  return all_data, all_embed

# math_vista_multi_choice_data, math_vista_multi_choice_embed = process_math_vista_multi_choice()
# torch.save((math_vista_multi_choice_data, math_vista_multi_choice_embed), os.path.join(root_dir, "data/vhelm/cache/math_vista_multi_choice.pth"))

def process_gqa():
  all_data = {}
  all_embed = {}
  
  PRIMARY_METRIC = "quasi_prefix_exact_match"
  for dataset in ["gqa"]:
    result_folders = [d for d in os.listdir(save_dir) if d.startswith(dataset)]
    result_folders = [d for d in result_folders if d.split("model=")[1].split(',')[0] in selected_models]
    result_folders = _check_ordering(result_folders)
    result_folders = sorted(result_folders)
    
    data = {}
    for d in result_folders:
        model = d.split("model=")[1].split(',')[0]
        values, input_tokens, output_tokens = _get_metrics(os.path.join(save_dir, d, "per_instance_stats.json"), PRIMARY_METRIC)
                
        data[model] = {
          "stat_type": "binary",
          "stat_name": PRIMARY_METRIC,
          "values": values,
          "input_tokens": input_tokens,
          "output_tokens": output_tokens,
        }
    
    with open(os.path.join(save_dir, result_folders[0], "display_requests.json")) as f:
        requests = json.load(f)
                
    contents = [_format_content(request["request"]["multimodal_prompt"]["media_objects"]) for request in requests]
    embed = {
      "contents": contents,
      "embeddings": get_embeddings(EMBEDDING_MODEL, contents)
    }
    
    all_data[dataset] = data
    all_embed[dataset] = embed
    
  return all_data, all_embed

# gqa_data, gqa_embed = process_gqa()
# torch.save((gqa_data, gqa_embed), os.path.join(root_dir, "data/vhelm/cache/gqa.pth"))
gqa_data, gqa_embed = torch.load(os.path.join(root_dir, "data/vhelm/cache/gqa.pth"))
VHELM_DATA.update(gqa_data)
VHELM_EMBED.update(gqa_embed)

def process_flickr30k():
  all_data = {}
  all_embed = {}
  
  PRIMARY_METRIC = "prometheus_vision"
  for dataset in ["flickr30k"]:
    result_folders = [d for d in os.listdir(save_dir) if d.startswith(dataset)]
    result_folders = [d for d in result_folders if d.split("model=")[1].split(',')[0] in selected_models]
    result_folders = _check_ordering(result_folders)
    result_folders = sorted(result_folders)
    
    data = {}
    for d in result_folders:
        model = d.split("model=")[1].split(',')[0]
        values, input_tokens, output_tokens = _get_metrics(os.path.join(save_dir, d, "per_instance_stats.json"), PRIMARY_METRIC)
                
        data[model] = {
          "stat_type": "binary",
          "stat_name": PRIMARY_METRIC,
          "values": [v/5 for v in values],
          "input_tokens": input_tokens,
          "output_tokens": output_tokens,
        }
    
    with open(os.path.join(save_dir, result_folders[0], "display_requests.json")) as f:
        requests = json.load(f)
                
    contents = [_format_content(request["request"]["multimodal_prompt"]["media_objects"]) for request in requests]
    embed = {
      "contents": contents,
      "embeddings": get_embeddings(EMBEDDING_MODEL, contents)
    }
    
    all_data[dataset] = data
    all_embed[dataset] = embed
    
  return all_data, all_embed

# flickr30k_data, flickr30k_embed = process_flickr30k()
# torch.save((flickr30k_data, flickr30k_embed), os.path.join(root_dir, "data/vhelm/cache/flickr30k.pth"))
flickr30k_data, flickr30k_embed = torch.load(os.path.join(root_dir, "data/vhelm/cache/flickr30k.pth"))
VHELM_DATA.update(flickr30k_data)
VHELM_EMBED.update(flickr30k_embed)

def process_mmmu():
  all_data = {}
  all_embed = {}
  
  PRIMARY_METRIC = "quasi_prefix_exact_match"
  for dataset in [
    'mmmu:subject=Pharmacy,question_type=multiple-choice',
    'mmmu:subject=Diagnostics_and_Laboratory_Medicine,question_type=multiple-choice',
    'mmmu:subject=Biology,question_type=multiple-choice',
    'mmmu:subject=Manage,question_type=multiple-choice',
    'mmmu:subject=Basic_Medical_Science,question_type=multiple-choice',
    'mmmu:subject=Psychology,question_type=multiple-choice',
    'mmmu:subject=Music,question_type=multiple-choice',
    'mmmu:subject=Art_Theory,question_type=multiple-choice',
    'mmmu:subject=Literature,question_type=multiple-choice',
    'mmmu:subject=Agriculture,question_type=multiple-choice',
    'mmmu:subject=Art,question_type=multiple-choice',
    'mmmu:subject=Mechanical_Engineering,question_type=multiple-choice',
    'mmmu:subject=Architecture_and_Engineering,question_type=multiple-choice',
    'mmmu:subject=Electronics,question_type=multiple-choice',
    'mmmu:subject=Accounting,question_type=multiple-choice',
    'mmmu:subject=Physics,question_type=multiple-choice',
    'mmmu:subject=Geography,question_type=multiple-choice',
    'mmmu:subject=Math,question_type=multiple-choice',
    'mmmu:subject=Design,question_type=multiple-choice',
    'mmmu:subject=Materials,question_type=multiple-choice',
    'mmmu:subject=Marketing,question_type=multiple-choice',
    'mmmu:subject=History,question_type=multiple-choice',
    'mmmu:subject=Sociology,question_type=multiple-choice',
    'mmmu:subject=Energy_and_Power,question_type=multiple-choice',
    'mmmu:subject=Chemistry,question_type=multiple-choice',
    'mmmu:subject=Computer_Science,question_type=multiple-choice',
    'mmmu:subject=Economics,question_type=multiple-choice',
    'mmmu:subject=Clinical_Medicine,question_type=multiple-choice',
    'mmmu:subject=Public_Health,question_type=multiple-choice',
    'mmmu:subject=Finance,question_type=multiple-choice',
  ]:
    result_folders = [d for d in os.listdir(save_dir) if d.startswith(dataset)]
    result_folders = [d for d in result_folders if d.split("model=")[1].split(',')[0] in selected_models]
    result_folders = _check_ordering(result_folders)
    result_folders = sorted(result_folders)
    
    data = {}
    for d in result_folders:
        model = d.split("model=")[1].split(',')[0]
        values, input_tokens, output_tokens = _get_metrics(os.path.join(save_dir, d, "per_instance_stats.json"), PRIMARY_METRIC)
                
        data[model] = {
          "stat_type": "binary",
          "stat_name": PRIMARY_METRIC,
          "values": values,
          "input_tokens": input_tokens,
          "output_tokens": output_tokens,
        }
    
    with open(os.path.join(save_dir, result_folders[0], "display_requests.json")) as f:
        requests = json.load(f)
                
    contents = [_format_content(request["request"]["multimodal_prompt"]["media_objects"]) for request in requests]
    embed = {
      "contents": contents,
      "embeddings": get_embeddings(EMBEDDING_MODEL, contents)
    }
    
    all_data[dataset] = data
    all_embed[dataset] = embed
    
  return all_data, all_embed

# mmmu_data, mmmu_embed = process_mmmu()
# torch.save((mmmu_data, mmmu_embed), os.path.join(root_dir, "data/vhelm/cache/mmmu.pth"))
mmmu_data, mmmu_embed = torch.load(os.path.join(root_dir, "data/vhelm/cache/mmmu.pth"))
VHELM_DATA.update(mmmu_data)
VHELM_EMBED.update(mmmu_embed)

def process_blink():
  all_data = {}
  all_embed = {}
  
  PRIMARY_METRIC = "quasi_prefix_exact_match"
  for dataset in [
    'blink:category=Relative_Depth',
    'blink:category=Visual_Similarity',
    'blink:category=Jigsaw',
    'blink:category=Forensic_Detection',
    'blink:category=IQ_Test',
    'blink:category=Semantic_Correspondence',
    'blink:category=Visual_Correspondence',
    'blink:category=Multi-view_Reasoning',
    'blink:category=Spatial_Relation',
    'blink:category=Functional_Correspondence',
    'blink:category=Object_Localization',
    'blink:category=Art_Style',
    'blink:category=Counting',
    'blink:category=Relative_Reflectance',
  ]:
    result_folders = [d for d in os.listdir(save_dir) if d.startswith(dataset)]
    result_folders = [d for d in result_folders if d.split("model=")[1].split(',')[0] in selected_models]
    result_folders = _check_ordering(result_folders)
    result_folders = sorted(result_folders)
    
    data = {}
    for d in result_folders:
        model = d.split("model=")[1].split(',')[0]
        values, input_tokens, output_tokens = _get_metrics(os.path.join(save_dir, d, "per_instance_stats.json"), PRIMARY_METRIC)
                
        data[model] = {
          "stat_type": "binary",
          "stat_name": PRIMARY_METRIC,
          "values": values,
          "input_tokens": input_tokens,
          "output_tokens": output_tokens,
        }
    
    with open(os.path.join(save_dir, result_folders[0], "display_requests.json")) as f:
        requests = json.load(f)
                
    contents = [_format_content(request["request"]["multimodal_prompt"]["media_objects"]) for request in requests]
    embed = {
      "contents": contents,
      "embeddings": get_embeddings(EMBEDDING_MODEL, contents)
    }
    
    all_data[dataset] = data
    all_embed[dataset] = embed
    
  return all_data, all_embed

# blink_data, blink_embed = process_blink()
# torch.save((blink_data, blink_embed), os.path.join(root_dir, "data/vhelm/cache/blink.pth"))
blink_data, blink_embed = torch.load(os.path.join(root_dir, "data/vhelm/cache/blink.pth"))
VHELM_DATA.update(blink_data)
VHELM_EMBED.update(blink_embed)

torch.save(VHELM_DATA, "data/vhelm/data.pth")
torch.save(VHELM_EMBED, "data/vhelm/vlm2vec_qwen7b.pth")