"""
The data processing class for the dataset involved in the experiments
"""

class DataProcessor:
    def __init__(self, dataset_name):
        self.dataset_name = dataset_name

    
    def _format_and_tokenize_gsm8k(self, examples, tokenizer, max_length=512):
        """
        Formats the prompt with a clear question/answer structure and tokenizes it.
        """

        questions = [f"Question:\n{q}\n\n Answer:\n" for q in examples["question"]]
        full_prompts = [q + a for q, a in zip(questions, examples["answer"])]

        tokenized_output = tokenizer(
            full_prompts,
            truncation=True,
            padding="max_length",
            max_length=max_length,
            return_tensors="pt" # Return PyTorch tensors
        )
        return tokenized_output

    def _format_and_tokenize_MetaMathQA(self, examples, tokenizer, max_length=512):
        """
        Formats the prompt with a clear question/answer structure and tokenizes it.
        """

        questions = [f"Question:\n{q}\n\n Answer:\n" for q in examples["query"]]
        full_prompts = [q + a for q, a in zip(questions, examples["response"])]

        tokenized_output = tokenizer(
            full_prompts,
            truncation=True,
            padding="max_length",
            max_length=max_length,
            return_tensors="pt" # Return PyTorch tensors
        )
        return tokenized_output

    def _format_and_tokenize_HumanEval(self, examples, tokenizer, max_length=512):
        """
        Formats the prompt with a clear question/answer structure and tokenizes it.
        """

        questions = [f"Question:\n{q}\n\n Answer:\n" for q in examples["prompt"]]
        full_prompts = [q + a for q, a in zip(questions, examples["canonical_solution"])]

        tokenized_output = tokenizer(
            full_prompts,
            truncation=True,
            padding="max_length",
            max_length=max_length,
            return_tensors="pt" # Return PyTorch tensors
        )
        return tokenized_output

    def _format_and_tokenize_MBPP(self, examples, tokenizer, max_length=512):
        """
        Formats the prompt with a clear question/answer structure and tokenizes it.
        """

        questions = [f"Question:\n{q}\n\n Answer:\n" for q in examples["text"]]
        full_prompts = [q + a for q, a in zip(questions, examples["code"])]

        tokenized_output = tokenizer(
            full_prompts,
            truncation=True,
            padding="max_length",
            max_length=max_length,
            return_tensors="pt" # Return PyTorch tensors
        )
        return tokenized_output

    def _format_and_tokenize_codealpaca(self, examples, tokenizer, max_length=512):
        """
        Formats the prompt with a clear question/answer structure and tokenizes it.
        """

        instructions = [f"Instruction:\n{q}\n\n" for q in examples["instruction"]]
        inputs = [f"Input:\n{q}\n\n" for q in examples["input"]]
        outputs = [f"Output:\n{q}\n\n" for q in examples["output"]]
        full_prompts = [q + i +o for q, i, o in zip(instructions, inputs, outputs)]

        tokenized_output = tokenizer(
            full_prompts,
            truncation=True,
            padding="max_length",
            max_length=max_length,
            return_tensors="pt" # Return PyTorch tensors
        )
        return tokenized_output

    def _format_and_tokenize_PubMedQA(self, examples, tokenizer, max_length=512):
        """
        Formats the prompt with a clear question/answer structure and tokenizes it.
        """

        questions = [f"Question:\n{q}\n\n Answer:\n" for q in examples["question"]]
        answers = [f"{long_answer}\n\nFinal Decision:\n{final_decision}" for long_answer, final_decision in zip(examples["long_answer"], examples["final_decision"])]
        full_prompts = [q + a for q, a in zip(questions, answers)]

        tokenized_output = tokenizer(
            full_prompts,
            truncation=True,
            padding="max_length",
            max_length=max_length,
            return_tensors="pt" # Return PyTorch tensors
        )
        return tokenized_output     

    def _format_and_tokenize_medicalo1(self,examples, tokenizer, max_length=1024):
        """
        Formats the prompt with a clear question/answer structure and tokenizes it.
        """

        questions = [f"Question:\n{q}\n\n Answer:\n" for q in examples["Question"]]
        full_prompts = [q + a for q, a in zip(questions, examples["Response"])]
        tokenized_output = tokenizer(
            full_prompts,
            truncation=True,
            padding="max_length",
            max_length=max_length,
            return_tensors="pt" # Return PyTorch tensors
        )
        return tokenized_output

    def format_and_tokenize(self, examples, tokenizer, max_length=512):
        formatted_dataset = None
        if 'gsm8k' in self.dataset_name.lower():
            formatted_dataset = self._format_and_tokenize_gsm8k(examples, tokenizer, max_length)
        elif 'metamathqa' in self.dataset_name.lower():
            formatted_dataset = self._format_and_tokenize_MetaMathQA(examples, tokenizer, max_length)
        elif 'human_eval' in self.dataset_name.lower():
            formatted_dataset = self._format_and_tokenize_HumanEval(examples, tokenizer, max_length)
        elif 'mbpp' in self.dataset_name.lower():
            formatted_dataset = self._format_and_tokenize_MBPP(examples, tokenizer, max_length)
        elif 'codealpaca' in self.dataset_name.lower():
            formatted_dataset = self._format_and_tokenize_codealpaca(examples, tokenizer, max_length)
        elif 'pubmedqa' in self.dataset_name.lower():
            formatted_dataset = self._format_and_tokenize_PubMedQA(examples, tokenizer, max_length)
        elif 'medical-o1-reasoning' in self.dataset_name.lower():
            formatted_dataset = self._format_and_tokenize_medicalo1(examples, tokenizer, max_length)
        else:
            raise ValueError("Unsupported dataset")
        return formatted_dataset
