""" """

import csv
from functools import lru_cache
from typing import (
    Text, List, Union, Optional,
    Iterable, Generator
)
import ujson as json
import re
from dataclasses import dataclass, field


@dataclass(eq=True)
class SimpleQAInstance:
    index: int
    topic: Text
    answer_type: Text
    question: Text
    answer: Text

    
class SimpleQADataReader:
    def __init__(self, data_path: Union[List[Text], Text]):
        if not isinstance(data_path, List):
            data_path = [data_path]
            
        self._data_paths = data_path
        
    @lru_cache(maxsize=3)
    def _read_file(self, file_path: Text) -> List[SimpleQAInstance]:
        data_list = []
        with open(file_path, "r", encoding='utf-8') as file_:
            csv_reader = csv.DictReader(file_)
            for ridx, row in enumerate(csv_reader):
                # print(type(row['metadata']))
                # print(row)
                # print(row['metadata'].replace("'", "\""))
                metadata_line = row['metadata'].replace(
                    "{'", "{\""
                ).replace(
                    "': '", "\": \""
                ).replace(
                    "': ", '": '
                ).replace(
                    "', '", "\", \""
                ).replace(
                    "', ", "\", "
                ).replace(
                    ", '", ", \""
                ).replace(
                    "'}", "\"}"
                ).replace(
                    "['", "[\""
                ).replace("']", "\"]")
                
                topic = re.search(r"\"topic\": \"(.*?)\"", metadata_line).group(1)
                answer_type = re.search(r"\"answer_type\": \"(.*?)\"", metadata_line).group(1)

                data_list.append(SimpleQAInstance(
                    index=ridx,
                    topic=topic,
                    answer_type=answer_type,
                    question=row['problem'],
                    answer=row['answer']
                ))
                
        return data_list
    
    def __iter__(self) -> Iterable[SimpleQAInstance]:
        for file_path in self._data_paths:
            for item in self._read_file(file_path):
                yield item