from abc import ABC
import glob
from typing import Union, Literal, List, Dict, Any
import pandas as pd
import numpy as np
import json
from sklearn.model_selection import train_test_split
import os
import re

class MMLUProDataset(ABC):
    def __init__(self,
                 split: Union[Literal['train'], Literal['test']],
                 data_dir: str = "local_datasets/MMLU-Pro/data/split_by_category",
                 type: Union[str, List[str], None] = None):  # Added data_dir
        self._split = split
        self._type = type  
        self._data_dir = data_dir  
        self._total_df: pd.DataFrame = self._load_data()

    @staticmethod
    def get_domain() -> str:
        return 'mmlu-pro'


    def _load_data(self) -> pd.DataFrame:
        """Loads the MMLU-Pro dataset from JSON files."""

        if(self.split == 'train'):

            file_path = os.path.join(self._data_dir +f"/{self._type}/"+ f"{self._split}.json")
            print(f"Loading data from: {file_path}")

        if(self.split == 'test'):
            file_path = os.path.join(self._data_dir +f"/{self._type}/"+ f"{self._split}.json")
            print(f"Loading data from: {file_path}")
        
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
        except FileNotFoundError:
            raise FileNotFoundError(f"File not found: {file_path}")
        except json.JSONDecodeError:
            raise json.JSONDecodeError(f"Error decoding JSON from file: {file_path}")


        df_data = []
        for item in data:
            
            options = item["options"]  
            option_dict = {}
            for i in range(len(options)):
                option_dict[chr(ord('A') + i)] = options[i]  # A, B, C... etc.

            # Combine all data into a single dictionary for the DataFrame
            row_data = {
                "question_id": item["question_id"],
                "question": item["question"],
                "answer": item["answer"],
                "answer_index": item["answer_index"],
                "cot_content": item["cot_content"],
                "category": item["category"],
                "src": item["src"],
                **option_dict
            }
            df_data.append(row_data)
            
        total_df = pd.DataFrame(df_data)
        print(f"Total number of questions in {self._split}: {len(total_df)}")
        return total_df
        
    @property
    def split(self) -> str:
        return self._split

    def __len__(self) -> int:
        return len(self._total_df)

    def __getitem__(self, index: int) -> pd.Series:  # Corrected type annotation
        record = self._total_df.iloc[index]
        assert isinstance(record, pd.Series)  #Simplified assertion
        return record

    @staticmethod
    def record_to_input(record: pd.Series) -> Dict[str, Any]:
        """Converts a DataFrame record into a dictionary suitable as input for a language model."""
        demo_question = record['question'] + "\n"  # Start with the question
        # Append options A to J
        for option_label in "ABCDEFGHIJ":
            if option_label in record:
                demo_question += f"Option {option_label}: {record[option_label]}\n"
        input_dict = {"task": demo_question}
        return input_dict

    
    def postprocess_answer(self, answer: Union[str, List[str]]) -> str:
    
        if isinstance(answer, list):
            if len(answer) > 0:
                answer = answer[0]
            else:
                answer = ""
        if not isinstance(answer, str):
            raise TypeError("Expected string or list of strings, got {}".format(type(answer)))

        answer = answer.strip().upper()  # Standardize answer
        found_letters = re.findall(r"([A-J])", answer)
        if found_letters:

            return found_letters[-1]

        if len(answer) > 0:
           
            if answer[-1].isalpha():
                return answer[-1]

            # Extract the first character if it's an option letter (original logic)
            if answer[0] in "ABCDEFGHIJ":
                return answer[0]

        return answer  # Return the processed answer
    @staticmethod
    def record_to_target_answer(record: pd.Series) -> str:
        """Returns the correct answer from the given record."""
        correct_answer = record['answer'].upper()
        assert isinstance(correct_answer, str), (
            f"String expected but got {correct_answer} "
            f"of type {type(correct_answer)} record={record}")
        return correct_answer
    
    def split_dataset(cls,
                      input_file: str,
                      output_dir: str,
                      train_ratio: float = 0.8,
                      random_state: int = 42) -> None:

      
        encodings = ['utf-8', 'utf-8-sig', 'latin-1', 'iso-8859-1']
        data = None
        for encoding in encodings:
            try:
                with open(input_file, 'r', encoding=encoding) as f:
                    data = json.load(f)
               
                break
            except:
                print("Error")

        if data is None:
            raise ValueError(f"Error in  {input_file}")

        df = pd.DataFrame(data)
        
        print(df['category'].value_counts())

  
        if not 0 < train_ratio < 1:
            raise ValueError(f"train_ratio Error {train_ratio}")

        try:
            train_df, test_df = train_test_split(
                df,
                train_size=train_ratio,
                stratify=df['category'],
                random_state=random_state
            )
        except ValueError as e:
            
            train_df, test_df = train_test_split(
                df,
                train_size=train_ratio,
                random_state=random_state
            )
        train_path = os.path.join(output_dir, 'train.json')
        test_path = os.path.join(output_dir, 'test.json')
        os.makedirs(output_dir, exist_ok=True)
        train_df.to_json(train_path, orient='records', indent=4, force_ascii=False)
        test_df.to_json(test_path, orient='records', indent=4, force_ascii=False)
  