import random
import torch
from torch.utils.data import Dataset

from llm_router.data.utils import get_costs

class VHELM(Dataset):
    CONFIGS = {
        "all": {
            'anthropic_claude-3-5-sonnet-20240620': None,
            'anthropic_claude-3-5-sonnet-20241022': None,
            'anthropic_claude-3-7-sonnet-20250219': None,
            'anthropic_claude-3-7-sonnet-20250219-thinking-64k': None,
            'anthropic_claude-3-haiku-20240307': None,
            'anthropic_claude-3-opus-20240229': None,
            'anthropic_claude-3-sonnet-20240229': None,
            'google_gemini-1.0-pro-vision-001': None,
            'google_gemini-1.5-flash-001-safety-block-none': None,
            'google_gemini-1.5-flash-002': None,
            'google_gemini-1.5-flash-preview-0514': None,
            'google_gemini-1.5-pro-001-safety-block-none': None,
            'google_gemini-1.5-pro-002': None,
            'google_gemini-1.5-pro-preview-0409': None,
            'google_gemini-1.5-pro-preview-0514': None,
            'google_gemini-2.0-flash-001': None,
            'google_gemini-2.0-flash-exp': None,
            'google_gemini-2.0-flash-lite-001': None,
            'google_gemini-2.0-flash-lite-preview-02-05': None,
            'google_gemini-2.0-flash-thinking-exp-01-21': None,
            'google_gemini-2.0-pro-exp-02-05': None,
            'google_gemini-2.5-pro-exp-03-25': None,
            'openai_gpt-4-turbo-2024-04-09': None,
            'openai_gpt-4.1-2025-04-14': None,
            'openai_gpt-4.1-mini-2025-04-14': None,
            'openai_gpt-4.1-nano-2025-04-14': None,
            'openai_gpt-4.5-preview-2025-02-27': None,
            'openai_gpt-4o-2024-05-13': None,
            'openai_gpt-4o-2024-08-06': None,
            'openai_gpt-4o-2024-11-20': None,
            'openai_gpt-4o-mini-2024-07-18': None,
            'openai_o1-2024-12-17': None,
            'openai_o3-2025-04-16-high-reasoning-effort': None,
            'openai_o4-mini-2025-04-16-high-reasoning-effort': None,
            'qwen_qwen-vl-chat': None,
            'qwen_qwen2-vl-72b-instruct': None,
            'qwen_qwen2-vl-7b-instruct': None,
            'qwen_qwen2.5-vl-32b-instruct': None,
            'qwen_qwen2.5-vl-3b-instruct': None,
            'qwen_qwen2.5-vl-72b-instruct': None,
            'qwen_qwen2.5-vl-7b-instruct': None,
        },
        "claude": {
            'anthropic_claude-3-5-sonnet-20240620': {"prompt": 3, "completion": 15},
            'anthropic_claude-3-5-sonnet-20241022': {"prompt": 3, "completion": 15},
            'anthropic_claude-3-7-sonnet-20250219': {"prompt": 3, "completion": 15},
            'anthropic_claude-3-7-sonnet-20250219-thinking-64k': {"prompt": 3, "completion": 15},
            'anthropic_claude-3-haiku-20240307': {"prompt": 0.8, "completion": 4},
            'anthropic_claude-3-opus-20240229': {"prompt": 15, "completion": 75},
            'anthropic_claude-3-sonnet-20240229': {"prompt": 3, "completion": 15},
        },
        "gpt": {
            'openai_gpt-4-turbo-2024-04-09': {"prompt": 10, "completion": 30},
            'openai_gpt-4.1-2025-04-14': {"prompt": 2, "completion": 8},
            'openai_gpt-4.1-mini-2025-04-14': {"prompt": 0.4, "completion": 1.6},
            'openai_gpt-4.1-nano-2025-04-14': {"prompt": 0.1, "completion": 0.4},
            'openai_gpt-4.5-preview-2025-02-27': {"prompt": 75, "completion": 150},
            'openai_gpt-4o-2024-05-13': {"prompt": 5, "completion": 15},
            'openai_gpt-4o-2024-08-06': {"prompt": 2.5, "completion": 10},
            'openai_gpt-4o-2024-11-20': {"prompt": 2.5, "completion": 10},
            'openai_gpt-4o-mini-2024-07-18': {"prompt": 0.15, "completion": 0.6},
            'openai_o1-2024-12-17': {"prompt": 15, "completion": 60},
            'openai_o3-2025-04-16-high-reasoning-effort': {"prompt": 10, "completion": 40},
            'openai_o4-mini-2025-04-16-high-reasoning-effort': {"prompt": 1.1, "completion": 4.4},
        },
        "google": {
            'google_gemini-1.0-pro-vision-001': {"prompt": 1.25, "completion": 10},
            'google_gemini-1.5-flash-002': {"prompt": 0.15, "completion": 0.6},
            'google_gemini-1.5-flash-preview-0514': {"prompt": 0.15, "completion": 0.6},
            'google_gemini-1.5-pro-002': {"prompt": 1.25, "completion": 10},
            'google_gemini-1.5-pro-preview-0409': {"prompt": 1.25, "completion": 10},
            'google_gemini-1.5-pro-preview-0514': {"prompt": 1.25, "completion": 10},
            'google_gemini-2.0-flash-001': {"prompt": 0.15, "completion": 0.6},
            'google_gemini-2.0-flash-exp': {"prompt": 0.15, "completion": 0.6},
            'google_gemini-2.0-flash-lite-001': {"prompt": 0.075, "completion": 0.3},
            'google_gemini-2.0-flash-lite-preview-02-05': {"prompt": 0.075, "completion": 0.3},
            'google_gemini-2.0-flash-thinking-exp-01-21': {"prompt": 0.15, "completion": 0.6},
            'google_gemini-2.0-pro-exp-02-05': {"prompt": 1.25, "completion": 10},
            'google_gemini-2.5-pro-exp-03-25': {"prompt": 1.25, "completion": 10},
        }
    }
    
    DATASETS = [
        'vqa',
        'vibe_eval:subject=difficulty-hard',
        'vibe_eval:subject=difficulty-normal',
        'unicorn:subject=OODCV-VQA',
        'unicorn:subject=Sketchy-VQA',
        'seed_bench:subject=visual-reasoning',
        'seed_bench:subject=instance-interaction',
        'real_world_qa',
        'mme:subject=posters',
        'mme:subject=landmark',
        'mme:subject=artwork',
        'mme:subject=celebrity',
        'math_vista:grade=elementary_school,question_type=free_form',
        'math_vista:grade=high_school,question_type=free_form',
        'math_vista:grade=college,question_type=free_form',
        'math_vista:grade=daily_life,question_type=free_form',
        'gqa',
        'flickr30k',
        'mmmu:subject=Pharmacy,question_type=multiple-choice',
        'mmmu:subject=Diagnostics_and_Laboratory_Medicine,question_type=multiple-choice',
        'mmmu:subject=Biology,question_type=multiple-choice',
        'mmmu:subject=Manage,question_type=multiple-choice',
        'mmmu:subject=Basic_Medical_Science,question_type=multiple-choice',
        'mmmu:subject=Psychology,question_type=multiple-choice',
        'mmmu:subject=Music,question_type=multiple-choice',
        'mmmu:subject=Art_Theory,question_type=multiple-choice',
        'mmmu:subject=Literature,question_type=multiple-choice',
        'mmmu:subject=Agriculture,question_type=multiple-choice',
        'mmmu:subject=Art,question_type=multiple-choice',
        'mmmu:subject=Mechanical_Engineering,question_type=multiple-choice',
        'mmmu:subject=Architecture_and_Engineering,question_type=multiple-choice',
        'mmmu:subject=Electronics,question_type=multiple-choice',
        'mmmu:subject=Accounting,question_type=multiple-choice',
        'mmmu:subject=Physics,question_type=multiple-choice',
        'mmmu:subject=Geography,question_type=multiple-choice',
        'mmmu:subject=Math,question_type=multiple-choice',
        'mmmu:subject=Design,question_type=multiple-choice',
        'mmmu:subject=Materials,question_type=multiple-choice',
        'mmmu:subject=Marketing,question_type=multiple-choice',
        'mmmu:subject=History,question_type=multiple-choice',
        'mmmu:subject=Sociology,question_type=multiple-choice',
        'mmmu:subject=Energy_and_Power,question_type=multiple-choice',
        'mmmu:subject=Chemistry,question_type=multiple-choice',
        'mmmu:subject=Computer_Science,question_type=multiple-choice',
        'mmmu:subject=Economics,question_type=multiple-choice',
        'mmmu:subject=Clinical_Medicine,question_type=multiple-choice',
        'mmmu:subject=Public_Health,question_type=multiple-choice',
        'mmmu:subject=Finance,question_type=multiple-choice',
        'blink:category=Relative_Depth',
        'blink:category=Visual_Similarity',
        'blink:category=Jigsaw',
        'blink:category=Forensic_Detection',
        'blink:category=IQ_Test',
        'blink:category=Semantic_Correspondence',
        'blink:category=Visual_Correspondence',
        'blink:category=Multi-view_Reasoning',
        'blink:category=Spatial_Relation',
        'blink:category=Functional_Correspondence',
        'blink:category=Object_Localization',
        'blink:category=Art_Style',
        'blink:category=Counting',
        'blink:category=Relative_Reflectance'
    ]
    
    def __init__(self, data_file: str, embed_file: str, config: str, split: str, seed: int):
        super().__init__()
        
        self.data = torch.load(data_file)
        self.embed = torch.load(embed_file)
        self.config = config
        self.split = split
        
        # splits
        dataset_size = {
            d: (len(self.embed[d]["contents"]), int(len(self.embed[d]["contents"]) * 0.7), len(self.embed[d]["contents"]) - int(len(self.embed[d]["contents"]) * 0.7))
            for d in VHELM.DATASETS
        }
        
        random.seed(seed)
        train_idx = {
            d: list(random.sample(range(len(self.embed[d]["contents"])), k=dataset_size[d][1]))
            for d in VHELM.DATASETS
        }
        test_idx = {
            d: [i for i in range(len(self.embed[d]["contents"])) if i not in train_idx[d]]
            for d in VHELM.DATASETS
        }
        all_idx = {
            d: list(random.sample(range(len(self.embed[d]["contents"])), k=dataset_size[d][0]))
            for d in VHELM.DATASETS
        }
        
        if split == "train":
            self.prompt_idx = train_idx
        elif split == "test":
            self.prompt_idx = test_idx
        elif split == "all":
            self.prompt_idx = all_idx
        elif split == "blink_train":
            self.prompt_idx = {d: train_idx[d] for d in VHELM.DATASETS if d.startswith("blink")}
        elif split == "blink_test":
            self.prompt_idx = {d: test_idx[d] for d in VHELM.DATASETS if d.startswith("blink")}
        elif split == "math_vista_train":
            self.prompt_idx = {d: train_idx[d] for d in VHELM.DATASETS if d.startswith("math_vista")}
        elif split == "math_vista_test":
            self.prompt_idx = {d: test_idx[d] for d in VHELM.DATASETS if d.startswith("math_vista")}
        elif split == "mme_train":
            self.prompt_idx = {d: train_idx[d] for d in VHELM.DATASETS if d.startswith("mme")}
        elif split == "mme_test":
            self.prompt_idx = {d: test_idx[d] for d in VHELM.DATASETS if d.startswith("mme")}
        elif split == "mmmu_train":
            self.prompt_idx = {d: train_idx[d] for d in VHELM.DATASETS if d.startswith("mmmu")}
        elif split == "mmmu_test":
            self.prompt_idx = {d: test_idx[d] for d in VHELM.DATASETS if d.startswith("mmmu")}
        elif split == "unicorn_train":
            self.prompt_idx = {d: train_idx[d] for d in VHELM.DATASETS if d.startswith("unicorn")}
        elif split == "unicorn_test":
            self.prompt_idx = {d: test_idx[d] for d in VHELM.DATASETS if d.startswith("unicorn")}
        elif split == "seed_bench_train":
            self.prompt_idx = {d: train_idx[d] for d in VHELM.DATASETS if d.startswith("seed_bench")}
        elif split == "seed_bench_test":
            self.prompt_idx = {d: test_idx[d] for d in VHELM.DATASETS if d.startswith("seed_bench")}
        elif split == "vibe_eval_train":
            self.prompt_idx = {d: train_idx[d] for d in VHELM.DATASETS if d.startswith("vibe_eval")}
        elif split == "vibe_eval_test":
            self.prompt_idx = {d: test_idx[d] for d in VHELM.DATASETS if d.startswith("vibe_eval")}
        elif split == "vqa_train":
            self.prompt_idx = {"vqa": train_idx["vqa"]}
        elif split == "vqa_test":
            self.prompt_idx = {"vqa": test_idx["vqa"]}
        elif split == "real_world_qa_train":
            self.prompt_idx = {"real_world_qa": train_idx["real_world_qa"]}
        elif split == "real_world_qa_test":
            self.prompt_idx = {"real_world_qa": test_idx["real_world_qa"]}
        elif split == "gqa_train":
            self.prompt_idx = {"gqa": train_idx["gqa"]}
        elif split == "gqa_test":
            self.prompt_idx = {"gqa": test_idx["gqa"]}
        elif split == "flickr30k_train":
            self.prompt_idx = {"flickr30k": train_idx["flickr30k"]}
        elif split == "flickr30k_test":
            self.prompt_idx = {"flickr30k": test_idx["flickr30k"]}
            
        # used for looking up
        idx_mapping = {}
        idx = 0
        for d in self.prompt_idx.keys():
            for index in self.prompt_idx[d]:
                idx_mapping[idx] = (d, index)
                idx += 1
        self.idx_mapping = idx_mapping
        
    @property
    def name(self):
        return f"vhelm::{self.config}::{self.split}"
    
    @property
    def routing_config(self):
        return VHELM.CONFIGS[self.config]
    
    def __len__(self):
        return sum(len(v) for v in self.prompt_idx.values())
    
    def _get_prompt(self, dataset, index):
        return self.embed[dataset]["contents"][index]
    
    def _get_embedding(self, dataset, index):
        return self.embed[dataset]["embeddings"][index].cpu().float()
    
    def _get_scores(self, dataset, index):
        scores = []
        for model in VHELM.CONFIGS[self.config].keys():
            score = self.data[dataset][model]["values"][index]
            scores.append(score)
            
        return torch.tensor(scores).float()
    
    def _get_input_tokens(self, dataset, index):
        input_tokens = []
        for model in VHELM.CONFIGS[self.config].keys():
            input_tokens.append(self.data[dataset][model]["input_tokens"][index])
            
        return torch.tensor(input_tokens).float()
    
    def _get_output_tokens(self, dataset, index):
        output_tokens = []
        for model in VHELM.CONFIGS[self.config].keys():
            output_tokens.append(self.data[dataset][model]["output_tokens"][index])
            
        return torch.tensor(output_tokens).float()
    
    def __getitem__(self, idx):
        dataset, index = self.idx_mapping[idx]
        
        prompt = self._get_prompt(dataset, index)
        embedding = self._get_embedding(dataset, index)
        scores = self._get_scores(dataset, index)
        input_tokens = self._get_input_tokens(dataset, index)
        output_tokens = self._get_output_tokens(dataset, index)
        costs = get_costs(input_tokens, output_tokens, self.routing_config)
        
        return {
            "idx": idx,
            "routing_config": self.routing_config,
            "prompt": prompt,
            "embedding": embedding,
            "scores": scores,
            "costs": torch.tensor(costs).float(),
            "input_tokens": input_tokens,
            "output_tokens": output_tokens,
        }
        
    def get_benchmark(self):
        prompts = []
        embeddings = []
        scores = []
        input_tokens = []
        output_tokens = []
        for dataset in self.prompt_idx.keys():
            for index in self.prompt_idx[dataset]:
                prompts.append(self._get_prompt(dataset, index))
                embeddings.append(self._get_embedding(dataset, index))
                scores.append(self._get_scores(dataset, index))
                input_tokens.append(self._get_input_tokens(dataset, index))
                output_tokens.append(self._get_output_tokens(dataset, index))
        embeddings = torch.stack(embeddings)
        scores = torch.stack(scores)
        input_tokens = torch.stack(input_tokens)
        output_tokens = torch.stack(output_tokens)
        costs = [get_costs(inp, out, self.routing_config) for inp, out in zip(input_tokens, output_tokens)]
        
        return {
            "routing_config": self.routing_config,
            "prompts": prompts,
            "embeddings": embeddings,
            "scores": scores,
            "costs": torch.tensor(costs).float(),
            "input_tokens": input_tokens,
            "output_tokens": output_tokens,
        }