import os

root_dir = os.path.join(os.path.dirname(__file__), "../")

import ast
import torch
import pandas as pd

from text_embeddings import MODELS, get_embeddings

import tiktoken
enc = tiktoken.get_encoding("o200k_base")

save_dir = os.path.join(root_dir, "data/routerbench/raw")
os.makedirs(save_dir, exist_ok=True)

MODEL_MAPPING = {
    "gpt-3.5-turbo-1106": "GPT-3.5",
    "claude-instant-v1": "Claude Instant V1",
    "claude-v1": "Claude V1",
    "claude-v2": "Claude V2",
    "gpt-4-1106-preview": "GPT-4",
    "meta/llama-2-70b-chat": "Llama 70B",
    "mistralai/mixtral-8x7b-chat": "Mixtral 8x7B",
    "zero-one-ai/Yi-34B-Chat": "Yi 34B",
    "WizardLM/WizardLM-13B-V1.2": "WizardLM 13B",
    "meta/code-llama-instruct-34b-chat": "Code Llama 34B",
    "mistralai/mistral-7b-chat": "Mistral 7B",
}

def get_input_tokens(prompts):
    input_tokens = []
    for prompt in prompts:
        assert prompt[0] == "[" and prompt[-1] == "]"
        prompt = ast.literal_eval(prompt)
        input_tokens.append(sum([len(enc.encode(p)) for p in prompt]))
        
    return input_tokens

def get_output_tokens(responses):
    output_tokens = []
    for resp in responses:
        assert resp[0] == "[" and resp[-1] == "]"
        resp = ast.literal_eval(resp)
        output_tokens.append(sum([len(enc.encode(p)) for p in resp]))
        
    return output_tokens

def process_dataset(df):
    results = {}
    
    models = list(MODEL_MAPPING.keys())
    df = df.dropna(subset=models)
    
    prompts = df["prompt"].tolist()
    input_tokens = get_input_tokens(prompts)
    for model in models:
        results[MODEL_MAPPING[model]] = {}
        results[MODEL_MAPPING[model]]["scores"] = df[model].tolist()
        results[MODEL_MAPPING[model]]["input_tokens"] = input_tokens
        results[MODEL_MAPPING[model]]["output_tokens"] = get_output_tokens(df[f"{model}|model_response"].tolist())
       
    embeddings = {
        "prompts": prompts,
        "embeddings": get_embeddings(MODELS["bert"](), prompts, 8)
    }
     
    return results, embeddings

def process(raw_file, save_dir):
    data = {}
    embed = {}
    
    df = pd.read_pickle(raw_file)    
    for dataset_name in df['eval_name'].unique():
        d, e = process_dataset(df[df['eval_name'] == dataset_name])
        data[dataset_name] = d
        embed[dataset_name] = e
        
    os.makedirs(save_dir, exist_ok=True)
    torch.save(data, os.path.join(save_dir, "data.pth"))
    torch.save(embed, os.path.join(save_dir, "bert.pth"))
    
    
process(os.path.join(save_dir, "routerbench_0shot.pkl"), os.path.join(root_dir, "data/routerbench/0shot"))
process(os.path.join(save_dir, "routerbench_5shot.pkl"), os.path.join(root_dir, "data/routerbench/5shot"))