import json
import os
import re
import pandas as pd
from typing import Union, Literal, List, Dict, Any, Optional 



from abc import ABC

class MATHDataset(ABC): 
    def __init__(self,
                 split: Union[Literal['train'], Literal['test'], Literal['validation']], 
                 data_dir: str = "local_datasets/math", 
                 ):
        self._split = split
        self._data_dir = data_dir
        self._file_path: List[str] = self._get_problem_file()
        self._total_df: Optional[pd.DataFrame] = None 
        self._records = self._load_all_records()

    @staticmethod
    def get_domain() -> str:
        return 'math' 
    
    def _load_data(self) -> pd.DataFrame:
        pass

    def _get_problem_file(self) -> str:
  
        jsonl_file_name = f"{self._split}.jsonl"
        file_path = os.path.join(self._data_dir, jsonl_file_name)
        
        if not os.path.exists(file_path):
            raise FileNotFoundError(f"dataset file not found: {file_path}")
        
        print(f"Loading problems from {file_path}")
        return file_path

    def _load_all_records(self) -> List[Dict[str, Any]]:
       
        records = []
        try:
            with open(self._file_path, 'r', encoding='utf-8') as f:
                for line in f:
                    
                    record = json.loads(line.strip())
                    records.append(record)
            print(f"Loaded {len(records)} problems from {self._file_path}.")
            return records
        except FileNotFoundError:
            print(f"Error: Problem file '{self._file_path}' not found.")
            return []
        except json.JSONDecodeError as e:
            print(f"Error: Failed to parse JSON line in '{self._file_path}'. Please check file format. Error: {e}")
            return []
        except Exception as e:
            print(f"Error loading problems from {self._file_path}: {e}")
            return []

    @property
    def split(self) -> str:
        return self._split

    def __len__(self) -> int:
      
        return len(self._records)

   
    def __getitem__(self, index: int) -> Dict[str, Any]:
        
        if 0 <= index < len(self._records):
            return self._records[index]
        else:
            raise IndexError("Index out of bounds for dataset records.")

    @staticmethod
    def record_to_input(record: Dict[str, Any]) -> Dict[str, Any]:
        
        
        input_dict = {"task": record['problem']}
        return input_dict

       



    def _extract_balanced_braces(self, text: str, marker: str) -> Optional[str]:
        """
        A robust helper function to extract content from a marker with balanced curly braces.
        
        Args:
            text: The string to search within.
            marker: The starting marker to look for (e.g., r'\\boxed{').

        Returns:
            The extracted content if found and balanced, otherwise None.
        """
        start_index = text.find(marker)
        if start_index == -1:
            return None

        content_start_index = start_index + len(marker)
        brace_level = 1
        for i in range(content_start_index, len(text)):
            char = text[i]
            if char == '{':
                brace_level += 1
            elif char == '}':
                brace_level -= 1
            
            if brace_level == 0:
                return text[content_start_index:i]
        
        # If loop finishes, braces were unbalanced
        return None

    def postprocess_answer(self, answer: Union[str, List[str]]) -> str:
        """Processes the answer provided by a language model with robust brace handling."""
        if isinstance(answer, list):
            answer = answer[0] if answer else "0"
        if not isinstance(answer, str):
            raise TypeError(f"Expected string or list of strings, got {type(answer)}")

        text = answer

       
        code_block_match = re.search(r"```python\n(.*?)```", text, re.DOTALL)
        if code_block_match:
            code_content = code_block_match.group(1).strip()
            extracted = self._extract_balanced_braces(code_content, r'\boxed{')
            if extracted is not None:
                return self._clean_number_string(extracted)

       
        extracted = self._extract_balanced_braces(text, r'\boxed{')
        if extracted is not None:
            return self._clean_number_string(extracted)

       
        extracted = self._extract_balanced_braces(text, r'\x08oxed{')
        if extracted is not None:
            print(f"Warning: Found and corrected misspelled '\\x08oxed'.")
            return self._clean_number_string(extracted)

        
        answer_keywords = ["The answer is", "the answer is", "Answer is", "The final answer is"]
        for keyword in answer_keywords:
            
            keyword_match = re.search(re.escape(keyword) + r'\s*([^\n]+)', text, re.IGNORECASE)
            if keyword_match:
                answer_snippet = keyword_match.group(1).strip()
                
             
                boxed_in_snippet = self._extract_balanced_braces(answer_snippet, r'\boxed{')
                if boxed_in_snippet is not None:
                    return self._clean_number_string(boxed_in_snippet)
                
                misspelled_in_snippet = self._extract_balanced_braces(answer_snippet, r'\x08oxed{')
                if misspelled_in_snippet is not None:
                    print(f"Warning: Found and corrected misspelled '\\x08oxed' after keyword.")
                    return self._clean_number_string(misspelled_in_snippet)

               
                numbers_in_snippet = re.findall(r'(-?\d+(?:,\d+)*\.\d+|-?\d+(?:,\d+)*)', answer_snippet)
                if numbers_in_snippet:
                    last_num_str = numbers_in_snippet[-1].replace(",", "")
                    if re.fullmatch(r'-?\d+\.?\d*', last_num_str):
                        last_num_str = self._clean_number_string(last_num_str)
                        return last_num_str.rstrip('.')

    
        numbers = re.findall(r'(-?\d+(?:,\d+)*\.\d+|-?\d+(?:,\d+)*)', text)
        if numbers:
            last_num_str = numbers[-1].replace(",", "")
            if re.fullmatch(r'-?\d+\.?\d*', last_num_str):
                 last_num_str = self._clean_number_string(last_num_str)
                 return last_num_str.rstrip('.')

     
        print(f"Warning: Could not extract a clear answer from LLM output: '{text}'")
        return "0"

    @staticmethod
    def _clean_number_string(s: str) -> str:
       
        s = s.replace(",", "") 
        s = s.replace(" ", "")

        return s.strip()

    @staticmethod
    def record_to_target_answer(record: Dict[str, Any]) -> str:

        raw_solution = record.get('solution', '')  
        start_marker = r'\boxed{'
        
  
        start_index = raw_solution.find(start_marker)

    
        if start_index == -1:
           
            return "N/A"


        content_start_index = start_index + len(start_marker)
        
        brace_level = 1
   
        for i in range(content_start_index, len(raw_solution)):
            char = raw_solution[i]
            
            if char == '{':
                brace_level += 1
            elif char == '}':
                
                brace_level -= 1

        
            if brace_level == 0:
                correct_answer = raw_solution[content_start_index:i]
                
               
                return MATHDataset._clean_number_string(correct_answer)

      
      
        return "N/A"


    def split_dataset(self,
                      input_file: str, 
                      output_dir: str,
                      train_ratio: float = 0.8,
                      random_state: int = 42) -> None:
   
        print("MATHDataset does not support splitting from a single file as it's pre-split by directory.")
        pass

