from abc import ABC
import json
import os
import re
import string
from typing import Union, Literal, List, Dict, Any, Tuple
import pandas as pd


class DROPDataset(ABC):
    def __init__(self,
                 split: Union[Literal['train'], Literal['validation']],
                 data_dir: str = "local_datasets/drop/json_data",  # Default data directory
                 ):
        self._split = split
        self._data_dir = data_dir
        self._total_df: pd.DataFrame = self._load_data()

    @staticmethod
    def get_domain() -> str:
        return 'drop'

    # def _load_data(self) -> pd.DataFrame:
    #     """Loads the DROP dataset from JSON files."""
    #     file_path = os.path.join(self._data_dir, f"{self._split}.json") # Changed to .json
    #     print(f"Loading data from: {file_path}")

    #     try:
    #         with open(file_path, 'r', encoding='utf-8') as f:
    #             data = json.load(f)  # Load the entire JSON file at once
    #     except FileNotFoundError:
    #         raise FileNotFoundError(f"File not found: {file_path}")
    #     except json.JSONDecodeError:
    #         raise json.JSONDecodeError(f"Error decoding JSON from file: {file_path}")

    #     # Convert to DataFrame
    #     total_df = pd.DataFrame(data)
    #     print(f"Total number of questions in {self._split}: {len(total_df)}")
    #     return total_df

    
    def _load_data(self) -> pd.DataFrame:
        """Loads the DROP dataset from JSON files where each entry is a QA pair."""
        file_path = os.path.join(self._data_dir, f"{self._split}.json")
        print(f"Loading data from: {file_path}")

        try:
            
            with open(file_path, 'r', encoding='utf-8') as f:
                data: List[Dict] = json.load(f) 
        except FileNotFoundError:
            raise FileNotFoundError(f"File not found: {file_path}")
        except json.JSONDecodeError:
            raise json.JSONDecodeError(f"Error decoding JSON from file: {file_path}")

        total_df = pd.DataFrame(data)
        
       
        if 'answers_spans' not in total_df.columns:
            raise KeyError("DataFrame does not contain 'answers_spans' column. Check your JSON structure.")
        
      
        
        if not total_df.empty:
            
            first_answers_spans = total_df.iloc[0]['answers_spans']
            
            if isinstance(first_answers_spans, dict):
                pass
            else:
                print("Unexpected 'answers_spans' format (expected dict).")
            print("-" * 30)

        return total_df
    def split(self) -> str:
        return self._split

    def __len__(self) -> int:
        return len(self._total_df)

    def __getitem__(self, index: int) -> pd.Series:
        record = self._total_df.iloc[index]
        assert isinstance(record, pd.Series)
        return record

    @staticmethod
    def record_to_input(record: pd.Series) -> Dict[str, Any]:
        """Converts a DataFrame record into a dictionary suitable as input for a language model."""
        task = f"Given the following passage {record['passage']} Answer the question {record['question']}"
        input_dict = {"task": task} 
        # input_dict = {
        #     "passage": record['passage'],
        #     "question": record['question']
        # }  # Include passage and question
        return input_dict

    def postprocess_answer(self, answer: Union[str, List[str]]) -> str: 
        """
        Processes the answer provided by a language model.
        Prioritizes extracting content within the last pair of square brackets `[]`.
        If no square brackets are found or their content is empty, it falls back to extracting the last non-empty line.
        Handles cases where the input 'answer' is a list of strings, taking the first non-empty string.
        """
   
        if isinstance(answer, list):
            
            found_string = False
            for item in answer:
                if isinstance(item, str) and item.strip(): 
                    answer = item
                    found_string = True
                    break
            if not found_string: 
                return "" 

      
        if not isinstance(answer, str):
           
            raise TypeError(f"Expected string or list of strings, but final processed 'answer' is {type(answer)}")

      
        bracket_matches = re.findall(r"\[(.*?)\]", answer)
        if bracket_matches:
            extracted_from_bracket = bracket_matches[-1].strip()
            if extracted_from_bracket:
                return extracted_from_bracket

      
        lines = answer.split('\n')
        for line in reversed(lines):
            stripped_line = line.strip()
            if stripped_line:
                return stripped_line
        

        return ""


    @staticmethod
    def record_to_target_answer(record: pd.Series) -> Tuple[List[str], List[str]]:
        """Returns the correct answer spans and types from the given record."""
        raw_answers = record['answers_spans']
        spans = raw_answers['spans']
        types = raw_answers['types']

        assert isinstance(spans, list), f"Expected spans to be a list, got {type(spans)}"
        assert isinstance(types, list), f"Expected types to be a list, got {type(types)}"

        return spans, types  # Return both spans and types


    # split_dataset function not needed as DROP is already split into train.json and test.json
    def split_dataset(cls,
                      input_file: str,
                      output_dir: str,
                      train_ratio: float = 0.8,
                      random_state: int = 42) -> None:
        pass


if __name__ == '__main__':
    drop_train = DROPDataset(split='train')
    print(f"Number of training examples: {len(drop_train)}")