import os

root_dir = os.path.join(os.path.dirname(__file__), "../")

import torch
from datasets import load_dataset, get_dataset_config_names
from huggingface_hub import HfApi

from text_embeddings import MODELS, get_embeddings

import tiktoken
enc = tiktoken.get_encoding("o200k_base")

import huggingface_hub
huggingface_hub.login(token="")

save_dir = os.path.join(root_dir, "data/openllm2/raw")
os.makedirs(save_dir, exist_ok=True)

api = HfApi()
all_dataset_info = [d for d in api.list_datasets(search="open-llm-leaderboard/") if d.id.startswith('open-llm-leaderboard/')]
    
llama_models = [
    # "open-llm-leaderboard/meta-llama__Meta-Llama-3.1-8B-Instruct-details",
    # "open-llm-leaderboard/meta-llama__Meta-Llama-3.1-70B-Instruct-details",
    
    "open-llm-leaderboard/meta-llama__Meta-Llama-3-8B-Instruct-details",
    "open-llm-leaderboard/meta-llama__Meta-Llama-3-70B-Instruct-details",
    
    "open-llm-leaderboard/meta-llama__Llama-2-7b-chat-hf-details",
    "open-llm-leaderboard/meta-llama__Llama-2-13b-chat-hf-details",
    "open-llm-leaderboard/meta-llama__Llama-2-70b-chat-hf-details",
    
    # "open-llm-leaderboard/meta-llama__Meta-Llama-3-8B-details",
    # "open-llm-leaderboard/meta-llama__Meta-Llama-3-70B-details",
    
    # "open-llm-leaderboard/meta-llama__Llama-2-7b-hf-details",
    # "open-llm-leaderboard/meta-llama__Llama-2-13b-hf-details",
    # "open-llm-leaderboard/meta-llama__Llama-2-70b-hf-details",
    
    # "open-llm-leaderboard/meta-llama__Meta-Llama-3.1-8B-details",
    # "open-llm-leaderboard/meta-llama__Meta-Llama-3.1-70B-details",
]

yi_models = [
    "open-llm-leaderboard/01-ai__Yi-1.5-34B-Chat-details",
    "open-llm-leaderboard/01-ai__Yi-1.5-9B-Chat-details",
    "open-llm-leaderboard/01-ai__Yi-1.5-6B-Chat-details",
    
    # "open-llm-leaderboard/01-ai__Yi-1.5-34B-details",
    # "open-llm-leaderboard/01-ai__Yi-1.5-9B-details",
    # "open-llm-leaderboard/01-ai__Yi-1.5-6B-details",
    
    # "open-llm-leaderboard/01-ai__Yi-1.5-34B-32K-details",
    # "open-llm-leaderboard/01-ai__Yi-1.5-9B-32K-details",
    
    # "open-llm-leaderboard/01-ai__Yi-1.5-34B-Chat-16K-details",
    # "open-llm-leaderboard/01-ai__Yi-1.5-9B-Chat-16K-details",
    
    "open-llm-leaderboard/01-ai__Yi-34B-Chat-details",
    "open-llm-leaderboard/01-ai__Yi-6B-Chat-details",
    
    # "open-llm-leaderboard/01-ai__Yi-34B-details",
    # "open-llm-leaderboard/01-ai__Yi-9B-details",
    # "open-llm-leaderboard/01-ai__Yi-6B-details",
    
    # "open-llm-leaderboard/01-ai__Yi-34B-200K-details",
    # "open-llm-leaderboard/01-ai__Yi-9B-200K-details",
    # "open-llm-leaderboard/01-ai__Yi-6B-200K-details",
]

qwen_models = [
    "open-llm-leaderboard/Qwen__Qwen2-72B-Instruct-details",
    "open-llm-leaderboard/Qwen__Qwen2-7B-Instruct-details",
    "open-llm-leaderboard/Qwen__Qwen2-1.5B-Instruct-details",
    "open-llm-leaderboard/Qwen__Qwen2-0.5B-Instruct-details",
    
    # "open-llm-leaderboard/Qwen__Qwen2-0.5B-details",
    # "open-llm-leaderboard/Qwen__Qwen2-1.5B-details",
    # "open-llm-leaderboard/Qwen__Qwen2-7B-details",
    # "open-llm-leaderboard/Qwen__Qwen2-72B-details",
    
    "open-llm-leaderboard/Qwen__Qwen1.5-110B-Chat-details",
    "open-llm-leaderboard/Qwen__Qwen1.5-32B-Chat-details",
    "open-llm-leaderboard/Qwen__Qwen1.5-14B-Chat-details",
    "open-llm-leaderboard/Qwen__Qwen1.5-7B-Chat-details",
    "open-llm-leaderboard/Qwen__Qwen1.5-4B-Chat-details",
    "open-llm-leaderboard/Qwen__Qwen1.5-1.8B-Chat-details",
    "open-llm-leaderboard/Qwen__Qwen1.5-0.5B-Chat-details",
    
    # "open-llm-leaderboard/Qwen__Qwen1.5-110B-details",
    # "open-llm-leaderboard/Qwen__Qwen1.5-32B-details",
    # "open-llm-leaderboard/Qwen__Qwen1.5-14B-details",
    # "open-llm-leaderboard/Qwen__Qwen1.5-7B-details",
    # "open-llm-leaderboard/Qwen__Qwen1.5-4B-details",
    # "open-llm-leaderboard/Qwen__Qwen1.5-1.8B-details",
    # "open-llm-leaderboard/Qwen__Qwen1.5-0.5B-details",
    
    # "open-llm-leaderboard/Qwen__Qwen1.5-MoE-A2.7B-details",
    "open-llm-leaderboard/Qwen__Qwen1.5-MoE-A2.7B-Chat-details",
]

mistral_models = [
    # "open-llm-leaderboard/mistralai__Mixtral-8x22B-v0.1-details",
    # "open-llm-leaderboard/mistralai__Mixtral-8x7B-v0.1-details",
    
    # "open-llm-leaderboard/mistralai__Mixtral-8x22B-Instruct-v0.1-details",
    "open-llm-leaderboard/mistralai__Mixtral-8x7B-Instruct-v0.1-details",
    
    # "open-llm-leaderboard/mistralai__Mistral-7B-v0.1-details",
    # "open-llm-leaderboard/mistralai__Mistral-7B-v0.3-details",
    
    # "open-llm-leaderboard/mistral-community__Mistral-7B-v0.2-details",
    # "open-llm-leaderboard/mistral-community__Mixtral-8x22B-v0.1-details",
    # "open-llm-leaderboard/mistral-community__mixtral-8x22B-v0.3-details",
    
    "open-llm-leaderboard/mistralai__Mistral-7B-Instruct-v0.1-details",
    "open-llm-leaderboard/mistralai__Mistral-7B-Instruct-v0.2-details",
    "open-llm-leaderboard/mistralai__Mistral-7B-Instruct-v0.3-details",
    
    # "open-llm-leaderboard/mistralai__Mistral-Nemo-Instruct-2407-details",
    # "open-llm-leaderboard/mistralai__Mistral-Nemo-Base-2407-details",
]

gemma_models = [
    # "open-llm-leaderboard/google__gemma-2-9b-it-details",
    # "open-llm-leaderboard/google__gemma-2-2b-it-details",
    
    # "open-llm-leaderboard/google__gemma-2-9b-details",
    # "open-llm-leaderboard/google__gemma-2-2b-details",
    
    "open-llm-leaderboard/google__gemma-7b-it-details",
    "open-llm-leaderboard/google__gemma-2b-it-details",
    
    # "open-llm-leaderboard/google__gemma-7b-details",
    # "open-llm-leaderboard/google__gemma-2b-details",
    
    # "open-llm-leaderboard/google__recurrentgemma-2b-details",
    # "open-llm-leaderboard/google__recurrentgemma-9b-details",
    
    "open-llm-leaderboard/google__recurrentgemma-2b-it-details",
    "open-llm-leaderboard/google__recurrentgemma-9b-it-details",
    
    "open-llm-leaderboard/google__gemma-1.1-2b-it-details",
    "open-llm-leaderboard/google__gemma-1.1-7b-it-details",
    
    "open-llm-leaderboard/google__flan-t5-small-details",
]

phi_model = [
    "open-llm-leaderboard/microsoft__Phi-3-medium-4k-instruct-details",
    "open-llm-leaderboard/microsoft__Phi-3-mini-4k-instruct-details",
    
    # "open-llm-leaderboard/microsoft__Phi-3-small-128k-instruct-details",
    # "open-llm-leaderboard/microsoft__Phi-3-mini-128k-instruct-details",
    
    # "open-llm-leaderboard/microsoft__phi-1-details",
    # "open-llm-leaderboard/microsoft__phi-1_5-details",
    # "open-llm-leaderboard/microsoft__phi-2-details",
    
    # "open-llm-leaderboard/microsoft__Orca-2-7b-details",
    # "open-llm-leaderboard/microsoft__Orca-2-13b-details",
]

openai_model = [
    "open-llm-leaderboard/openai-community__gpt2-large-details",
    "open-llm-leaderboard/openai-community__gpt2-details",
    "open-llm-leaderboard/openai-community__gpt2-medium-details",
    "open-llm-leaderboard/openai-community__gpt2-xl-details",
]

infinity_model = [
    'open-llm-leaderboard/BAAI__Infinity-Instruct-3M-0625-Llama3-70B-details',
    'open-llm-leaderboard/BAAI__Infinity-Instruct-3M-0625-Llama3-8B-details',
]

cohere_model = [
    'open-llm-leaderboard/CohereForAI__aya-23-35B-details',
    'open-llm-leaderboard/CohereForAI__aya-23-8B-details',
]

qwen25_model = [
    'open-llm-leaderboard/Qwen__Qwen2.5-0.5B-Instruct-details',
    'open-llm-leaderboard/Qwen__Qwen2.5-1.5B-Instruct-details',
    'open-llm-leaderboard/Qwen__Qwen2.5-7B-Instruct-details',
    'open-llm-leaderboard/Qwen__Qwen2.5-14B-Instruct-details',
    'open-llm-leaderboard/Qwen__Qwen2.5-32B-Instruct-details',
    'open-llm-leaderboard/Qwen__Qwen2.5-72B-Instruct-details',
    
    # 'open-llm-leaderboard/Qwen__Qwen2.5-0.5B-details',
    # 'open-llm-leaderboard/Qwen__Qwen2.5-1.5B-details',
    # 'open-llm-leaderboard/Qwen__Qwen2.5-14B-details',
    # 'open-llm-leaderboard/Qwen__Qwen2.5-32B-details',
    # 'open-llm-leaderboard/Qwen__Qwen2.5-3B-Instruct-details',
    # 'open-llm-leaderboard/Qwen__Qwen2.5-3B-details',
    # 'open-llm-leaderboard/Qwen__Qwen2.5-72B-details',
    # 'open-llm-leaderboard/Qwen__Qwen2.5-7B-details',
    # 'open-llm-leaderboard/Qwen__Qwen2.5-Coder-7B-Instruct-details',
    # 'open-llm-leaderboard/Qwen__Qwen2.5-Coder-7B-details',
    # 'open-llm-leaderboard/Qwen__Qwen2.5-Math-7B-Instruct-details',
]

internlm_model = [
    'open-llm-leaderboard/internlm__internlm2_5-1_8b-chat-details',
    'open-llm-leaderboard/internlm__internlm2_5-20b-chat-details',
    'open-llm-leaderboard/internlm__internlm2_5-7b-chat-details',
]

selected_models = llama_models + yi_models + qwen_models + mistral_models + gemma_models + phi_model + openai_model + infinity_model + cohere_model + qwen25_model + internlm_model

METRIC_MAPPING = {
 'bbh_boolean_expressions': 'acc_norm',
 'bbh_causal_judgement': 'acc_norm',
 'bbh_date_understanding': 'acc_norm',
 'bbh_disambiguation_qa': 'acc_norm',
 'bbh_formal_fallacies': 'acc_norm',
 'bbh_geometric_shapes': 'acc_norm',
 'bbh_hyperbaton': 'acc_norm',
 'bbh_logical_deduction_five_objects': 'acc_norm',
 'bbh_logical_deduction_seven_objects': 'acc_norm',
 'bbh_logical_deduction_three_objects': 'acc_norm',
 'bbh_movie_recommendation': 'acc_norm',
 'bbh_navigate': 'acc_norm',
 'bbh_object_counting': 'acc_norm',
 'bbh_penguins_in_a_table': 'acc_norm',
 'bbh_reasoning_about_colored_objects': 'acc_norm',
 'bbh_ruin_names': 'acc_norm',
 'bbh_salient_translation_error_detection': 'acc_norm',
 'bbh_snarks': 'acc_norm',
 'bbh_sports_understanding': 'acc_norm',
 'bbh_temporal_sequences': 'acc_norm',
 'bbh_tracking_shuffled_objects_five_objects': 'acc_norm',
 'bbh_tracking_shuffled_objects_seven_objects': 'acc_norm',
 'bbh_tracking_shuffled_objects_three_objects': 'acc_norm',
 'bbh_web_of_lies': 'acc_norm',
 'gpqa_diamond': 'acc_norm',
 'gpqa_extended': 'acc_norm',
 'gpqa_main': 'acc_norm',
 'ifeval': 'prompt_level_strict_acc',
 'math_algebra_hard': 'exact_match',
 'math_counting_and_prob_hard': 'exact_match',
 'math_geometry_hard': 'exact_match',
 'math_intermediate_algebra_hard': 'exact_match',
 'math_num_theory_hard': 'exact_match',
 'math_prealgebra_hard': 'exact_match',
 'math_precalculus_hard': 'exact_match',
 'mmlu_pro': 'acc',
 'musr_murder_mysteries': 'acc_norm',
 'musr_object_placements': 'acc_norm',
 'musr_team_allocation': 'acc_norm'
}

SIZES = {
 'bbh_boolean_expressions': 250,
 'bbh_causal_judgement': 187,
 'bbh_date_understanding': 250,
 'bbh_disambiguation_qa': 250,
 'bbh_formal_fallacies': 250,
 'bbh_geometric_shapes': 250,
 'bbh_hyperbaton': 250,
 'bbh_logical_deduction_five_objects': 250,
 'bbh_logical_deduction_seven_objects': 250,
 'bbh_logical_deduction_three_objects': 250,
 'bbh_movie_recommendation': 250,
 'bbh_navigate': 250,
 'bbh_object_counting': 250,
 'bbh_penguins_in_a_table': 146,
 'bbh_reasoning_about_colored_objects': 250,
 'bbh_ruin_names': 250,
 'bbh_salient_translation_error_detection': 250,
 'bbh_snarks': 178,
 'bbh_sports_understanding': 250,
 'bbh_temporal_sequences': 250,
 'bbh_tracking_shuffled_objects_five_objects': 250,
 'bbh_tracking_shuffled_objects_seven_objects': 250,
 'bbh_tracking_shuffled_objects_three_objects': 250,
 'bbh_web_of_lies': 250,
 'gpqa_diamond': 198,
 'gpqa_extended': 546,
 'gpqa_main': 448,
 'ifeval': 541,
 'math_algebra_hard': 307,
 'math_counting_and_prob_hard': 123,
 'math_geometry_hard': 132,
 'math_intermediate_algebra_hard': 280,
 'math_num_theory_hard': 154,
 'math_prealgebra_hard': 193,
 'math_precalculus_hard': 135,
 'mmlu_pro': 12032,
 'musr_murder_mysteries': 250,
 'musr_object_placements': 256,
 'musr_team_allocation': 250
}

def get_configs(dataset_name):
    configs = []
    for k in METRIC_MAPPING.keys():
        configs.append(f"{dataset_name.removeprefix('open-llm-leaderboard/').removesuffix('-details')}__leaderboard_{k}")
    
    return configs

def get_prompts(dataset, subset_name):
    prompts = []
    for ex in dataset:
        prompts.append(ex["arguments"]["gen_args_0"]["arg_0"])
        
    return prompts

def get_input_tokens(dataset, subset_name):
    prompts = get_prompts(dataset, subset_name)
    
    return [len(enc.encode(p, disallowed_special=())) for p in prompts]

def get_output_tokens(dataset, subset_name):
    resps = [ex['resps'][0][0] for ex in dataset]
    
    return [len(enc.encode(res, disallowed_special=())) if isinstance(res, str) else 1 for res in resps]

def download_openllm_predictions(dataset_name):
    outputs = {}
    
    configs = get_configs(dataset_name)
    
    for config in configs:
        try:
            dataset = load_dataset(dataset_name, config, split='latest', cache_dir=save_dir, token=True)
        except Exception as e:
            print(f"{dataset_name} => {e}")
            return 
        
        subset_name = config.split('__leaderboard_')[1]
        metric_name = METRIC_MAPPING[subset_name]
        outputs[subset_name] = {}
        try:
            predictions = dataset[metric_name]
        except:
            print(f"{dataset_name} {subset_name} do not have metric {metric_name}")
            return 
        
        if len(predictions) != SIZES[subset_name]:
            print(f"{dataset_name} {subset_name} is incomplete")
            return 
        
        predictions = torch.tensor(predictions).float()
        
        if torch.any(torch.isnan(predictions)):
            print(f"{dataset_name} {subset_name} has nan values")
            return
        
        outputs[subset_name]["scores"] = predictions
        outputs[subset_name]["input_tokens"] = get_input_tokens(dataset, subset_name)
        outputs[subset_name]["output_tokens"] = get_output_tokens(dataset, subset_name)
        
    return outputs

data = {}
for dataset_name in selected_models:
    res = download_openllm_predictions(dataset_name)
    if res is not None:
        data[dataset_name] = res
    
torch.save(data, os.path.join(root_dir, "data/openllm2/data.pth"))

def get_prompts_embeddings(dataset_name, embed):
    outputs = {}
    configs = get_configs(dataset_name)
    
    for config in configs:
        try:
            dataset = load_dataset(dataset_name, config, split='latest', cache_dir=save_dir, token=True)
        except Exception as e:
            print(f"{dataset_name} => {e}")
            return 
        
        subset_name = config.split('__leaderboard_')[1]
        prompts = get_prompts(dataset, subset_name)
        embeddings = get_embeddings(embed, prompts, 4)
        
        outputs[subset_name] = {
            "prompts": prompts,
            "embeddings": embeddings
        }

    return outputs

embeddings = get_prompts_embeddings("open-llm-leaderboard/openai-community__gpt2-details", MODELS["bert"]())
torch.save(embeddings, os.path.join(root_dir, "data/openllm2/bert.pth"))