import random
import torch
from torch.utils.data import Dataset
import tiktoken

from llm_router.data.utils import get_costs

enc = tiktoken.get_encoding("o200k_base")

def estimate_token_count(text):
    return len(enc.encode(text))


class AlpacaEval(Dataset):
    CONFIGS = {
        "all": {
            'Conifer-7B-DPO': None,
            'Contextual-KTO-Mistral-PairRM': None,
            'Ein-70B-v0.1': None,
            'FsfairX-Zephyr-Chat-v0.1': None,
            'Infinity-Instruct-3M-0613-Llama3-70B': None,
            'Infinity-Instruct-3M-0613-Mistral-7B': None,
            'Infinity-Instruct-3M-0625-Llama3-70B': None,
            'Infinity-Instruct-3M-0625-Llama3-8B': None,
            'Infinity-Instruct-3M-0625-Mistral-7B': None,
            'Infinity-Instruct-3M-0625-Qwen2-7B': None,
            'Infinity-Instruct-3M-0625-Yi-1.5-9B': None,
            'LMCocktail-10.7B-v1': None,
            'Llama-3-Instruct-8B-SimPO': None,
            'Llama-3-Instruct-8B-SimPO-ExPO': None,
            'Llama-3-Instruct-8B-WPO-HB-v2': None,
            'Meta-Llama-3-70B-Instruct': None,
            'Meta-Llama-3-8B-Instruct': None,
            'Meta-Llama-3.1-405B-Instruct-Turbo': None,
            'Meta-Llama-3.1-70B-Instruct-Turbo': None,
            'Meta-Llama-3.1-8B-Instruct-Turbo': None,
            'Mistral-7B+RAHF-DUAL+LoRA': None,
            'Mistral-7B-Instruct-v0.2': None,
            'Mistral-7B-ReMax-v0.1': None,
            'Mixtral-8x22B-Instruct-v0.1': None,
            'Mixtral-8x7B-Instruct-v0.1': None,
            'Mixtral-8x7B-Instruct-v0.1_concise': None,
            'Mixtral-8x7B-Instruct-v0.1_verbose': None,
            'Nanbeige-Plus-Chat-v0.1': None,
            'Nanbeige2-16B-Chat': None,
            'Nanbeige2-8B-Chat': None,
            'OpenHermes-2.5-Mistral-7B': None,
            'Qwen-14B-Chat': None,
            'Qwen1.5-1.8B-Chat': None,
            'Qwen1.5-110B-Chat': None,
            'Qwen1.5-14B-Chat': None,
            'Qwen1.5-72B-Chat': None,
            'Qwen1.5-7B-Chat': None,
            'REBEL-Llama-3-8B-Instruct': None,
            'SPPO-Gemma-2-9B-It-PairRM': None,
            'SPPO-Llama-3-Instruct-8B-PairRM': None,
            'SPPO-Mistral7B-PairRM': None,
            'SPPO-Mistral7B-PairRM-ExPO': None,
            'Samba-CoE-v0.1': None,
            'Samba-CoE-v0.2': None,
            'Samba-CoE-v0.2-best-of-16': None,
            'Snorkel-Mistral-PairRM-DPO': None,
            'Snorkel-Mistral-PairRM-DPO-best-of-16': None,
            'Starling-LM-7B-alpha': None,
            'Starling-LM-7B-alpha-ExPO': None,
            'Starling-LM-7B-beta-ExPO': None,
            'Storm-7B': None,
            'Storm-7B-best-of-64': None,
            'TempNet-LLaMA2-Chat-13B-v0.1': None,
            'TempNet-LLaMA2-Chat-70B-v0.1': None,
            'TempNet-LLaMA2-Chat-7B-v0.1': None,
            'Together-MoA': None,
            'Together-MoA-Lite': None,
            'Yi-34B-Chat': None,
            'airoboros-33b': None,
            'airoboros-65b': None,
            'aligner-2b_claude-3-opus-20240229': None,
            'aligner-2b_gpt-4-turbo-2024-04-09': None,
            'aligner-2b_qwen1.5-72b-chat': None,
            'alpaca-7b': None,
            'alpaca-7b-neft': None,
            'alpaca-7b_concise': None,
            'alpaca-7b_verbose': None,
            'alpaca-farm-ppo-human': None,
            'alpaca-farm-ppo-sim-gpt4-20k': None,
            'baichuan-13b-chat': None,
            'baize-v2-13b': None,
            'baize-v2-7b': None,
            'bedrock_claude': None,
            'causallm-14b': None,
            'chatglm2-6b': None,
            'claude': None,
            'claude-2': None,
            'claude-2.1': None,
            'claude-2.1_concise': None,
            'claude-2.1_verbose': None,
            'claude-3-5-sonnet-20240620': None,
            'claude-3-opus-20240229': None,
            'claude-3-sonnet-20240229': None,
            'claude-instant-1.2': None,
            'claude2-alpaca-13b': None,
            'cohere': None,
            'cut-13b': None,
            'dbrx-instruct': None,
            'deepseek-llm-67b-chat': None,
            'deita-7b-v1.0': None,
            'dolphin-2.2.1-mistral-7b': None,
            'evo-7b': None,
            'evo-v2-7b': None,
            'falcon-40b-instruct': None,
            'falcon-7b-instruct': None,
            'gemini-pro': None,
            'gemma-2-9b-it-DPO': None,
            'gemma-2-9b-it-SimPO': None,
            'gemma-2b-it': None,
            'gemma-7b-it': None,
            'ghost-7b-alpha': None,
            'ghost-8b-beta-disl-0x5': None,
            'gpt-3.5-turbo-0301': None,
            'gpt-3.5-turbo-0613': None,
            'gpt-3.5-turbo-1106': None,
            'gpt-3.5-turbo-1106_concise': None,
            'gpt-3.5-turbo-1106_verbose': None,
            'gpt-4-0125-preview': None,
            'gpt-4-turbo-2024-04-09': None,
            'gpt-4o-2024-05-13': None,
            'gpt35_turbo_instruct': None,
            'gpt4': None,
            'gpt4_0314': None,
            'gpt4_0613': None,
            'gpt4_0613_concise': None,
            'gpt4_0613_verbose': None,
            'gpt4_1106_preview': None,
            'gpt4_1106_preview_concise': None,
            'gpt4_1106_preview_verbose': None,
            'gpt4_gamed': None,
            'guanaco-13b': None,
            'guanaco-33b': None,
            'guanaco-65b': None,
            'guanaco-7b': None,
            'higgs-llama-3-70b-v2': None,
            'humpback-llama-65b': None,
            'humpback-llama2-70b': None,
            'internlm2-chat-20b-ExPO': None,
            'internlm2-chat-20b-ppo': None,
            'internlm2-chat-7b-ExPO': None,
            'jina-chat': None,
            'llama-2-13b-chat-hf': None,
            'llama-2-70b-chat-hf': None,
            'llama-2-7b-chat-hf': None,
            'llama-2-chat-7b-evol70k-neft': None,
            'merlinite-7B-AOT': None,
            'minichat-1.5-3b': None,
            'minichat-3b': None,
            'minotaur-13b': None,
            'mistral-large-2402': None,
            'mistral-medium': None,
            'mistral-orpo-beta': None,
            'nous-hermes-13b': None,
            'oasst-rlhf-llama-33b': None,
            'oasst-sft-llama-33b': None,
            'oasst-sft-pythia-12b': None,
            'openbuddy-falcon-40b-v9': None,
            'openbuddy-falcon-7b-v6': None,
            'openbuddy-llama-30b-v7.1': None,
            'openbuddy-llama-65b-v8': None,
            'openbuddy-llama2-13b-v11.1': None,
            'openbuddy-llama2-70b-v10.1': None,
            'openchat-13b': None,
            'openchat-v2-13b': None,
            'openchat-v2-w-13b': None,
            'openchat-v3.1-13b': None,
            'openchat8192-13b': None,
            'opencoderplus-15b': None,
            'openpipe-moa-gpt-4-turbo-v1': None,
            'pairrm-Yi-34B-Chat': None,
            'pairrm-tulu-2-13b': None,
            'pairrm-tulu-2-70b': None,
            'pairrm-zephyr-7b-beta': None,
            'phi-2': None,
            'phi-2-dpo': None,
            'phi-2-sft': None,
            'platolm-7b': None,
            'pythia-12b-mix-sft': None,
            'recycled-wizardlm-7b-v1.0': None,
            'recycled-wizardlm-7b-v2.0': None,
            'text_davinci_001': None,
            'text_davinci_003': None,
            'tulu-2-dpo-13b': None,
            'tulu-2-dpo-13b-ExPO': None,
            'tulu-2-dpo-70b': None,
            'tulu-2-dpo-70b-ExPO': None,
            'tulu-2-dpo-7b': None,
            'tulu-2-dpo-7b-ExPO': None,
            'ultralm-13b': None,
            'ultralm-13b-best-of-16': None,
            'ultralm-13b-v2.0': None,
            'ultralm-13b-v2.0-best-of-16': None,
            'vicuna-13b': None,
            'vicuna-13b-v1.3': None,
            'vicuna-13b-v1.5': None,
            'vicuna-13b-v1.5-togetherai': None,
            'vicuna-33b-v1.3': None,
            'vicuna-7b': None,
            'vicuna-7b-v1.3': None,
            'vicuna-7b-v1.5': None,
            'wizardlm-13b': None,
            'wizardlm-13b-v1.1': None,
            'wizardlm-13b-v1.2': None,
            'wizardlm-70b': None,
            'xwinlm-13b-v0.1': None,
            'xwinlm-70b-v0.1': None,
            'xwinlm-70b-v0.3': None,
            'xwinlm-7b-v0.1': None,
            'yi-large-preview': None,
            'zephyr-7b-alpha': None,
            'zephyr-7b-alpha-ExPO': None,
            'zephyr-7b-beta': None,
            'zephyr-7b-beta-ExPO': None
        },
        "gpt": {
            'gpt-3.5-turbo-0301': {"prompt": 1.5, "completion": 2.0},
            'gpt-3.5-turbo-0613': {"prompt": 1.5, "completion": 2.0},
            'gpt-3.5-turbo-1106': {"prompt": 1.0, "completion": 2.0},
            'gpt-4-0125-preview': {"prompt": 10, "completion": 30},
            'gpt-4o-2024-05-13': {"prompt": 5, "completion": 15},
            'gpt4': {"prompt": 30, "completion": 60},
            'gpt4_0314': {"prompt": 30, "completion": 60},
            'gpt4_0613': {"prompt": 30, "completion": 60},
            'gpt4_1106_preview': {"prompt": 10, "completion": 30},
        },
        "gpt_selected": {
            'gpt-3.5-turbo-0613': {"prompt": 1.5, "completion": 2.0},
            'gpt-3.5-turbo-1106': {"prompt": 1.0, "completion": 2.0},
            'gpt-4-0125-preview': {"prompt": 10, "completion": 30},
            'gpt-4o-2024-05-13': {"prompt": 5, "completion": 15},
            'gpt4': {"prompt": 30, "completion": 60},
        },
        "mistral": {
            'Mistral-7B-Instruct-v0.2': {"prompt": 0.25, "completion": 0.25},
            'Mixtral-8x22B-Instruct-v0.1': {"prompt": 2, "completion": 6},
            'Mixtral-8x7B-Instruct-v0.1': {"prompt": 0.7, "completion": 0.7},
            'mistral-large-2402': {"prompt": 8, "completion": 24},
            'mistral-medium': {"prompt": 2.7, "completion": 8.1},
        },
        "claude": {
            'claude-2': {"prompt": 8, "completion": 24},
            'claude-2.1': {"prompt": 8, "completion": 24},
            'claude-3-5-sonnet-20240620': {"prompt": 3, "completion": 15},
            'claude-3-opus-20240229': {"prompt": 15, "completion": 75},
            'claude-3-sonnet-20240229': {"prompt": 3, "completion": 15},
            'claude-instant-1.2': {"prompt": 0.8, "completion": 2.4},
        }
    }
    
    def __init__(self, data_file: str, embed_file: str, config: str, split: str, seed: int):
        super().__init__()
        
        self.data = torch.load(data_file)
        self.embed = torch.load(embed_file)
        self.config = config
        self.split = split
        
        num_prompts = len(self.data['instructions'])
        num_train = int(num_prompts * 0.7)
        num_test = num_prompts - num_train
        
        random.seed(seed)
        train_idx = list(random.sample(range(num_prompts), k=num_train))
        test_idx = [i for i in range(num_prompts) if i not in train_idx]
        
        if split == "train":
            self.prompt_idx = train_idx
        elif split == "test":
            self.prompt_idx = test_idx
        elif split == "all":
            self.prompt_idx = list(random.sample(range(num_prompts), k=num_prompts))
            
    @property
    def name(self):
        return f"alpaca_eval::{self.config}::{self.split}"
    
    @property
    def routing_config(self):
        return AlpacaEval.CONFIGS[self.config]
    
    def __len__(self):
        return len(self.prompt_idx)
    
    def _get_prompt(self, idx):
        return self.data['instructions'][idx]
    
    def _get_embedding(self, idx):
        return self.embed[idx].cpu().float()
    
    def _get_scores(self, idx):
        scores = []
        for model in AlpacaEval.CONFIGS[self.config].keys():
            scores.append(self.data['eval_results'][model][idx])
            
        return torch.stack(scores).float()
    
    def _get_input_length(self, idx):
        prompt = self._get_prompt(idx)
        prompt_tokens = estimate_token_count(prompt)
        input_tokens = torch.tensor([prompt_tokens for _ in range(len(AlpacaEval.CONFIGS[self.config].keys()))]).float()
        
        return input_tokens
    
    def _get_output_length(self, idx):
        output_length = []
        for model in AlpacaEval.CONFIGS[self.config].keys():
            output_length.append(estimate_token_count(self.data['outputs'][model][idx]))
            
        return torch.tensor(output_length).float()
    
    def __getitem__(self, idx):
        index = self.prompt_idx[idx]
        
        prompt = self._get_prompt(index)
        embedding = self._get_embedding(index)
        scores = self._get_scores(index)
        input_tokens = self._get_input_length(index)
        output_tokens = self._get_output_length(index)
        costs = get_costs(input_tokens, output_tokens, self.routing_config)
        
        return {
            "idx": idx,
            "routing_config": self.routing_config,
            "prompt": prompt,
            "embedding": embedding,
            "scores": scores,
            "costs": torch.tensor(costs).float(),
            "input_tokens": input_tokens,
            "output_tokens": output_tokens,
        }
        
    def get_benchmark(self):
        prompts = [self._get_prompt(idx) for idx in self.prompt_idx]
        embeddings = torch.stack([self._get_embedding(idx) for idx in self.prompt_idx])
        scores = torch.stack([self._get_scores(idx) for idx in self.prompt_idx])
        input_tokens = torch.stack([self._get_input_length(idx) for idx in self.prompt_idx])
        output_tokens = torch.stack([self._get_output_length(idx) for idx in self.prompt_idx])
        costs = [get_costs(inp, out, self.routing_config) for inp, out in zip(input_tokens, output_tokens)]
        
        return {
            "routing_config": self.routing_config,
            "prompts": prompts,
            "embeddings": embeddings,
            "scores": scores,
            "costs": torch.tensor(costs).float(),
            "input_tokens": input_tokens,
            "output_tokens": output_tokens,
        }