""" """

from functools import lru_cache
from typing import (
    Text, List, Union, Optional,
    Iterable, Generator
)
import ujson as json
from dataclasses import dataclass, field


@dataclass(eq=True)
class SimpleQAAnswerAtRound:
    index: int
    # round_idx: int
    question: Text
    # answer_template: Text
    # topic: Text
    answer_type: Text
    gold_answer: Text
    pos: List[Text]
    neg: List[Text]
    pos_multiplicity: int
    neg_multiplicity: int
    backoff: Text
    # pos_multiplicity_sum: int = field(init=False)
    # neg_multiplicity_sum: int = field(init=False)
    
    # def __post_init__(self):
    #     self.pos_multiplicity_sum = sum(self.pos_multiplicity)
    #     self.neg_multiplicity_sum = sum(self.neg_multiplicity)


@dataclass(eq=True)
class SimpleQAAnswerAtRoundList:
    rounds: List[SimpleQAAnswerAtRound]
    question: Text = field(init=False)
    # answer_template: Text = field(init=False)
    # topic: Text = field(init=False)
    answer_type: Text = field(init=False)
    gold_answer: Text = field(init=False)
    index: int = field(init=False)

    def __post_init__(self):
        self.question = self.rounds[0].question
        # self.answer_template = self.rounds[0].answer_template
        self.answer_type = self.rounds[0].answer_type
        self.gold_answer = self.rounds[0].gold_answer
        # self.topic = self.rounds[0].topic
        self.index = self.rounds[0].index
        
        
class SimpleQAAnswerBackoffDataReader:
    def __init__(self, data_path: Union[List[Text], Text]):
        if not isinstance(data_path, List):
            data_path = [data_path]
            
        self._data_paths = data_path

    @lru_cache(maxsize=3)
    def _read_file(self, file_path: Text) -> List[SimpleQAAnswerAtRoundList]:
        with open(file_path, "r", encoding='utf-8') as file_:
            data_lists = [SimpleQAAnswerAtRoundList(
                rounds=[
                    SimpleQAAnswerAtRound(**item_) for item_ in list_
                ]
            ) for list_ in json.load(file_)]
            
        return data_lists
        
    def __iter__(self) -> Iterable[SimpleQAAnswerAtRoundList]:
        for file_path in self._data_paths:
            for item in self._read_file(file_path):
                yield item