from abc import ABC
import json
import os
import re
from typing import Union, Literal, List, Dict, Any
import pandas as pd

class MGSMDataset(ABC):
    def __init__(self,
                 split: Union[Literal['train'], Literal['test']],
                 
                 data_dir: str = "local_datasets/MGSM",  # Default data directory for MGSM
                 ):
        self._split = split
        self._data_dir = data_dir
        self._total_df: pd.DataFrame = self._load_data()

    @staticmethod
    def get_domain() -> str:
        
        return 'mgsm'

    def _load_data(self) -> pd.DataFrame:

        file_path = os.path.join(self._data_dir, f"{self._split}.jsonl")
        print(f"Loading data from: {file_path}")

        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                data = [json.loads(line) for line in f] # Read each line as a JSON object
        except FileNotFoundError:
            raise FileNotFoundError(f"File not found: {file_path}. Please ensure your train.jsonl "
                                    f"and mgsm_test_dataset.jsonl files are in the '{self._data_dir}' directory.")
        except json.JSONDecodeError as e:
            raise json.JSONDecodeError(f"Error decoding JSON from file: {file_path}. Check for malformed JSON lines. Error: {e}")
        except Exception as e:
            raise Exception(f"An unexpected error occurred while loading data from {file_path}: {e}")

        # Convert to DataFrame
        total_df = pd.DataFrame(data)
        print(f"Total number of questions in MGSM {self._split} split: {len(total_df)}")
        return total_df

    @property
    def split(self) -> str:
     
        return self._split

    def __len__(self) -> int:
       
        return len(self._total_df)

    def __getitem__(self, index: int) -> pd.Series:
       
        record = self._total_df.iloc[index]
        assert isinstance(record, pd.Series)
        return record

    @staticmethod
    def record_to_input(record: pd.Series) -> Dict[str, Any]:
  
        input_dict = {"task": record['question']}
        return input_dict

    def postprocess_answer(self, answer: Union[str, List[str]]) -> str:
        
        if isinstance(answer, list):
            if len(answer) > 0:
                answer = answer[0]
            else:
                return "0" 
        if not isinstance(answer, str):
            raise TypeError("Expected string or list of strings, got {}".format(type(answer)))

        text = answer

      
        parts = text.split("####")
        if len(parts) > 1:
            extracted_answer = parts[-1].strip()
            if extracted_answer and extracted_answer.isdigit():
                return extracted_answer


        match = re.search(r"####(\d+)|answer is (\d+)|Answer: (\d+)", text, re.IGNORECASE)
        if match:
            extracted_answer = match.group(1) or match.group(2) or match.group(3)
            if extracted_answer and extracted_answer.isdigit():
                return extracted_answer

        
        start_index = text.lower().find("answer is")
        if start_index != -1:
            start_index += len("answer is")
            extracted_answer = text[start_index:].strip()
            if extracted_answer and extracted_answer.isdigit():
                return extracted_answer

        start_index = text.lower().find("answer:")
        if start_index != -1:
            start_index += len("answer:")
            extracted_answer = text[start_index:].strip()
            if extracted_answer and extracted_answer.isdigit():
                return extracted_answer

        
        numbers = re.findall(r'\d+', text)  
        if numbers:
            extracted_answer = numbers[-1]  
            if extracted_answer and extracted_answer.isdigit():
                return extracted_answer


        return "0"

    @staticmethod
    def record_to_target_answer(record: pd.Series) -> str:
        
        correct_answer = str(record['answer'])
        assert isinstance(correct_answer, str), (
            f"String expected but got {correct_answer} "
            f"of type {type(correct_answer)} from record={record}")
        
    
        correct_answer = correct_answer.replace(",", "") 
        return correct_answer

    @classmethod
    def split_dataset(cls,
                      input_file: str,
                      output_dir: str,
                      train_ratio: float = 0.8,
                      random_state: int = 42) -> None:

        pass


