import json
import re
from datasets import load_dataset
from utils.data_creator import MMLUPromptBuilder, TydiqaPromptBuilder, BBHPromptBuilder
import os
import random
random.seed(2025)


class DatasetLoader:
    def __init__(self,dataset_path, ntrain=None,tokenizer = None):
        self.dataset_loaders = {
            "gsm8k": self.gsm8k_data_loader,
            "codex": self.codex_data_loader,
            "mmlu": self.mmlu_data_loader,
            "tydiqa": self.tydiqa_data_loader,
            "bbh": self.bbh_data_loader,
            "alpaca_gpt4": self.alpaca_gpt4_data_loader,
            "translate": self.translate_data_loader,
            "comprehensive": self.comprehensive_data_loader,
            "evol_instruct": self.evol_instruct_data_loader,
        }

        self.dataset_path = dataset_path
        self.ntrain = ntrain
        self.tokenizer = tokenizer

    def load_dataset(self, dataset_name, lang=None):
        if dataset_name == "all":
            # Load all datasets
            all_datasets = {}
            for name, loader in self.dataset_loaders.items():
                print(f"Loading dataset: {name}...")
                all_datasets[name] = loader()
            return all_datasets
        elif dataset_name == "translate" and lang is None:
            raise ValueError("In translate mode, please give the lang parameter")
        elif dataset_name in self.dataset_loaders:
            self.lang = lang
            # Load a single dataset
            return self.dataset_loaders[dataset_name]()
        else:
            raise ValueError(f"Dataset '{dataset_name}' is not supported. Supported datasets: {list(self.dataset_loaders.keys()) + ['all']}")

    def _find_answer_index(self,context, answer):
        context_tokens = self.tokenizer(context, return_offsets_mapping=False, add_special_tokens=False)["input_ids"]
        answer_tokens = self.tokenizer(answer, add_special_tokens=False)["input_ids"]
        # print(context_tokens)
        # print(answer_tokens)
        # Attempt to find partial matches
        for i in range(len(context_tokens) - len(answer_tokens) + 1):
            if context_tokens[i:i + len(answer_tokens)] == answer_tokens:
                return i  # Return the starting index of the match

        raise ValueError(f"Answer not found in the given context.\nContext: {context}\nAnswer: {answer}")

    def translate_data_loader(self):
        from utils.data_creator import translate_creator
        file_path = os.path.join(self.dataset_path,"ALMA/human_written_data")
        lang = self.lang
        translate_dataset = translate_creator(file_path,lang,chat_style=False,tokenizer=self.tokenizer)

        return translate_dataset

    def gsm8k_data_loader(self):
    #     dataset = []
    #     with open(os.path.join(self.dataset_path, "gsm/train.jsonl")) as fin:
    #         for line in fin:
    #             example = json.loads(line)
    #             dataset.append({
    #                 "question": example["question"],
    #                 "answer_with_cot": example["answer"].split("####")[0].strip() + f"\nSo the answer is {example['answer'].split('####')[1].strip()}.",
    #                 "answer": example["answer"].split("####")[1].strip()
    #             })

    #         for example in dataset:
    #             # Remove commas in numbers
    #             example["answer"] = re.sub(r"(\d),(\d)", r"\1\2", example["answer"])
    #             try:
    #                 example["answer"] = float(example["answer"])
    #                 assert isinstance(example["answer"], float), f"answer is not a valid number: {example['answer']}"
    #             except ValueError:
    #                 raise ValueError(f"answer could not be converted to float: {example['answer']}")
                
    #             example["QAPair"] = example['question'].strip() + example['answer_with_cot'].strip()
    #             example["answer_start_index"] = self._find_answer_index(example["QAPair"].strip(), example["answer_with_cot"].strip())
    #     if self.ntrain != None:
    #         dataset = random.sample(dataset, ntrain)
    #     return dataset
        from utils.GSM_examplars import EXAMPLARS as GSM_EXAMPLARS
        from utils.data_creator import GSMPromptBuilder
        # chat_format_function = dynamic_import_function("eval.templates.create_prompt_with_tulu_chat_format")
        # 初始化 PromptBuilder
        gsm8k_dataset = GSMPromptBuilder(
            data_dir = os.path.join(self.dataset_path,"gsm/train.jsonl"),
            exemplars=GSM_EXAMPLARS,  # Few-shot 示例
            n_shot=0,                # 使用 8 个 Few-shot 示例
            no_cot=False,            # 启用 Chain-of-Thought (CoT)
            tokenizer = self.tokenizer,
            use_chat_format=False,   # 不使用聊天格式
            chat_formatting_function = None
        ).build_prompt()
        random.seed(2025)
        if self.ntrain != None:
            gsm8k_dataset = random.sample(gsm8k_dataset,self.ntrain)
            
        return gsm8k_dataset

    def codex_data_loader(self):
        dataset = []
        coding_dataset = load_dataset('parquet', data_files=os.path.join(self.dataset_path,'bigcodebench/data/v0.1.3-00000-of-00001.parquet'))
        for item in coding_dataset['train']:
            dataset.append({
                "question": item['instruct_prompt'].strip(),
                "answer_with_cot": f"{item['code_prompt'].strip()}\n{item['canonical_solution']}"
            })
        for example in dataset:
            example["QAPair"] = example['question'].strip() + example['answer_with_cot'].strip()
        if self.ntrain != None:
            dataset = random.sample(dataset, ntrain)
        return dataset


    def mmlu_data_loader(self):
        mmlu_prompt_builder = MMLUPromptBuilder(
            target_data_type="val",
            shot_data_type="dev",
            data_dir=os.path.join(self.dataset_path,"mmlu/data"),
            n_shot=0,  # Number of few-shot examples
            tokenizer=self.tokenizer,
            use_chat_format=False,
            chat_formatting_function=None
        )
        mmlu_dataset_prompts = mmlu_prompt_builder.build_all_subject_prompts()
        mmlu_dataset = mmlu_prompt_builder.sample_questions_round_robin(mmlu_dataset_prompts, self.ntrain)
            
        return mmlu_dataset

    def tydiqa_data_loader(self):
        tydiqa_builder = TydiqaPromptBuilder(
            data_dir=os.path.join(self.dataset_path,"tydiqa/tydiqa-goldp-v1.1-train.json"),
            tokenizer=self.tokenizer,
            n_shot=0,  # Number of few-shot examples
            no_context=False,  # Use context
            use_chat_format=False,
            chat_formatting_function=None,
            n_shot_data_dir=os.path.join(self.dataset_path,"tydiqa/tydiqa-goldp-v1.1-train.json")
        )
        tydiqa_dataset = tydiqa_builder.build_prompt()
        tydiqa_dataset = tydiqa_builder.sample_round_robin(tydiqa_dataset, self.ntrain)
        return tydiqa_dataset

    def bbh_data_loader(self):
        data_dir = os.path.join(self.dataset_path,"BBH/")
        bbh_prompt_builder = BBHPromptBuilder(
            data_dir=data_dir,
            no_cot=False,
            use_chat_format=False,
            chat_formatting_function=None,
            tokenizer=self.tokenizer,
        )
        all_prompts = bbh_prompt_builder.get_all_prompts_all_tasks()
        bbh_dataset = bbh_prompt_builder.sample_questions_round_robin(all_prompts=all_prompts, num_samples=self.ntrain)

        return bbh_dataset

    def alpaca_gpt4_data_loader(self):
        from utils.data_creator import alpaca_gpt4_creator
        from transformers import AutoTokenizer
        
        data_path = os.path.join(self.dataset_path,"Alpaca_GPT4/alpaca_gpt4.json")
        
        alpaca_gpt4_dataset_all = alpaca_gpt4_creator(data_path)
        
        return alpaca_gpt4_dataset_all
    
    def comprehensive_data_loader(self):
        # dataset = []
        with open(os.path.join(self.dataset_path,"comprehensive/comprehensive_dataset_ntrain_1140.json"), "r") as f:
            comprehensive_dataset = json.load(f)
        return comprehensive_dataset

    def evol_instruct_data_loader(self):
        dataset  = []
        evol_instruct_dataset = load_dataset('arrow', data_files=os.path.join(self.dataset_path,'WizardLM_evol_instruct_70k/train/data-00000-of-00001.arrow'))
        
        for item in evol_instruct_dataset['train']:
            dataset.append({
                "question": item['instruction'].strip(),
                "answer": item['output'].strip(),
                "QAPair": item['instruction'].strip() + "\n\n" + item['output'].strip(),
                "ori_data_format": item
            })
        return dataset