import os
import json
import torch
from collections import defaultdict

os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = ""

root_dir = os.path.join(os.path.dirname(__file__), "../")

from google.cloud import storage

from text_embeddings import MODELS, get_embeddings

# download
bucket_name = "crfm-helm-public"
folders = [
    "mmlu/benchmark_output/runs/v1.0.0/", 
    "mmlu/benchmark_output/runs/v1.1.0/",
    "mmlu/benchmark_output/runs/v1.2.0/",
    "mmlu/benchmark_output/runs/v1.3.0/",
    "mmlu/benchmark_output/runs/v1.4.0/",
    "mmlu/benchmark_output/runs/v1.5.0/",
    "mmlu/benchmark_output/runs/v1.6.0/",
]

save_dir = os.path.join(root_dir, "data/helm-mmlu/raw")
os.makedirs(save_dir, exist_ok=True)

storage_client = storage.Client()
bucket = storage_client.bucket(bucket_name)
for folder_name in folders:
    blobs = bucket.list_blobs(prefix=folder_name)
    for blob in blobs:
        if blob.name.endswith("per_instance_stats.json") or blob.name.endswith("instances.json") or blob.name.endswith("display_requests.json"):
            file_path = os.path.join(save_dir, blob.name.replace(folder_name, ''))
            os.makedirs(os.path.dirname(file_path), exist_ok=True)
            blob.download_to_filename(file_path)
            
# parse model and dataset names
def parse_dirname(dname):
    dataset = dname.split('method=')[0]
    model = dname.split('model=')[1].split('eval_split=')[0]
    
    dataset = dataset.strip(',: ')
    model = model.strip(',: ')
    
    return dataset, model

# parse model and dataset names
models = set()
model_to_datasets = defaultdict(list)

for d in os.listdir(save_dir):
    dataset, model = parse_dirname(d)
    models.add(model)
    model_to_datasets[model].append(d)

print(f"downloaded {len(models)} models in total")
print(sorted(models))

'''
downloaded 54 models in total
['01-ai_yi-34b', '01-ai_yi-6b', '01-ai_yi-large-preview', 'ai21_jamba-instruct', 'allenai_olmo-1.7-7b', 'allenai_olmo-7b', 'anthropic_claude-2.1', 'anthropic_claude-3-5-sonnet-20240620', 'anthropic_claude-3-haiku-20240307', 'anthropic_claude-3-opus-20240229', 'anthropic_claude-3-sonnet-20240229', 'anthropic_claude-instant-1.2', 'cohere_command-r', 'cohere_command-r-plus', 'databricks_dbrx-instruct', 'deepseek-ai_deepseek-llm-67b-chat', 'google_gemini-1.0-pro-001', 'google_gemini-1.5-flash-001', 'google_gemini-1.5-flash-preview-0514', 'google_gemini-1.5-pro-001', 'google_gemini-1.5-pro-preview-0409', 'google_gemini-pro', 'google_gemma-2-27b', 'google_gemma-2-9b', 'google_gemma-7b', 'google_text-bison@001', 'google_text-unicorn@001', 'meta_llama-2-13b', 'meta_llama-2-70b', 'meta_llama-2-7b', 'meta_llama-3-70b', 'meta_llama-3-8b', 'microsoft_phi-2', 'microsoft_phi-3-medium-4k-instruct', 'microsoft_phi-3-small-8k-instruct', 'mistralai_mistral-7b-instruct-v0.3', 'mistralai_mistral-7b-v0.1', 'mistralai_mistral-large-2402', 'mistralai_mistral-small-2402', 'mistralai_mixtral-8x22b', 'mistralai_mixtral-8x7b-32kseqlen', 'openai_gpt-3.5-turbo-0613', 'openai_gpt-4-0613', 'openai_gpt-4-1106-preview', 'openai_gpt-4-turbo-2024-04-09', 'openai_gpt-4o-2024-05-13', 'qwen_qwen1.5-110b-chat', 'qwen_qwen1.5-14b', 'qwen_qwen1.5-32b', 'qwen_qwen1.5-72b', 'qwen_qwen1.5-7b', 'qwen_qwen2-72b-instruct', 'snowflake_snowflake-arctic-instruct', 'writer_palmyra-x-v3']
'''

# exclude incomplete models
for k, v in model_to_datasets.items():
    model_to_datasets[k] = sorted(v)

datasets = sorted([parse_dirname(d)[0] for d in model_to_datasets[list(models)[0]]])
for k, v in model_to_datasets.items():
    if sorted([parse_dirname(d)[0] for d in v]) != datasets:
        models.remove(k)
        
print(len(models))
print(sorted(models))
print(datasets)

'''
51
['01-ai_yi-34b', '01-ai_yi-6b', '01-ai_yi-large-preview', 'ai21_jamba-instruct', 'allenai_olmo-1.7-7b', 'allenai_olmo-7b', 'anthropic_claude-2.1', 'anthropic_claude-3-5-sonnet-20240620', 'anthropic_claude-3-haiku-20240307', 'anthropic_claude-instant-1.2', 'cohere_command-r', 'cohere_command-r-plus', 'databricks_dbrx-instruct', 'deepseek-ai_deepseek-llm-67b-chat', 'google_gemini-1.0-pro-001', 'google_gemini-1.5-flash-001', 'google_gemini-1.5-flash-preview-0514', 'google_gemini-1.5-pro-001', 'google_gemini-1.5-pro-preview-0409', 'google_gemma-2-27b', 'google_gemma-2-9b', 'google_gemma-7b', 'google_text-bison@001', 'google_text-unicorn@001', 'meta_llama-2-13b', 'meta_llama-2-70b', 'meta_llama-2-7b', 'meta_llama-3-70b', 'meta_llama-3-8b', 'microsoft_phi-2', 'microsoft_phi-3-medium-4k-instruct', 'microsoft_phi-3-small-8k-instruct', 'mistralai_mistral-7b-instruct-v0.3', 'mistralai_mistral-7b-v0.1', 'mistralai_mistral-large-2402', 'mistralai_mistral-small-2402', 'mistralai_mixtral-8x22b', 'mistralai_mixtral-8x7b-32kseqlen', 'openai_gpt-3.5-turbo-0613', 'openai_gpt-4-0613', 'openai_gpt-4-1106-preview', 'openai_gpt-4-turbo-2024-04-09', 'openai_gpt-4o-2024-05-13', 'qwen_qwen1.5-110b-chat', 'qwen_qwen1.5-14b', 'qwen_qwen1.5-32b', 'qwen_qwen1.5-72b', 'qwen_qwen1.5-7b', 'qwen_qwen2-72b-instruct', 'snowflake_snowflake-arctic-instruct', 'writer_palmyra-x-v3']
['mmlu:subject=abstract_algebra', 'mmlu:subject=anatomy', 'mmlu:subject=astronomy', 'mmlu:subject=business_ethics', 'mmlu:subject=clinical_knowledge', 'mmlu:subject=college_biology', 'mmlu:subject=college_chemistry', 'mmlu:subject=college_computer_science', 'mmlu:subject=college_mathematics', 'mmlu:subject=college_medicine', 'mmlu:subject=college_physics', 'mmlu:subject=computer_security', 'mmlu:subject=conceptual_physics', 'mmlu:subject=econometrics', 'mmlu:subject=electrical_engineering', 'mmlu:subject=elementary_mathematics', 'mmlu:subject=formal_logic', 'mmlu:subject=global_facts', 'mmlu:subject=high_school_biology', 'mmlu:subject=high_school_chemistry', 'mmlu:subject=high_school_computer_science', 'mmlu:subject=high_school_european_history', 'mmlu:subject=high_school_geography', 'mmlu:subject=high_school_government_and_politics', 'mmlu:subject=high_school_macroeconomics', 'mmlu:subject=high_school_mathematics', 'mmlu:subject=high_school_microeconomics', 'mmlu:subject=high_school_physics', 'mmlu:subject=high_school_psychology', 'mmlu:subject=high_school_statistics', 'mmlu:subject=high_school_us_history', 'mmlu:subject=high_school_world_history', 'mmlu:subject=human_aging', 'mmlu:subject=human_sexuality', 'mmlu:subject=international_law', 'mmlu:subject=jurisprudence', 'mmlu:subject=logical_fallacies', 'mmlu:subject=machine_learning', 'mmlu:subject=management', 'mmlu:subject=marketing', 'mmlu:subject=medical_genetics', 'mmlu:subject=miscellaneous', 'mmlu:subject=moral_disputes', 'mmlu:subject=moral_scenarios', 'mmlu:subject=nutrition', 'mmlu:subject=philosophy', 'mmlu:subject=prehistory', 'mmlu:subject=professional_accounting', 'mmlu:subject=professional_law', 'mmlu:subject=professional_medicine', 'mmlu:subject=professional_psychology', 'mmlu:subject=public_relations', 'mmlu:subject=security_studies', 'mmlu:subject=sociology', 'mmlu:subject=us_foreign_policy', 'mmlu:subject=virology', 'mmlu:subject=world_religions']
'''

# validate all models used the same inputs and the same order
examples = {}
models_to_exclude = set()
for d in os.listdir(save_dir):
    dataset, model = parse_dirname(d)
    with open(os.path.join(save_dir, d, "instances.json")) as f:
        instances = [ex['id'] for ex in json.load(f)]
    if dataset in examples:
        if instances != examples[dataset]:
            print(f"{d} is different")
            models_to_exclude.add(model)
    else:
        examples[dataset] = instances
    # validate evaluation results are complete and use the same order
    with open(os.path.join(save_dir, d, "per_instance_stats.json")) as f:
        instances = [ex['instance_id'] for ex in json.load(f)]
    if examples[dataset] != instances:
        print(f"{d} has different instances in stats") 
        models_to_exclude.add(model)
    # validate display_requests files are complete and use the same order
    with open(os.path.join(save_dir, d, "display_requests.json")) as f:
        instances = [ex['instance_id'] for ex in json.load(f)]
    if examples[dataset] != instances:
        print(f"{d} has different instances in display")
        models_to_exclude.add(model)
        
print(f"{len(models_to_exclude)} models to exclude: {models_to_exclude}")
models = models.difference(models_to_exclude)
print(len(models))
print(models)

'''
0 models to exclude: set()
51
{'mistralai_mistral-7b-instruct-v0.3', 'allenai_olmo-7b', 'writer_palmyra-x-v3', '01-ai_yi-34b', 'openai_gpt-4-1106-preview', 'google_gemini-1.5-pro-001', 'google_gemma-7b', 'meta_llama-2-13b', 'qwen_qwen2-72b-instruct', 'google_gemini-1.5-flash-001', 'qwen_qwen1.5-32b', 'meta_llama-3-70b', 'meta_llama-2-70b', 'ai21_jamba-instruct', 'mistralai_mistral-small-2402', '01-ai_yi-large-preview', 'microsoft_phi-2', 'anthropic_claude-3-haiku-20240307', 'meta_llama-3-8b', 'qwen_qwen1.5-72b', 'qwen_qwen1.5-110b-chat', 'deepseek-ai_deepseek-llm-67b-chat', 'mistralai_mixtral-8x7b-32kseqlen', 'openai_gpt-4-turbo-2024-04-09', 'google_gemini-1.0-pro-001', 'anthropic_claude-instant-1.2', 'qwen_qwen1.5-14b', '01-ai_yi-6b', 'allenai_olmo-1.7-7b', 'mistralai_mixtral-8x22b', 'openai_gpt-3.5-turbo-0613', 'cohere_command-r-plus', 'google_gemma-2-27b', 'google_gemini-1.5-flash-preview-0514', 'meta_llama-2-7b', 'google_text-unicorn@001', 'databricks_dbrx-instruct', 'google_text-bison@001', 'anthropic_claude-3-5-sonnet-20240620', 'qwen_qwen1.5-7b', 'snowflake_snowflake-arctic-instruct', 'microsoft_phi-3-medium-4k-instruct', 'openai_gpt-4-0613', 'mistralai_mistral-7b-v0.1', 'google_gemini-1.5-pro-preview-0409', 'cohere_command-r', 'openai_gpt-4o-2024-05-13', 'anthropic_claude-2.1', 'microsoft_phi-3-small-8k-instruct', 'mistralai_mistral-large-2402', 'google_gemma-2-9b'}
'''

# get evaluation results
def get_primary_metrics(dataset, stats):
    stat_name, stat_type = ("exact_match", "binary")

    values = []
    for stat in stats:
        for st in stat:
            if st['name']['name'] == stat_name:
                values.append(st['mean'])
                
    input_tokens = []
    for stat in stats:
        for st in stat:
            if st['name']['name'] == "num_prompt_tokens":
                input_tokens.append(st['mean'])
                
    assert len(input_tokens) == len(values)
    
    output_tokens = []
    for stat in stats:
        for st in stat:
            if st['name']['name'] == "num_output_tokens":
                output_tokens.append(st['mean'])
                
    assert len(output_tokens) == len(input_tokens)

    return {
        "stat_type": stat_type,
        "stat_name": stat_name,
        "values": values,
        "input_tokens": input_tokens,
        "output_tokens": output_tokens,
    }
    
def read_eval_results(model):
    results = {}
    for d in model_to_datasets[model]:
        dataset = parse_dirname(d)[0]
        with open(os.path.join(save_dir, d, "per_instance_stats.json")) as f:
            stats = [ex['stats'] for ex in json.load(f)]
        results[dataset] = get_primary_metrics(dataset, stats)
            
    return results

data = {}
for model in list(models):
    data[model] = read_eval_results(model)
    
torch.save(data, os.path.join(root_dir, "data/helm-mmlu/data.pth"))

def embed_prompts(model, embed):
    outputs = {}
    for d in model_to_datasets[model]:
        dataset = parse_dirname(d)[0]
        with open(os.path.join(save_dir, d, "display_requests.json")) as f:
            prompts = [ex['request']['prompt'] for ex in json.load(f)]
        embeddings = get_embeddings(embed, prompts, 4)

        outputs[dataset] = {
            "prompts": prompts,
            "embeddings": embeddings
        }

    return outputs

embeddings = embed_prompts(list(models)[0], MODELS["sfr2"]())
torch.save(embeddings, os.path.join(root_dir, "data/helm-mmlu/sfr2.pth"))