import random
import torch
from torch.utils.data import Dataset

from llm_router.data.utils import get_costs

class HELMLite(Dataset):
    CONFIGS = {
        "all": {
            'meta_llama-3-8b': None, 
            'meta_llama-3-70b': None, 
            'meta_llama-2-7b': None,
            'meta_llama-2-70b': None, 
            'meta_llama-2-13b': None,  
            'meta_llama-65b': None, 
            
            'google_gemini-1.0-pro-002': None, 
            'google_gemini-1.0-pro-001': None, 
            'google_gemini-1.5-pro-001': None, 
            'google_gemini-1.5-pro-preview-0409': None, 
            'google_gemini-1.5-flash-001': None, 
            'google_text-bison@001': None, 
            'google_text-unicorn@001': None,
            'google_gemma-2-9b-it': None, 
            'google_gemma-2-27b-it': None, 
            'google_gemma-7b': None, 
            
            'cohere_command-r': None, 
            'cohere_command': None, 
            'cohere_command-r-plus': None, 
            'cohere_command-light': None, 
            
            'anthropic_claude-3-5-sonnet-20240620': None, 
            'anthropic_claude-3-opus-20240229': None, 
            'anthropic_claude-3-sonnet-20240229': None, 
            'anthropic_claude-3-haiku-20240307': None, 
            'anthropic_claude-2.0': None, 
            'anthropic_claude-instant-v1': None, 
            'anthropic_claude-v1.3': None, 
            'anthropic_claude-2.1': None, 
            'anthropic_claude-instant-1.2': None, 
            
            'mistralai_mistral-7b-instruct-v0.3': None,
            'mistralai_mixtral-8x7b-32kseqlen': None, 
            'mistralai_mistral-medium-2312': None,  
            'mistralai_mistral-large-2407': None, 
            'mistralai_open-mistral-nemo-2407': None, 
            'mistralai_mixtral-8x22b': None, 
            'mistralai_mistral-7b-v0.1': None, 
            
            'openai_gpt-4o-2024-05-13': None, 
            'openai_gpt-4o-mini-2024-07-18': None, 
            'openai_gpt-3.5-turbo-0613': None, 
            'openai_gpt-4-0613': None, 
            'openai_gpt-4-turbo-2024-04-09': None, 
            'openai_gpt-4-1106-preview': None, 
            'openai_text-davinci-002': None, 
            'openai_text-davinci-003': None, 
            
            'microsoft_phi-3-medium-4k-instruct': None, 
            'microsoft_phi-2': None, 
            
            '01-ai_yi-6b': None, 
            '01-ai_yi-large-preview': None, 
            '01-ai_yi-34b': None, 
            
            'writer_palmyra-x-v3': None, 
            'writer_palmyra-x-v2': None, 
            
            'tiiuae_falcon-7b': None, 
            'tiiuae_falcon-40b': None,     
            
            'AlephAlpha_luminous-base': None, 
            'AlephAlpha_luminous-supreme': None, 
            'AlephAlpha_luminous-extended': None, 
            
            'qwen_qwen1.5-72b': None, 
            'qwen_qwen1.5-14b': None, 
            'qwen_qwen1.5-7b': None, 
            'qwen_qwen1.5-32b': None,
            'qwen_qwen1.5-110b-chat': None, 
            'qwen_qwen2-72b-instruct': None, 
            
            'ai21_j2-grande': None, 
            'ai21_j2-jumbo': None, 
            
            'allenai_olmo-7b': None, 
            'databricks_dbrx-instruct': None, 
            'deepseek-ai_deepseek-llm-67b-chat': None, 
            'snowflake_snowflake-arctic-instruct': None,
        },
        "google": {
            'google_gemini-1.0-pro-002': {"prompt": 0.5, "completion": 1.5}, 
            'google_gemini-1.0-pro-001': {"prompt": 0.5, "completion": 1.5}, 
            'google_gemini-1.5-pro-001': {"prompt": 3.5, "completion": 10.5}, 
            'google_gemini-1.5-flash-001': {"prompt": 0.075, "completion": 0.3}, 
            'google_text-bison@001': {"prompt": 0.5, "completion": 1.5}, 
            'google_text-unicorn@001': {"prompt": 7.0, "completion": 21.0},
            'google_gemma-2-9b-it': {"prompt": 0.2, "completion": 0.2}, 
            'google_gemma-2-27b-it': {"prompt": 0.6, "completion": 0.6}, 
            'google_gemma-7b': {"prompt": 0.1, "completion": 0.1}, 
        },
        "claude": {
            'anthropic_claude-3-5-sonnet-20240620': {"prompt": 3, "completion": 15},
            'anthropic_claude-3-opus-20240229':{"prompt": 15, "completion": 75},
            'anthropic_claude-3-sonnet-20240229': {"prompt": 3, "completion": 15},
            'anthropic_claude-3-haiku-20240307': {"prompt": 0.25, "completion": 1.25},
            'anthropic_claude-2.0': {"prompt": 8, "completion": 24},
            'anthropic_claude-instant-v1': {"prompt": 0.8, "completion": 2.4},
            'anthropic_claude-v1.3': {"prompt": 8, "completion": 24},
            'anthropic_claude-2.1': {"prompt": 8, "completion": 24},
            'anthropic_claude-instant-1.2': {"prompt": 0.8, "completion": 2.4},
        },
        "gpt": {
            'openai_gpt-4o-2024-05-13': {"prompt": 5.0, "completion": 15.0},
            'openai_gpt-4o-mini-2024-07-18': {"prompt": 0.15, "completion": 0.6},
            'openai_gpt-3.5-turbo-0613': {"prompt": 1.5, "completion": 2.0},
            'openai_gpt-4-0613': {"prompt": 30.0, "completion": 60.0}, 
            'openai_gpt-4-turbo-2024-04-09': {"prompt": 10.0, "completion": 30.0},
            'openai_gpt-4-1106-preview': {"prompt": 10.0, "completion": 30.0}, 
        }
    }
    
    DATASETS = [
        "narrative_qa",
        "natural_qa:mode=closedbook",
        "natural_qa:mode=openbook_longans",
        "commonsense:dataset=openbookqa,method=multiple_choice_joint",
        "mmlu:subject=abstract_algebra,method=multiple_choice_joint",
        "mmlu:subject=college_chemistry,method=multiple_choice_joint",
        "mmlu:subject=computer_security,method=multiple_choice_joint",
        "mmlu:subject=econometrics,method=multiple_choice_joint",
        "mmlu:subject=us_foreign_policy,method=multiple_choice_joint",
        'math:subject=algebra,level=1,use_official_examples=False,use_chain_of_thought=True',
        'math:subject=counting_and_probability,level=1,use_official_examples=False,use_chain_of_thought=True',
        'math:subject=geometry,level=1,use_official_examples=False,use_chain_of_thought=True',
        'math:subject=intermediate_algebra,level=1,use_official_examples=False,use_chain_of_thought=True',
        'math:subject=number_theory,level=1,use_official_examples=False,use_chain_of_thought=True',
        'math:subject=prealgebra,level=1,use_official_examples=False,use_chain_of_thought=True',
        'math:subject=precalculus,level=1,use_official_examples=False,use_chain_of_thought=True',
        'gsm',
        'legalbench:subset=abercrombie',
        'legalbench:subset=corporate_lobbying',
        'legalbench:subset=function_of_decision_section',
        'legalbench:subset=international_citizenship_questions',
        'legalbench:subset=proa',
        'med_qa',
        'wmt_14:language_pair=cs-en',
        'wmt_14:language_pair=de-en',
        'wmt_14:language_pair=fr-en',
        'wmt_14:language_pair=hi-en',
        'wmt_14:language_pair=ru-en',
    ]
    
    def __init__(self, data_file: str, embed_file: str, config: str, split: str, seed: int):
        super().__init__()
        
        self.data = torch.load(data_file)
        self.embed = torch.load(embed_file)
        self.config = config
        self.split = split
        
        # splits
        dataset_size = {
            d: (len(self.embed[d]["prompts"]), int(len(self.embed[d]["prompts"]) * 0.7), len(self.embed[d]["prompts"]) - int(len(self.embed[d]["prompts"]) * 0.7))
            for d in HELMLite.DATASETS
        }
        
        random.seed(seed)
        train_idx = {
            d: list(random.sample(range(len(self.embed[d]["prompts"])), k=dataset_size[d][1]))
            for d in HELMLite.DATASETS
        }
        test_idx = {
            d: [i for i in range(len(self.embed[d]["prompts"])) if i not in train_idx[d]]
            for d in HELMLite.DATASETS
        }
        all_idx = {
            d: list(random.sample(range(len(self.embed[d]["prompts"])), k=dataset_size[d][0]))
            for d in HELMLite.DATASETS
        }
        
        if split == "train":
            self.prompt_idx = train_idx
        elif split == "test":
            self.prompt_idx = test_idx
        elif split == "all":
            self.prompt_idx = all_idx
            
        # used for looking up
        idx_mapping = {}
        idx = 0
        for d in self.prompt_idx.keys():
            for index in self.prompt_idx[d]:
                idx_mapping[idx] = (d, index)
                idx += 1
        self.idx_mapping = idx_mapping
        
    @property
    def name(self):
        return f"helm_lite::{self.config}::{self.split}"
    
    @property
    def routing_config(self):
        return HELMLite.CONFIGS[self.config]
    
    def __len__(self):
        return sum(len(v) for v in self.prompt_idx.values())
    
    def _get_prompt(self, dataset, index):
        return self.embed[dataset]["prompts"][index]
    
    def _get_embedding(self, dataset, index):
        return self.embed[dataset]["embeddings"][index].cpu().float()
    
    def _get_scores(self, dataset, index):
        scores = []
        for model in HELMLite.CONFIGS[self.config].keys():
            stat_type = self.data[model][dataset]["stat_type"]
            stat_name = self.data[model][dataset]["stat_name"]
            stat_value = self.data[model][dataset]["values"][index]
            scores.append(stat_value)
            
        return torch.tensor(scores).float()
    
    def _get_input_tokens(self, dataset, index):
        input_tokens = []
        for model in HELMLite.CONFIGS[self.config].keys():
            input_tokens.append(self.data[model][dataset]["input_tokens"][index])
            
        return torch.tensor(input_tokens).float()
    
    def _get_output_tokens(self, dataset, index):
        output_tokens = []
        for model in HELMLite.CONFIGS[self.config].keys():
            output_tokens.append(self.data[model][dataset]["output_tokens"][index])
            
        return torch.tensor(output_tokens).float()
    
    def __getitem__(self, idx):
        dataset, index = self.idx_mapping[idx]
        
        prompt = self._get_prompt(dataset, index)
        embedding = self._get_embedding(dataset, index)
        scores = self._get_scores(dataset, index)
        input_tokens = self._get_input_tokens(dataset, index)
        output_tokens = self._get_output_tokens(dataset, index)
        costs = get_costs(input_tokens, output_tokens, self.routing_config)
        
        return {
            "idx": idx,
            "routing_config": self.routing_config,
            "prompt": prompt,
            "embedding": embedding,
            "scores": scores,
            "costs": torch.tensor(costs).float(),
            "input_tokens": input_tokens,
            "output_tokens": output_tokens,
        }
        
    def get_benchmark(self):
        prompts = []
        embeddings = []
        scores = []
        input_tokens = []
        output_tokens = []
        for dataset in self.prompt_idx.keys():
            for index in self.prompt_idx[dataset]:
                prompts.append(self._get_prompt(dataset, index))
                embeddings.append(self._get_embedding(dataset, index))
                scores.append(self._get_scores(dataset, index))
                input_tokens.append(self._get_input_tokens(dataset, index))
                output_tokens.append(self._get_output_tokens(dataset, index))
        embeddings = torch.stack(embeddings)
        scores = torch.stack(scores)
        input_tokens = torch.stack(input_tokens)
        output_tokens = torch.stack(output_tokens)
        costs = [get_costs(inp, out, self.routing_config) for inp, out in zip(input_tokens, output_tokens)]
        
        return {
            "routing_config": self.routing_config,
            "prompts": prompts,
            "embeddings": embeddings,
            "scores": scores,
            "costs": torch.tensor(costs).float(),
            "input_tokens": input_tokens,
            "output_tokens": output_tokens,
        }