"""
Usage: python mteb_meta.py path_to_results_folder
python eval/mteb_meta.py /XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_new/results/mteb/v4_fineweb_100b_pythia-160m-retr-32k_w_meta_mb2-wb2048-grp1024_keep_368k_negs_128N_truncate_normal/no_instruction_False_include_long_prompt_True_include_meta_tokens_True_prompt_style_bos_prefix_q_prefix_d_pooling_method_lasttoken/step-00072000_ck

Creates evaluation results metadata for the model card. 
E.g.
---
tags:
- mteb
model-index:
- name: SGPT-5.8B-weightedmean-msmarco-specb-bitfit
  results:
  - task:
      type: classification
    dataset:
      type: mteb/banking77
      name: MTEB Banking77
      config: default
      split: test
      revision: 44fa15921b4c889113cc5df03dd4901b49161ab7
    metrics:
    - type: accuracy
      value: 84.49350649350649
---
"""

import json
import logging
import os
import sys

from mteb import MTEB

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


results_folder = sys.argv[1].rstrip("/")
model_name = results_folder.split("/")[-1]

all_results = {}

for file_name in os.listdir(results_folder):
    if not file_name.endswith(".json") or "model_meta" in file_name:
        logger.info(f"Skipping non-json {file_name}")
        continue
    with open(os.path.join(results_folder, file_name), "r", encoding="utf-8") as f:
        results = json.load(f)
        all_results = {**all_results, **{file_name.replace(".json", ""): results}}

# Use "train" split instead
TRAIN_SPLIT = ["DanishPoliticalCommentsClassification"]
# Use "validation" split instead
VALIDATION_SPLIT = ["AFQMC", "Cmnli", "IFlyTek", "TNews", "MSMARCO", "MultilingualSentiment", "Ocnli"]
# Use "dev" split instead
DEV_SPLIT = [
    "CmedqaRetrieval",
    "CovidRetrieval",
    "DuRetrieval",
    "EcomRetrieval",
    "MedicalRetrieval",
    "MMarcoReranking",
    "MMarcoRetrieval",
    "MSMARCO",
    "T2Reranking",
    "T2Retrieval",
    "VideoRetrieval",
]

MARKER = "---"
TAGS = "tags:"
MTEB_TAG = "- mteb"
HEADER = "model-index:"
MODEL = f"- name: {model_name}"
RES = "  results:"

META_STRING = "\n".join([MARKER, TAGS, MTEB_TAG, HEADER, MODEL, RES])


ONE_TASK = "  - task:\n      type: {}\n    dataset:\n      type: {}\n      name: {}\n      config: {}\n      split: {}\n      revision: {}\n    metrics:"
ONE_METRIC = "    - type: {}\n      value: {}"
# SKIP_KEYS = ["std", "evaluation_time", "main_score", "threshold"]
SKIP_KEYS = ["std", "evaluation_time", "threshold"]

for ds_name, res_dict in sorted(all_results.items()):
    mteb_desc = (
        MTEB(tasks=[ds_name.replace("CQADupstackRetrieval", "CQADupstackAndroidRetrieval")]).tasks[0].metadata_dict
    )
    hf_hub_name = mteb_desc.get("hf_hub_name", mteb_desc.get("beir_name"))
    if "CQADupstack" in ds_name:
        hf_hub_name = "BeIR/cqadupstack"
    mteb_type = mteb_desc["type"]
    revision = res_dict.get("dataset_revision")  # Okay if it's None
    split = "test"
    if (ds_name in TRAIN_SPLIT) and ("train" in res_dict['scores'].keys()):
        split = "train"
    elif (ds_name in VALIDATION_SPLIT) and ("validation" in res_dict['scores'].keys()):
        split = "validation"
    elif (ds_name in DEV_SPLIT) and ("dev" in res_dict['scores'].keys()):
        split = "dev"
    elif "test" not in res_dict['scores'].keys():
        logger.info(f"Skipping {ds_name} as split {split} not present.")
        continue

    res_dict = res_dict['scores'].get(split)
    for lang in mteb_desc["eval_langs"]:
        mteb_name = f"MTEB {ds_name}"
        mteb_name += f" ({lang})" if len(mteb_desc["eval_langs"]) > 1 else ""
        # For English there is no language key if it's the only language
        # test_result_lang = res_dict.get(lang) if len(mteb_desc["eval_langs"]) > 1 else res_dict
        test_result_lang = res_dict[0]
        # Skip if the language was not found but it has other languages
        if test_result_lang is None:
            continue
        META_STRING += "\n" + ONE_TASK.format(
            mteb_type, hf_hub_name, mteb_name, lang if len(mteb_desc["eval_langs"]) > 1 else "default", split, revision
        )
        for metric, score in test_result_lang.items():
            if "main_score" not in metric:
                continue
            if not isinstance(score, dict):
                score = {metric: score}
            for sub_metric, sub_score in score.items():
                if any([x in sub_metric for x in SKIP_KEYS]):
                    continue
                META_STRING += "\n" + ONE_METRIC.format(
                    f"{metric}_{sub_metric}" if metric != sub_metric else metric,
                    # All MTEB scores are 0-1, multiply them by 100 for 3 reasons:
                    # 1) It's easier to visually digest (You need two chars less: "0.1" -> "1")
                    # 2) Others may multiply them by 100, when building on MTEB making it confusing what the range is
                    # This happend with Text and Code Embeddings paper (OpenAI) vs original BEIR paper
                    # 3) It's accepted practice (SuperGLUE, GLUE are 0-100)
                    sub_score * 100,
                )

META_STRING += "\n" + MARKER
if os.path.exists(f"./{model_name}/mteb_metadata.md"):
    logger.warning("Overwriting mteb_metadata.md")
elif not os.path.exists(f"./{model_name}"):
    os.mkdir(f"./{model_name}")
with open(f"./{model_name}/mteb_metadata.md", "w") as f:
    f.write(META_STRING)