import random
import torch
from torch.utils.data import Dataset

from llm_router.data.utils import get_costs

class RouterBench(Dataset):
    CONFIGS = {
        "all": {
            'GPT-3.5': {"prompt": 1.0, "completion": 2.0},
            'Claude Instant V1': {"prompt": 0.8, "completion": 2.4},
            'Claude V1': {"prompt": 8.0, "completion": 24.0},
            'Claude V2': {"prompt": 8.0, "completion": 24.0},
            'GPT-4': {"prompt": 10.0, "completion": 30.0},
            'Llama 70B': {"prompt": 0.9, "completion": 0.9},
            'Mixtral 8x7B': {"prompt": 0.6, "completion": 0.6},
            'Yi 34B': {"prompt": 0.8, "completion": 0.8},
            'WizardLM 13B': {"promot": 0.3, "completion": 0.3},
            'Code Llama 34B': {"prompt": 0.776, "completion": 0.776},
            'Mistral 7B': {"prompt": 0.2, "completion": 0.2}
        }
    }
    
    DATASETS = [
        'arc-challenge',
        'grade-school-math',
        'hellaswag',
        'mbpp',
        'winogrande',
        'mmlu-abstract-algebra',
        'mmlu-anatomy',
        'mmlu-astronomy',
        'mmlu-business-ethics',
        'mmlu-clinical-knowledge',
        'mmlu-college-biology',
        'mmlu-college-chemistry',
        'mmlu-college-computer-science',
        'mmlu-college-mathematics',
        'mmlu-college-medicine',
        'mmlu-college-physics',
        'mmlu-computer-security',
        'mmlu-conceptual-physics',
        'mmlu-econometrics',
        'mmlu-electrical-engineering',
        'mmlu-elementary-mathematics',
        'mmlu-formal-logic',
        'mmlu-global-facts',
        'mmlu-high-school-biology',
        'mmlu-high-school-chemistry',
        'mmlu-high-school-computer-science',
        'mmlu-high-school-european-history',
        'mmlu-high-school-geography',
        'mmlu-high-school-government-and-politics',
        'mmlu-high-school-macroeconomics',
        'mmlu-high-school-mathematics',
        'mmlu-high-school-microeconomics',
        'mmlu-high-school-physics',
        'mmlu-high-school-psychology',
        'mmlu-high-school-statistics',
        'mmlu-high-school-us-history',
        'mmlu-high-school-world-history',
        'mmlu-human-aging',
        'mmlu-human-sexuality',
        'mmlu-international-law',
        'mmlu-jurisprudence',
        'mmlu-logical-fallacies',
        'mmlu-machine-learning',
        'mmlu-management',
        'mmlu-marketing',
        'mmlu-medical-genetics',
        'mmlu-miscellaneous',
        'mmlu-moral-disputes',
        'mmlu-moral-scenarios',
        'mmlu-nutrition',
        'mmlu-philosophy',
        'mmlu-prehistory',
        'mmlu-professional-accounting',
        'mmlu-professional-law',
        'mmlu-professional-medicine',
        'mmlu-professional-psychology',
        'mmlu-public-relations',
        'mmlu-security-studies',
        'mmlu-sociology',
        'mmlu-us-foreign-policy',
        'mmlu-virology',
        'mmlu-world-religions',
    ]
    
    def __init__(self, data_file: str, embed_file: str, config: str, split: str, seed: int):
        super().__init__()
        
        self.data = torch.load(data_file)
        self.embed = torch.load(embed_file)
        self.config = config
        self.split = split
        
        # splits
        dataset_size = {
            d: (len(self.embed[d]["prompts"]), int(len(self.embed[d]["prompts"]) * 0.7), len(self.embed[d]["prompts"]) - int(len(self.embed[d]["prompts"]) * 0.7))
            for d in RouterBench.DATASETS
        }
        
        random.seed(seed)
        train_idx = {
            d: list(random.sample(range(len(self.embed[d]["prompts"])), k=dataset_size[d][1]))
            for d in RouterBench.DATASETS
        }
        test_idx = {
            d: [i for i in range(len(self.embed[d]["prompts"])) if i not in train_idx[d]]
            for d in RouterBench.DATASETS
        }
        all_idx = {
            d: list(random.sample(range(len(self.embed[d]["prompts"])), k=dataset_size[d][0]))
            for d in RouterBench.DATASETS
        }
        
        if split == "train":
            self.prompt_idx = train_idx
        elif split == "test":
            self.prompt_idx = test_idx
        elif split == "all":
            self.prompt_idx = all_idx
        elif split == "arcc_train":
            self.prompt_idx = {"arc-challenge": train_idx["arc-challenge"]}
        elif split == "arcc_test":
            self.prompt_idx = {"arc-challenge": test_idx["arc-challenge"]}
        elif split == "gsm_train":
            self.prompt_idx = {"grade-school-math": train_idx["grade-school-math"]}
        elif split == "gsm_test":
            self.prompt_idx = {"grade-school-math": test_idx["grade-school-math"]}
        elif split == "mmlu_train":
            self.prompt_idx = {d: train_idx[d] for d in RouterBench.DATASETS if d.startswith("mmlu-")}
        elif split == "mmlu_test":
            self.prompt_idx = {d: test_idx[d] for d in RouterBench.DATASETS if d.startswith("mmlu-")}
        elif split == "hellaswag_train":
            self.prompt_idx = {"hellaswag": train_idx["hellaswag"]}
        elif split == "hellaswag_test":
            self.prompt_idx = {"hellaswag": test_idx["hellaswag"]}
        elif split == "mbpp_train":
            self.prompt_idx = {"mbpp": train_idx["mbpp"]}
        elif split == "mbpp_test":
            self.prompt_idx = {"mbpp": test_idx["mbpp"]}
        elif split == "winogrande_train":
            self.prompt_idx = {"winogrande": train_idx["winogrande"]}
        elif split == "winogrande_test":
            self.prompt_idx = {"winogrande": test_idx["winogrande"]}
            
        # used for looking up
        idx_mapping = {}
        idx = 0
        for d in self.prompt_idx.keys():
            for index in self.prompt_idx[d]:
                idx_mapping[idx] = (d, index)
                idx += 1
        self.idx_mapping = idx_mapping
        
    @property
    def name(self):
        return f"routerbench::{self.config}::{self.split}"
    
    @property
    def routing_config(self):
        return RouterBench.CONFIGS[self.config]
    
    def __len__(self):
        return sum(len(v) for v in self.prompt_idx.values())
    
    def _get_prompt(self, dataset, index):
        return self.embed[dataset]["prompts"][index]
    
    def _get_embedding(self, dataset, index):
        return self.embed[dataset]["embeddings"][index].cpu().float()
    
    def _get_scores(self, dataset, index):
        scores = []
        for model in RouterBench.CONFIGS[self.config].keys():
            score = self.data[dataset][model]["scores"][index]
            scores.append(score)
            
        return torch.tensor(scores).float()
    
    def _get_input_tokens(self, dataset, index):
        input_tokens = []
        for model in RouterBench.CONFIGS[self.config].keys():
            input_tokens.append(self.data[dataset][model]["input_tokens"][index])
            
        return torch.tensor(input_tokens).float()
    
    def _get_output_tokens(self, dataset, index):
        output_tokens = []
        for model in RouterBench.CONFIGS[self.config].keys():
            output_tokens.append(self.data[dataset][model]["output_tokens"][index])
            
        return torch.tensor(output_tokens).float()
    
    def __getitem__(self, idx):
        dataset, index = self.idx_mapping[idx]
        
        prompt = self._get_prompt(dataset, index)
        embedding = self._get_embedding(dataset, index)
        scores = self._get_scores(dataset, index)
        input_tokens = self._get_input_tokens(dataset, index)
        output_tokens = self._get_output_tokens(dataset, index)
        costs = get_costs(input_tokens, output_tokens, self.routing_config)
        
        return {
            "idx": idx,
            "routing_config": self.routing_config,
            "prompt": prompt,
            "embedding": embedding,
            "scores": scores,
            "costs": torch.tensor(costs).float(),
            "input_tokens": input_tokens,
            "output_tokens": output_tokens,
        }
        
    def get_benchmark(self):
        prompts = []
        embeddings = []
        scores = []
        input_tokens = []
        output_tokens = []
        for dataset in self.prompt_idx.keys():
            for index in self.prompt_idx[dataset]:
                prompts.append(self._get_prompt(dataset, index))
                embeddings.append(self._get_embedding(dataset, index))
                scores.append(self._get_scores(dataset, index))
                input_tokens.append(self._get_input_tokens(dataset, index))
                output_tokens.append(self._get_output_tokens(dataset, index))
        embeddings = torch.stack(embeddings)
        scores = torch.stack(scores)
        input_tokens = torch.stack(input_tokens)
        output_tokens = torch.stack(output_tokens)
        costs = [get_costs(inp, out, self.routing_config) for inp, out in zip(input_tokens, output_tokens)]
        
        return {
            "routing_config": RouterBench.CONFIGS[self.config],
            "prompts": prompts,
            "embeddings": embeddings,
            "scores": scores,
            "costs": torch.tensor(costs).float(),
            "input_tokens": input_tokens,
            "output_tokens": output_tokens,
        }