""" """

from functools import lru_cache
from typing import (
    Text, List, Union, Optional,
    Iterable, Generator
)
import ujson as json
from dataclasses import dataclass, field


@dataclass
class Round:
    # round_id: Text
    selected: List[Text]
    not_selected: List[Text]
    selected_multiplicity: List[int]
    not_selected_multiplicity: List[int]


@dataclass
class SimpleQARoundIteration:
    question: Text
    answer_type: Text
    gold_answer: Text
    rounds: List[Round]


class SimpleQAIterativeClusteringDataReader:
    def __init__(self, data_path: Union[List[Text], Text]):
        if not isinstance(data_path, List):
            data_path = [data_path]

        self._data_paths = data_path

    def _parse_item(self, line_text: Text) -> SimpleQARoundIteration:
        """ """
        round_dict = json.loads(line_text)
        question = round_dict['question']
        answer_type = round_dict['answer_type']
        gold_answer = round_dict['gold_answer']
        rounds = round_dict['rounds']
        
        round_list = []
        
        for ridx, round in enumerate(rounds):
            # print(round)
            round_list.append(
                Round(
                    # round_id=round['round_idx'],
                    selected=[str(s) for s in round['selected']],
                    not_selected=[str(c) for c in round['candidates']],
                    selected_multiplicity=round['selected_multiplicity'],
                    not_selected_multiplicity=round['candidates_multiplicity']
                )
            )
            
        return SimpleQARoundIteration(
            question=question,
            answer_type=answer_type,
            gold_answer=gold_answer,
            rounds=round_list
        )
        
    @lru_cache(maxsize=3)
    def _read_file(self, file_path: Text) -> List[SimpleQARoundIteration]:
        with open(file_path, "r", encoding='utf-8') as file_:
            return [self._parse_item(line) for line in file_]
        
    def __iter__(self) -> Iterable[SimpleQARoundIteration]:
        for file_path in self._data_paths:
            for item in self._read_file(file_path):
                yield item