import random
import torch
from torch.utils.data import Dataset

from llm_router.data.utils import get_costs

class OpenLLM2(Dataset):
    CONFIGS = {
        "all": {
            'open-llm-leaderboard/01-ai__Yi-1.5-34B-Chat-details': None,
            'open-llm-leaderboard/01-ai__Yi-1.5-6B-Chat-details': None,
            'open-llm-leaderboard/01-ai__Yi-1.5-9B-Chat-details': None,
            'open-llm-leaderboard/01-ai__Yi-34B-Chat-details': None,
            'open-llm-leaderboard/01-ai__Yi-6B-Chat-details': None,
            'open-llm-leaderboard/Qwen__Qwen1.5-0.5B-Chat-details': None,
            'open-llm-leaderboard/Qwen__Qwen1.5-1.8B-Chat-details': None,
            'open-llm-leaderboard/Qwen__Qwen1.5-110B-Chat-details': None,
            'open-llm-leaderboard/Qwen__Qwen1.5-14B-Chat-details': None,
            'open-llm-leaderboard/Qwen__Qwen1.5-32B-Chat-details': None,
            'open-llm-leaderboard/Qwen__Qwen1.5-4B-Chat-details': None,
            'open-llm-leaderboard/Qwen__Qwen1.5-7B-Chat-details': None,
            'open-llm-leaderboard/Qwen__Qwen1.5-MoE-A2.7B-Chat-details': None,
            'open-llm-leaderboard/Qwen__Qwen2-0.5B-Instruct-details': None,
            'open-llm-leaderboard/Qwen__Qwen2-1.5B-Instruct-details': None,
            'open-llm-leaderboard/Qwen__Qwen2-72B-Instruct-details': None,
            'open-llm-leaderboard/Qwen__Qwen2-7B-Instruct-details': None,
            'open-llm-leaderboard/google__flan-t5-small-details': None,
            'open-llm-leaderboard/google__gemma-2b-it-details': None,
            'open-llm-leaderboard/google__gemma-7b-it-details': None,
            'open-llm-leaderboard/google__recurrentgemma-2b-it-details': None,
            'open-llm-leaderboard/google__recurrentgemma-9b-it-details': None,
            'open-llm-leaderboard/meta-llama__Llama-2-13b-chat-hf-details': None,
            'open-llm-leaderboard/meta-llama__Llama-2-70b-chat-hf-details': None,
            'open-llm-leaderboard/meta-llama__Llama-2-7b-chat-hf-details': None,
            'open-llm-leaderboard/meta-llama__Meta-Llama-3-70B-Instruct-details': None,
            'open-llm-leaderboard/meta-llama__Meta-Llama-3-8B-Instruct-details': None,
            'open-llm-leaderboard/microsoft__Phi-3-medium-4k-instruct-details': None,
            'open-llm-leaderboard/microsoft__Phi-3-mini-4k-instruct-details': None,
            'open-llm-leaderboard/mistralai__Mistral-7B-Instruct-v0.1-details': None,
            'open-llm-leaderboard/mistralai__Mistral-7B-Instruct-v0.2-details': None,
            'open-llm-leaderboard/mistralai__Mistral-7B-Instruct-v0.3-details': None,
            'open-llm-leaderboard/mistralai__Mixtral-8x7B-Instruct-v0.1-details': None,
            'open-llm-leaderboard/openai-community__gpt2-details': None,
            'open-llm-leaderboard/openai-community__gpt2-large-details': None,
            'open-llm-leaderboard/openai-community__gpt2-medium-details': None,
            'open-llm-leaderboard/openai-community__gpt2-xl-details': None,
            'open-llm-leaderboard/BAAI__Infinity-Instruct-3M-0625-Llama3-70B-details': None,
            'open-llm-leaderboard/BAAI__Infinity-Instruct-3M-0625-Llama3-8B-details': None,
            'open-llm-leaderboard/CohereForAI__aya-23-35B-details': None,
            'open-llm-leaderboard/CohereForAI__aya-23-8B-details': None,
            'open-llm-leaderboard/Qwen__Qwen2.5-0.5B-Instruct-details': None,
            'open-llm-leaderboard/Qwen__Qwen2.5-1.5B-Instruct-details': None,
            'open-llm-leaderboard/Qwen__Qwen2.5-7B-Instruct-details': None,
            'open-llm-leaderboard/Qwen__Qwen2.5-14B-Instruct-details': None,
            'open-llm-leaderboard/Qwen__Qwen2.5-32B-Instruct-details': None,
            'open-llm-leaderboard/Qwen__Qwen2.5-72B-Instruct-details': None,
            'open-llm-leaderboard/internlm__internlm2_5-1_8b-chat-details': None,
            'open-llm-leaderboard/internlm__internlm2_5-20b-chat-details': None,
            'open-llm-leaderboard/internlm__internlm2_5-7b-chat-details': None
        },
        "qwen2.5": {
            'open-llm-leaderboard/Qwen__Qwen2.5-0.5B-Instruct-details': {"prompt": 0.08, "completion": 0.08},
            'open-llm-leaderboard/Qwen__Qwen2.5-1.5B-Instruct-details': {"prompt": 0.2, "completion": 0.2},
            'open-llm-leaderboard/Qwen__Qwen2.5-7B-Instruct-details': {"prompt": 0.3, "completion": 0.3},
            'open-llm-leaderboard/Qwen__Qwen2.5-14B-Instruct-details': {"prompt": 0.8, "completion": 0.8},
            'open-llm-leaderboard/Qwen__Qwen2.5-32B-Instruct-details': {"prompt": 0.8, "completion": 0.8},
            'open-llm-leaderboard/Qwen__Qwen2.5-72B-Instruct-details': {"prompt": 1.2, "completion": 1.2},
        },
        "yi1.5": {
            'open-llm-leaderboard/01-ai__Yi-1.5-34B-Chat-details': {"prompt": 0.8, "completion": 0.8},
            'open-llm-leaderboard/01-ai__Yi-1.5-6B-Chat-details': {"prompt": 0.3, "completion": 0.3},
            'open-llm-leaderboard/01-ai__Yi-1.5-9B-Chat-details': {"prompt": 0.4, "completion": 0.4},
        },
        "llama3": {
            'open-llm-leaderboard/meta-llama__Meta-Llama-3-70B-Instruct-details': {"prompt": 0.9, "completion": 0.9},
            'open-llm-leaderboard/meta-llama__Meta-Llama-3-8B-Instruct-details': {"prompt": 0.2, "completion": 0.2},
        }
    }
    
    DATASETS = {
        'bbh_boolean_expressions': 'acc_norm',
        'bbh_causal_judgement': 'acc_norm',
        'bbh_date_understanding': 'acc_norm',
        'bbh_disambiguation_qa': 'acc_norm',
        'bbh_formal_fallacies': 'acc_norm',
        'bbh_geometric_shapes': 'acc_norm',
        'bbh_hyperbaton': 'acc_norm',
        'bbh_logical_deduction_five_objects': 'acc_norm',
        'bbh_logical_deduction_seven_objects': 'acc_norm',
        'bbh_logical_deduction_three_objects': 'acc_norm',
        'bbh_movie_recommendation': 'acc_norm',
        'bbh_navigate': 'acc_norm',
        'bbh_object_counting': 'acc_norm',
        'bbh_penguins_in_a_table': 'acc_norm',
        'bbh_reasoning_about_colored_objects': 'acc_norm',
        'bbh_ruin_names': 'acc_norm',
        'bbh_salient_translation_error_detection': 'acc_norm',
        'bbh_snarks': 'acc_norm',
        'bbh_sports_understanding': 'acc_norm',
        'bbh_temporal_sequences': 'acc_norm',
        'bbh_tracking_shuffled_objects_five_objects': 'acc_norm',
        'bbh_tracking_shuffled_objects_seven_objects': 'acc_norm',
        'bbh_tracking_shuffled_objects_three_objects': 'acc_norm',
        'bbh_web_of_lies': 'acc_norm',
        'gpqa_diamond': 'acc_norm',
        'gpqa_extended': 'acc_norm',
        'gpqa_main': 'acc_norm',
        'ifeval': 'prompt_level_strict_acc',
        'math_algebra_hard': 'exact_match',
        'math_counting_and_prob_hard': 'exact_match',
        'math_geometry_hard': 'exact_match',
        'math_intermediate_algebra_hard': 'exact_match',
        'math_num_theory_hard': 'exact_match',
        'math_prealgebra_hard': 'exact_match',
        'math_precalculus_hard': 'exact_match',
        'mmlu_pro': 'acc',
        'musr_murder_mysteries': 'acc_norm',
        'musr_object_placements': 'acc_norm',
        'musr_team_allocation': 'acc_norm'
    }
    
    def __init__(self, data_file: str, embed_file: str, config: str, split: str, seed: int):
        super().__init__()
        
        self.data = torch.load(data_file)
        self.embed = torch.load(embed_file)
        self.config = config
        self.split = split
        
        # splits
        dataset_size = {
            d: (len(self.embed[d]["prompts"]), int(len(self.embed[d]["prompts"]) * 0.7), len(self.embed[d]["prompts"]) - int(len(self.embed[d]["prompts"]) * 0.7))
            for d in OpenLLM2.DATASETS
        }
        
        random.seed(seed)
        train_idx = {
            d: list(random.sample(range(len(self.embed[d]["prompts"])), k=dataset_size[d][1]))
            for d in OpenLLM2.DATASETS
        }
        test_idx = {
            d: [i for i in range(len(self.embed[d]["prompts"])) if i not in train_idx[d]]
            for d in OpenLLM2.DATASETS
        }
        all_idx = {
            d: list(random.sample(range(len(self.embed[d]["prompts"])), k=dataset_size[d][0]))
            for d in OpenLLM2.DATASETS
        }
        
        if split == "train":
            self.prompt_idx = train_idx
        elif split == "test":
            self.prompt_idx = test_idx
        elif split == "all":
            self.prompt_idx = all_idx
            
        # used for looking up
        idx_mapping = {}
        idx = 0
        for d in self.prompt_idx.keys():
            for index in self.prompt_idx[d]:
                idx_mapping[idx] = (d, index)
                idx += 1
        self.idx_mapping = idx_mapping
        
    @property
    def name(self):
        return f"openllm2::{self.config}::{self.split}"
    
    @property
    def routing_config(self):
        return OpenLLM2.CONFIGS[self.config]
        
    def __len__(self):
        return sum(len(v) for v in self.prompt_idx.values())
    
    def _get_prompt(self, dataset, index):
        return self.embed[dataset]["prompts"][index]
    
    def _get_embedding(self, dataset, index):
        return self.embed[dataset]["embeddings"][index].cpu().float()
    
    def _get_scores(self, dataset, index):
        scores = []
        for model in OpenLLM2.CONFIGS[self.config].keys():
            stat_value = self.data[model][dataset]["scores"][index]
            scores.append(stat_value)
            
        return torch.stack(scores).float()
    
    def _get_input_tokens(self, dataset, index):
        input_tokens = []
        for model in OpenLLM2.CONFIGS[self.config].keys():
            input_tokens.append(self.data[model][dataset]["input_tokens"][index])
            
        return torch.tensor(input_tokens).float()
    
    def _get_output_tokens(self, dataset, index):
        output_tokens = []
        for model in OpenLLM2.CONFIGS[self.config].keys():
            output_tokens.append(self.data[model][dataset]["output_tokens"][index])
            
        return torch.tensor(output_tokens).float()
    
    def __getitem__(self, idx):
        dataset, index = self.idx_mapping[idx]
        
        prompt = self._get_prompt(dataset, index)
        embedding = self._get_embedding(dataset, index)
        scores = self._get_scores(dataset, index)
        input_tokens = self._get_input_tokens(dataset, index)
        output_tokens = self._get_output_tokens(dataset, index)
        costs = get_costs(input_tokens, output_tokens, self.routing_config)
        
        return {
            "idx": idx,
            "routing_config": self.routing_config,
            "prompt": prompt,
            "embedding": embedding,
            "scores": scores,
            "costs": torch.tensor(costs).float(),
            "input_tokens": input_tokens,
            "output_tokens": output_tokens,
        }
        
    def get_benchmark(self):
        prompts = []
        embeddings = []
        scores = []
        input_tokens = []
        output_tokens = []
        for dataset in self.prompt_idx.keys():
            for index in self.prompt_idx[dataset]:
                prompts.append(self._get_prompt(dataset, index))
                embeddings.append(self._get_embedding(dataset, index))
                scores.append(self._get_scores(dataset, index))
                input_tokens.append(self._get_input_tokens(dataset, index))
                output_tokens.append(self._get_output_tokens(dataset, index))
        embeddings = torch.stack(embeddings)
        scores = torch.stack(scores)
        input_tokens = torch.stack(input_tokens)
        output_tokens = torch.stack(output_tokens)
        costs = [get_costs(inp, out, self.routing_config) for inp, out in zip(input_tokens, output_tokens)]
        
        return {
            "routing_config": self.routing_config,
            "prompts": prompts,
            "embeddings": embeddings,
            "scores": scores,
            "costs": torch.tensor(costs).float(),
            "input_tokens": input_tokens,
            "output_tokens": output_tokens,
        }