import os
import json
import torch
from collections import defaultdict

os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = ""

root_dir = os.path.join(os.path.dirname(__file__), "../")

from google.cloud import storage

from text_embeddings import MODELS, get_embeddings

# download
bucket_name = "crfm-helm-public"
folders = [
    "lite/benchmark_output/runs/v1.0.0/", 
    "lite/benchmark_output/runs/v1.1.0/",
    "lite/benchmark_output/runs/v1.2.0/",
    "lite/benchmark_output/runs/v1.3.0/",
    "lite/benchmark_output/runs/v1.4.0/",
    "lite/benchmark_output/runs/v1.5.0/",
    "lite/benchmark_output/runs/v1.6.0/",
]

save_dir = os.path.join(root_dir, "data/helm_lite/raw")
os.makedirs(save_dir, exist_ok=True)

storage_client = storage.Client()
bucket = storage_client.bucket(bucket_name)
for folder_name in folders:
    blobs = bucket.list_blobs(prefix=folder_name)
    for blob in blobs:
        if blob.name.endswith("per_instance_stats.json") or blob.name.endswith("instances.json") or blob.name.endswith("display_requests.json"):
            file_path = os.path.join(save_dir, blob.name.replace(folder_name, ''))
            os.makedirs(os.path.dirname(file_path), exist_ok=True)
            blob.download_to_filename(file_path)
            
            
# parse model and dataset names
models = set()
model_to_datasets = defaultdict(list)

save_dir = os.path.join(root_dir, "data/helm_lite/raw")
for d in os.listdir(save_dir):
    dataset, model = d.replace(',stop=none', '').split("model=")
    dataset = dataset.strip(',: ')
    model = model.strip(',: ')
    models.add(model)
    model_to_datasets[model].append(d)

print(f"downloaded {len(models)} models in total")
print(sorted(models))

'''
downloaded 72 models in total
['01-ai_yi-34b', '01-ai_yi-6b', '01-ai_yi-large-preview', 'AlephAlpha_luminous-base', 'AlephAlpha_luminous-extended', 'AlephAlpha_luminous-supreme', 'ai21_j2-grande', 'ai21_j2-jumbo', 'ai21_jamba-instruct', 'allenai_olmo-7b', 'anthropic_claude-2.0', 'anthropic_claude-2.1', 'anthropic_claude-3-5-sonnet-20240620', 'anthropic_claude-3-haiku-20240307', 'anthropic_claude-3-opus-20240229', 'anthropic_claude-3-sonnet-20240229', 'anthropic_claude-instant-1.2', 'anthropic_claude-instant-v1', 'anthropic_claude-v1.3', 'cohere_command', 'cohere_command-light', 'cohere_command-r', 'cohere_command-r-plus', 'databricks_dbrx-instruct', 'deepseek-ai_deepseek-llm-67b-chat', 'google_gemini-1.0-pro-001', 'google_gemini-1.0-pro-002', 'google_gemini-1.5-flash-001', 'google_gemini-1.5-pro-001', 'google_gemini-1.5-pro-preview-0409', 'google_gemma-2-27b-it', 'google_gemma-2-9b-it', 'google_gemma-7b', 'google_text-bison@001', 'google_text-unicorn@001', 'meta_llama-2-13b', 'meta_llama-2-70b', 'meta_llama-2-7b', 'meta_llama-3-70b', 'meta_llama-3-8b', 'meta_llama-65b', 'microsoft_phi-2', 'microsoft_phi-3-medium-4k-instruct', 'microsoft_phi-3-small-8k-instruct', 'mistralai_mistral-7b-instruct-v0.3', 'mistralai_mistral-7b-v0.1', 'mistralai_mistral-large-2402', 'mistralai_mistral-large-2407', 'mistralai_mistral-medium-2312', 'mistralai_mistral-small-2402', 'mistralai_mixtral-8x22b', 'mistralai_mixtral-8x7b-32kseqlen', 'mistralai_open-mistral-nemo-2407', 'openai_gpt-3.5-turbo-0613', 'openai_gpt-4-0613', 'openai_gpt-4-1106-preview', 'openai_gpt-4-turbo-2024-04-09', 'openai_gpt-4o-2024-05-13', 'openai_gpt-4o-mini-2024-07-18', 'openai_text-davinci-002', 'openai_text-davinci-003', 'qwen_qwen1.5-110b-chat', 'qwen_qwen1.5-14b', 'qwen_qwen1.5-32b', 'qwen_qwen1.5-72b', 'qwen_qwen1.5-7b', 'qwen_qwen2-72b-instruct', 'snowflake_snowflake-arctic-instruct', 'tiiuae_falcon-40b', 'tiiuae_falcon-7b', 'writer_palmyra-x-v2', 'writer_palmyra-x-v3']
'''

# exclude incomplete models
for k, v in model_to_datasets.items():
    model_to_datasets[k] = sorted(v)

datasets = sorted([d.replace(',stop=none', '').split('model=')[0].strip(',: ') for d in model_to_datasets[list(models)[0]]])
for k, v in model_to_datasets.items():
    if sorted([d.replace(',stop=none', '').split('model=')[0].strip(',: ') for d in v]) != datasets:
        models.remove(k)
        
print(len(models))
print(sorted(models))
print(datasets)

'''
70
['01-ai_yi-34b', '01-ai_yi-6b', '01-ai_yi-large-preview', 'AlephAlpha_luminous-base', 'AlephAlpha_luminous-extended', 'AlephAlpha_luminous-supreme', 'ai21_j2-grande', 'ai21_j2-jumbo', 'allenai_olmo-7b', 'anthropic_claude-2.0', 'anthropic_claude-2.1', 'anthropic_claude-3-5-sonnet-20240620', 'anthropic_claude-3-haiku-20240307', 'anthropic_claude-3-opus-20240229', 'anthropic_claude-3-sonnet-20240229', 'anthropic_claude-instant-1.2', 'anthropic_claude-instant-v1', 'anthropic_claude-v1.3', 'cohere_command', 'cohere_command-light', 'cohere_command-r', 'cohere_command-r-plus', 'databricks_dbrx-instruct', 'deepseek-ai_deepseek-llm-67b-chat', 'google_gemini-1.0-pro-001', 'google_gemini-1.0-pro-002', 'google_gemini-1.5-flash-001', 'google_gemini-1.5-pro-001', 'google_gemini-1.5-pro-preview-0409', 'google_gemma-2-27b-it', 'google_gemma-2-9b-it', 'google_gemma-7b', 'google_text-bison@001', 'google_text-unicorn@001', 'meta_llama-2-13b', 'meta_llama-2-70b', 'meta_llama-2-7b', 'meta_llama-3-70b', 'meta_llama-3-8b', 'meta_llama-65b', 'microsoft_phi-2', 'microsoft_phi-3-medium-4k-instruct', 'mistralai_mistral-7b-instruct-v0.3', 'mistralai_mistral-7b-v0.1', 'mistralai_mistral-large-2402', 'mistralai_mistral-large-2407', 'mistralai_mistral-medium-2312', 'mistralai_mistral-small-2402', 'mistralai_mixtral-8x22b', 'mistralai_mixtral-8x7b-32kseqlen', 'mistralai_open-mistral-nemo-2407', 'openai_gpt-3.5-turbo-0613', 'openai_gpt-4-0613', 'openai_gpt-4-1106-preview', 'openai_gpt-4-turbo-2024-04-09', 'openai_gpt-4o-2024-05-13', 'openai_gpt-4o-mini-2024-07-18', 'openai_text-davinci-002', 'openai_text-davinci-003', 'qwen_qwen1.5-110b-chat', 'qwen_qwen1.5-14b', 'qwen_qwen1.5-32b', 'qwen_qwen1.5-72b', 'qwen_qwen1.5-7b', 'qwen_qwen2-72b-instruct', 'snowflake_snowflake-arctic-instruct', 'tiiuae_falcon-40b', 'tiiuae_falcon-7b', 'writer_palmyra-x-v2', 'writer_palmyra-x-v3']
['commonsense:dataset=openbookqa,method=multiple_choice_joint', 'gsm', 'legalbench:subset=abercrombie', 'legalbench:subset=corporate_lobbying', 'legalbench:subset=function_of_decision_section', 'legalbench:subset=international_citizenship_questions', 'legalbench:subset=proa', 'math:subject=algebra,level=1,use_official_examples=False,use_chain_of_thought=True', 'math:subject=counting_and_probability,level=1,use_official_examples=False,use_chain_of_thought=True', 'math:subject=geometry,level=1,use_official_examples=False,use_chain_of_thought=True', 'math:subject=intermediate_algebra,level=1,use_official_examples=False,use_chain_of_thought=True', 'math:subject=number_theory,level=1,use_official_examples=False,use_chain_of_thought=True', 'math:subject=prealgebra,level=1,use_official_examples=False,use_chain_of_thought=True', 'math:subject=precalculus,level=1,use_official_examples=False,use_chain_of_thought=True', 'med_qa', 'mmlu:subject=abstract_algebra,method=multiple_choice_joint', 'mmlu:subject=college_chemistry,method=multiple_choice_joint', 'mmlu:subject=computer_security,method=multiple_choice_joint', 'mmlu:subject=econometrics,method=multiple_choice_joint', 'mmlu:subject=us_foreign_policy,method=multiple_choice_joint', 'narrative_qa', 'natural_qa:mode=closedbook', 'natural_qa:mode=openbook_longans', 'wmt_14:language_pair=cs-en', 'wmt_14:language_pair=de-en', 'wmt_14:language_pair=fr-en', 'wmt_14:language_pair=hi-en', 'wmt_14:language_pair=ru-en']
'''

# validate all models used the same inputs and the same order
examples = {}
models_to_exclude = set()
for d in os.listdir(save_dir):
    dataset, model = d.replace(',stop=none', '').split("model=")
    dataset = dataset.strip(', ')
    model = model.strip(', ')
    with open(os.path.join(save_dir, d, "instances.json")) as f:
        instances = [ex['id'] for ex in json.load(f)]
    if dataset in examples:
        if instances != examples[dataset]:
            print(f"{d} is different")
            models_to_exclude.add(model)
    else:
        examples[dataset] = instances
    # validate evaluation results are complete and use the same order
    with open(os.path.join(save_dir, d, "per_instance_stats.json")) as f:
        instances = [ex['instance_id'] for ex in json.load(f)]
    if examples[dataset] != instances:
        print(f"{d} has different instances in stats") 
        models_to_exclude.add(model)
    # validate display_requests files are complete and use the same order
    with open(os.path.join(save_dir, d, "display_requests.json")) as f:
        instances = [ex['instance_id'] for ex in json.load(f)]
    if examples[dataset] != instances:
        print(f"{d} has different instances in display")
        models_to_exclude.add(model)
        
print(f"{len(models_to_exclude)} models to exclude: {models_to_exclude}")
models = models.difference(models_to_exclude)
print(len(models))
print(models)

'''
2 models to exclude: {'mistralai_mistral-small-2402', 'mistralai_mistral-large-2402'}
68
{'meta_llama-65b', 'google_gemini-1.0-pro-002', 'google_text-bison@001', 'google_gemma-2-9b-it', 'cohere_command-r', 'anthropic_claude-3-opus-20240229', 'mistralai_mistral-7b-instruct-v0.3', 'openai_text-davinci-003', 'meta_llama-3-8b', 'anthropic_claude-3-haiku-20240307', 'microsoft_phi-3-medium-4k-instruct', 'openai_gpt-4o-2024-05-13', 'anthropic_claude-2.0', 'mistralai_mixtral-8x7b-32kseqlen', 'anthropic_claude-instant-v1', 'mistralai_mistral-medium-2312', 'anthropic_claude-3-sonnet-20240229', '01-ai_yi-6b', 'writer_palmyra-x-v3', 'tiiuae_falcon-7b', 'mistralai_mistral-large-2407', 'google_gemini-1.5-pro-preview-0409', 'deepseek-ai_deepseek-llm-67b-chat', 'google_gemini-1.5-flash-001', 'cohere_command', 'cohere_command-r-plus', 'AlephAlpha_luminous-base', 'anthropic_claude-v1.3', 'AlephAlpha_luminous-supreme', 'openai_gpt-4o-mini-2024-07-18', 'microsoft_phi-2', 'allenai_olmo-7b', 'anthropic_claude-2.1', 'qwen_qwen1.5-72b', 'qwen_qwen1.5-14b', 'openai_gpt-3.5-turbo-0613', 'meta_llama-2-7b', 'google_gemini-1.5-pro-001', 'meta_llama-2-70b', 'meta_llama-2-13b', 'qwen_qwen1.5-7b', '01-ai_yi-large-preview', 'qwen_qwen1.5-32b', 'mistralai_open-mistral-nemo-2407', 'qwen_qwen1.5-110b-chat', 'mistralai_mixtral-8x22b', 'google_gemma-2-27b-it', 'databricks_dbrx-instruct', 'tiiuae_falcon-40b', 'openai_gpt-4-0613', 'google_gemma-7b', 'meta_llama-3-70b', 'ai21_j2-grande', 'cohere_command-light', 'openai_gpt-4-turbo-2024-04-09', 'anthropic_claude-instant-1.2', 'openai_gpt-4-1106-preview', 'AlephAlpha_luminous-extended', 'qwen_qwen2-72b-instruct', 'openai_text-davinci-002', 'anthropic_claude-3-5-sonnet-20240620', 'writer_palmyra-x-v2', 'google_gemini-1.0-pro-001', 'snowflake_snowflake-arctic-instruct', '01-ai_yi-34b', 'ai21_j2-jumbo', 'mistralai_mistral-7b-v0.1', 'google_text-unicorn@001'}
'''

# get evaluation results
PRIMARY_METRICS = {
    "narrative_qa": ("f1_score", "real"),
    "natural_qa:mode=closedbook": ("f1_score", "real"),
    "natural_qa:mode=openbook_longans": ("f1_score", "real"),
    "commonsense:dataset=openbookqa,method=multiple_choice_joint": ("exact_match", "binary"),
    "mmlu:subject=abstract_algebra,method=multiple_choice_joint": ("exact_match", "binary"),
    "mmlu:subject=college_chemistry,method=multiple_choice_joint": ("exact_match", "binary"),
    "mmlu:subject=computer_security,method=multiple_choice_joint": ("exact_match", "binary"),
    "mmlu:subject=econometrics,method=multiple_choice_joint": ("exact_match", "binary"),
    "mmlu:subject=us_foreign_policy,method=multiple_choice_joint": ("exact_match", "binary"),
    'math:subject=algebra,level=1,use_official_examples=False,use_chain_of_thought=True': ("math_equiv_chain_of_thought", "binary"),
    'math:subject=counting_and_probability,level=1,use_official_examples=False,use_chain_of_thought=True': ("math_equiv_chain_of_thought", "binary"),
    'math:subject=geometry,level=1,use_official_examples=False,use_chain_of_thought=True': ("math_equiv_chain_of_thought", "binary"),
    'math:subject=intermediate_algebra,level=1,use_official_examples=False,use_chain_of_thought=True': ("math_equiv_chain_of_thought", "binary"),
    'math:subject=number_theory,level=1,use_official_examples=False,use_chain_of_thought=True': ("math_equiv_chain_of_thought", "binary"),
    'math:subject=prealgebra,level=1,use_official_examples=False,use_chain_of_thought=True': ("math_equiv_chain_of_thought", "binary"),
    'math:subject=precalculus,level=1,use_official_examples=False,use_chain_of_thought=True': ("math_equiv_chain_of_thought", "binary"),
    'gsm': ("final_number_exact_match", "binary"),
    'legalbench:subset=abercrombie': ("quasi_exact_match", "binary"),
    'legalbench:subset=corporate_lobbying': ("quasi_exact_match", "binary"),
    'legalbench:subset=function_of_decision_section': ("quasi_exact_match", "binary"),
    'legalbench:subset=international_citizenship_questions': ("quasi_exact_match", "binary"),
    'legalbench:subset=proa': ("quasi_exact_match", "binary"),
    'med_qa': ("quasi_exact_match", "binary"),
    'wmt_14:language_pair=cs-en': ("bleu_4", "real"),
    'wmt_14:language_pair=de-en': ("bleu_4", "real"),
    'wmt_14:language_pair=fr-en': ("bleu_4", "real"),
    'wmt_14:language_pair=hi-en': ("bleu_4", "real"),
    'wmt_14:language_pair=ru-en': ("bleu_4", "real"),
}

def get_primary_metrics(dataset, stats):
    stat_name, stat_type = PRIMARY_METRICS[dataset]

    values = []
    for stat in stats:
        for st in stat:
            if st['name']['name'] == stat_name:
                values.append(st['mean'])
                
    input_tokens = []
    for stat in stats:
        for st in stat:
            if st['name']['name'] == "num_prompt_tokens":
                input_tokens.append(st['mean'])
                
    assert len(input_tokens) == len(values)
    
    output_tokens = []
    for stat in stats:
        for st in stat:
            if st['name']['name'] == "num_output_tokens":
                output_tokens.append(st['mean'])
                
    assert len(output_tokens) == len(input_tokens)

    return {
        "stat_type": stat_type,
        "stat_name": stat_name,
        "values": values,
        "input_tokens": input_tokens,
        "output_tokens": output_tokens,
    }
    
def read_eval_results(model):
    results = {}
    for d in model_to_datasets[model]:
        dataset = d.replace(',stop=none', '').split('model=')[0].strip(',: ')
        with open(os.path.join(save_dir, d, "per_instance_stats.json")) as f:
            stats = [ex['stats'] for ex in json.load(f)]
        results[dataset] = get_primary_metrics(dataset, stats)
            
    return results

data = {}
for model in list(models):
    data[model] = read_eval_results(model)
    
torch.save(data, os.path.join(root_dir, "data/helm_lite/data.pth"))


def embed_prompts(model, embed):
    outputs = {}
    for d in model_to_datasets[model]:
        dataset = d.replace(',stop=none', '').split('model=')[0].strip(',: ')
        with open(os.path.join(save_dir, d, "display_requests.json")) as f:
            prompts = [ex['request']['prompt'] for ex in json.load(f)]
        embeddings = get_embeddings(embed, prompts, 4)

        outputs[dataset] = {
            "prompts": prompts,
            "embeddings": embeddings
        }

    return outputs

embeddings = embed_prompts(list(models)[0], MODELS["bert"]())
torch.save(embeddings, os.path.join(root_dir, "data/helm_lite/bert.pth"))