import random
import torch
from torch.utils.data import Dataset

from llm_router.data.utils import get_costs

class HELMMMLU(Dataset):
    CONFIGS = {
        "all": {
            '01-ai_yi-34b': None,
            '01-ai_yi-6b': None,
            '01-ai_yi-large-preview': None,
            'ai21_jamba-instruct': None,
            'allenai_olmo-1.7-7b': None,
            'allenai_olmo-7b': None,
            'anthropic_claude-2.1': None,
            'anthropic_claude-3-5-sonnet-20240620': None,
            'anthropic_claude-3-haiku-20240307': None,
            'anthropic_claude-instant-1.2': None,
            'cohere_command-r': None,
            'cohere_command-r-plus': None,
            'databricks_dbrx-instruct': None,
            'deepseek-ai_deepseek-llm-67b-chat': None,
            'google_gemini-1.0-pro-001': None,
            'google_gemini-1.5-flash-001': None,
            'google_gemini-1.5-flash-preview-0514': None,
            'google_gemini-1.5-pro-001': None,
            'google_gemini-1.5-pro-preview-0409': None,
            'google_gemma-2-27b': None,
            'google_gemma-2-9b': None,
            'google_gemma-7b': None,
            'google_text-bison@001': None,
            'google_text-unicorn@001': None,
            'meta_llama-2-13b': None,
            'meta_llama-2-70b': None,
            'meta_llama-2-7b': None,
            'meta_llama-3-70b': None,
            'meta_llama-3-8b': None,
            'microsoft_phi-2': None,
            'microsoft_phi-3-medium-4k-instruct': None,
            'microsoft_phi-3-small-8k-instruct': None,
            'mistralai_mistral-7b-instruct-v0.3': None,
            'mistralai_mistral-7b-v0.1': None,
            'mistralai_mistral-large-2402': None,
            'mistralai_mistral-small-2402': None,
            'mistralai_mixtral-8x22b': None,
            'mistralai_mixtral-8x7b-32kseqlen': None,
            'openai_gpt-3.5-turbo-0613': None,
            'openai_gpt-4-0613': None,
            'openai_gpt-4-1106-preview': None,
            'openai_gpt-4-turbo-2024-04-09': None,
            'openai_gpt-4o-2024-05-13': None,
            'qwen_qwen1.5-110b-chat': None,
            'qwen_qwen1.5-14b': None,
            'qwen_qwen1.5-32b': None,
            'qwen_qwen1.5-72b': None,
            'qwen_qwen1.5-7b': None,
            'qwen_qwen2-72b-instruct': None,
            'snowflake_snowflake-arctic-instruct': None,
            'writer_palmyra-x-v3': None
        }
    }
    
    DATASETS = ['mmlu:subject=abstract_algebra', 'mmlu:subject=anatomy', 'mmlu:subject=astronomy', 'mmlu:subject=business_ethics', 'mmlu:subject=clinical_knowledge', 'mmlu:subject=college_biology', 'mmlu:subject=college_chemistry', 'mmlu:subject=college_computer_science', 'mmlu:subject=college_mathematics', 'mmlu:subject=college_medicine', 'mmlu:subject=college_physics', 'mmlu:subject=computer_security', 'mmlu:subject=conceptual_physics', 'mmlu:subject=econometrics', 'mmlu:subject=electrical_engineering', 'mmlu:subject=elementary_mathematics', 'mmlu:subject=formal_logic', 'mmlu:subject=global_facts', 'mmlu:subject=high_school_biology', 'mmlu:subject=high_school_chemistry', 'mmlu:subject=high_school_computer_science', 'mmlu:subject=high_school_european_history', 'mmlu:subject=high_school_geography', 'mmlu:subject=high_school_government_and_politics', 'mmlu:subject=high_school_macroeconomics', 'mmlu:subject=high_school_mathematics', 'mmlu:subject=high_school_microeconomics', 'mmlu:subject=high_school_physics', 'mmlu:subject=high_school_psychology', 'mmlu:subject=high_school_statistics', 'mmlu:subject=high_school_us_history', 'mmlu:subject=high_school_world_history', 'mmlu:subject=human_aging', 'mmlu:subject=human_sexuality', 'mmlu:subject=international_law', 'mmlu:subject=jurisprudence', 'mmlu:subject=logical_fallacies', 'mmlu:subject=machine_learning', 'mmlu:subject=management', 'mmlu:subject=marketing', 'mmlu:subject=medical_genetics', 'mmlu:subject=miscellaneous', 'mmlu:subject=moral_disputes', 'mmlu:subject=moral_scenarios', 'mmlu:subject=nutrition', 'mmlu:subject=philosophy', 'mmlu:subject=prehistory', 'mmlu:subject=professional_accounting', 'mmlu:subject=professional_law', 'mmlu:subject=professional_medicine', 'mmlu:subject=professional_psychology', 'mmlu:subject=public_relations', 'mmlu:subject=security_studies', 'mmlu:subject=sociology', 'mmlu:subject=us_foreign_policy', 'mmlu:subject=virology', 'mmlu:subject=world_religions']
    
    def __init__(self, data_file: str, embed_file: str, config: str, split: str, seed: int):
        super().__init__()
        
        self.data = torch.load(data_file)
        self.embed = torch.load(embed_file)
        self.config = config
        self.split = split
        
        # splits
        dataset_size = {
            d: (len(self.embed[d]["prompts"]), int(len(self.embed[d]["prompts"]) * 0.7), len(self.embed[d]["prompts"]) - int(len(self.embed[d]["prompts"]) * 0.7))
            for d in HELMMMLU.DATASETS
        }
        
        random.seed(seed)
        train_idx = {
            d: list(random.sample(range(len(self.embed[d]["prompts"])), k=dataset_size[d][1]))
            for d in HELMMMLU.DATASETS
        }
        test_idx = {
            d: [i for i in range(len(self.embed[d]["prompts"])) if i not in train_idx[d]]
            for d in HELMMMLU.DATASETS
        }
        all_idx = {
            d: list(random.sample(range(len(self.embed[d]["prompts"])), k=dataset_size[d][0]))
            for d in HELMMMLU.DATASETS
        }
        seen_idx = {
            d: all_idx[d]
            for d in HELMMMLU.DATASETS[:-15]
        }
        unseen_idx = {
            d: all_idx[d]
            for d in HELMMMLU.DATASETS[-15:]
        }
        
        if split == "train":
            self.prompt_idx = train_idx
        elif split == "test":
            self.prompt_idx = test_idx
        elif split == "all":
            self.prompt_idx = all_idx
        elif split == "seen":
            self.prompt_idx = seen_idx
        elif split == "unseen":
            self.prompt_idx = unseen_idx
            
        # used for looking up
        idx_mapping = {}
        idx = 0
        for d in self.prompt_idx.keys():
            for index in self.prompt_idx[d]:
                idx_mapping[idx] = (d, index)
                idx += 1
        self.idx_mapping = idx_mapping
        
    @property
    def name(self):
        return f"helm_mmlu::{self.config}::{self.split}"
    
    @property
    def routing_config(self):
        return HELMMMLU.CONFIGS[self.config]
    
    def __len__(self):
        return sum(len(v) for v in self.prompt_idx.values())
    
    def _get_prompt(self, dataset, index):
        return self.embed[dataset]["prompts"][index]
    
    def _get_embedding(self, dataset, index):
        return self.embed[dataset]["embeddings"][index].cpu().float()
    
    def _get_scores(self, dataset, index):
        scores = []
        for model in HELMMMLU.CONFIGS[self.config].keys():
            stat_type = self.data[model][dataset]["stat_type"]
            stat_name = self.data[model][dataset]["stat_name"]
            stat_value = self.data[model][dataset]["values"][index]
            scores.append(stat_value)
            
        return torch.tensor(scores).float()
    
    def _get_input_tokens(self, dataset, index):
        input_tokens = []
        for model in HELMMMLU.CONFIGS[self.config].keys():
            input_tokens.append(self.data[model][dataset]["input_tokens"][index])
            
        return torch.tensor(input_tokens).float()
    
    def _get_output_tokens(self, dataset, index):
        output_tokens = []
        for model in HELMMMLU.CONFIGS[self.config].keys():
            output_tokens.append(self.data[model][dataset]["output_tokens"][index])
            
        return torch.tensor(output_tokens).float()
    
    def __getitem__(self, idx):
        dataset, index = self.idx_mapping[idx]
        
        prompt = self._get_prompt(dataset, index)
        embedding = self._get_embedding(dataset, index)
        scores = self._get_scores(dataset, index)
        input_tokens = self._get_input_tokens(dataset, index)
        output_tokens = self._get_output_tokens(dataset, index)
        costs = get_costs(input_tokens, output_tokens, self.routing_config)
        
        return {
            "idx": idx,
            "routing_config": self.routing_config,
            "prompt": prompt,
            "embedding": embedding,
            "scores": scores,
            "costs": torch.tensor(costs).float(),
            "input_tokens": input_tokens,
            "output_tokens": output_tokens,
        }
        
    def get_benchmark(self):
        prompts = []
        embeddings = []
        scores = []
        input_tokens = []
        output_tokens = []
        for dataset in self.prompt_idx.keys():
            for index in self.prompt_idx[dataset]:
                prompts.append(self._get_prompt(dataset, index))
                embeddings.append(self._get_embedding(dataset, index))
                scores.append(self._get_scores(dataset, index))
                input_tokens.append(self._get_input_tokens(dataset, index))
                output_tokens.append(self._get_output_tokens(dataset, index))
        embeddings = torch.stack(embeddings)
        scores = torch.stack(scores)
        input_tokens = torch.stack(input_tokens)
        output_tokens = torch.stack(output_tokens)
        costs = [get_costs(inp, out, self.routing_config) for inp, out in zip(input_tokens, output_tokens)]

        return {
            "routing_config": self.routing_config,
            "prompts": prompts,
            "embeddings": embeddings,
            "scores": scores,
            "costs": costs,
            "input_tokens": input_tokens,
            "output_tokens": output_tokens,
        }