from abc import ABC
import json
import os
import re
from typing import Union, Literal, List, Dict, Any
import pandas as pd

class GSM8KDataset(ABC):
    def __init__(self,
                 split: Union[Literal['train'], Literal['test']],
                 data_dir: str = "local_datasets/gsm8k",  # Default data directory
                 ):
        self._split = split
        self._data_dir = data_dir
        self._total_df: pd.DataFrame = self._load_data()

    @staticmethod
    def get_domain() -> str:
        return 'gsm8k_new'

    def _load_data(self) -> pd.DataFrame:
        """Loads the GSM8K dataset from JSONL files."""
        file_path = os.path.join(self._data_dir, f"{self._split}.jsonl")
        print(f"Loading data from: {file_path}")

        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                data = [json.loads(line) for line in f] # Read each line as a JSON object
        except FileNotFoundError:
            raise FileNotFoundError(f"File not found: {file_path}")
        except json.JSONDecodeError:
            raise json.JSONDecodeError(f"Error decoding JSON from file: {file_path}")

        # Convert to DataFrame
        total_df = pd.DataFrame(data)
        print(f"Total number of questions in {self._split}: {len(total_df)}")
        return total_df

    @property
    def split(self) -> str:
        return self._split

    def __len__(self) -> int:
        return len(self._total_df)

    def __getitem__(self, index: int) -> pd.Series:
        record = self._total_df.iloc[index]
        assert isinstance(record, pd.Series)
        return record

    @staticmethod
    def record_to_input(record: pd.Series) -> Dict[str, Any]:
        """Converts a DataFrame record into a dictionary suitable as input for a language model."""
        input_dict = {"task": record['question']}  # Only the question for GSM8K
        return input_dict

  
    def postprocess_answer(self, answer: Union[str, List[str]]) -> str:
        """Processes the answer provided by a language model, prioritizing different extraction methods.
        """
        if isinstance(answer, list):
            if len(answer) > 0:
                answer = answer[0]
            else:
                return "0"  # Or "" , depending on your desired default
        if not isinstance(answer, str):
            raise TypeError("Expected string or list of strings, got {}".format(type(answer)))

        text = answer

        parts = text.split("####")
        if len(parts) > 1:
            extracted_answer = parts[-1].strip()
            if extracted_answer and extracted_answer.isdigit():
                return extracted_answer

        match = re.search(r"####(\d+)|answer is (\d+)|Answer: (\d+)", text, re.IGNORECASE)
        if match:
            extracted_answer = match.group(1) or match.group(2) or match.group(3)
            if extracted_answer and extracted_answer.isdigit():
                return extracted_answer

        start_index = text.lower().find("answer is")
        if start_index != -1:
            start_index += len("answer is")
            extracted_answer = text[start_index:].strip()
            if extracted_answer and extracted_answer.isdigit():
                return extracted_answer

        start_index = text.lower().find("answer:")
        if start_index != -1:
            start_index += len("answer:")
            extracted_answer = text[start_index:].strip()
            if extracted_answer and extracted_answer.isdigit():
                return extracted_answer

        numbers = re.findall(r'\d+', text)  
        if numbers:
            extracted_answer = numbers[-1] 
            if extracted_answer and extracted_answer.isdigit():  
                return extracted_answer

   
        return "0"  # Return 0 if no number is found, or "" depending on your needs

    @staticmethod
    def record_to_target_answer(record: pd.Series) -> str:
        """Returns the correct answer from the given record."""
       
        raw_answer = record['answer']
        try:
            correct_answer = raw_answer.split("####")[-1].strip() 
        except:
            correct_answer = 'Can not find correct answer'
        assert isinstance(correct_answer, str), (
            f"String expected but got {correct_answer} "
            f"of type {type(correct_answer)} record={record}")
        correct_answer =correct_answer.replace(",", "")
        return correct_answer

 
    def split_dataset(cls,
                      input_file: str,
                      output_dir: str,
                      train_ratio: float = 0.8,
                      random_state: int = 42) -> None:
        pass