""" """

from functools import lru_cache
from typing import (
    Text, List, Union, Optional,
    Iterable, Generator
)
import ujson as json
from dataclasses import dataclass, field


@dataclass
class Round:
    round_id: Text
    selected: List[Text]
    not_selected: List[Text]
    selected_multiplicity: List[int]
    not_selected_multiplicity: List[int]


@dataclass
class RoundIteration:
    question: Text
    answer_template: Text
    rounds: List[Round]
    topic: Optional[Text] = None


class IterativeClusteringDataReader:
    def __init__(self, data_path: Union[List[Text], Text]):
        if not isinstance(data_path, List):
            data_path = [data_path]

        self._data_paths = data_path

    def _parse_item(self, line_text: Text) -> RoundIteration:
        """ """
        round_dict = json.loads(line_text)
        question = round_dict['question']
        answer_template = round_dict['answer_template']
        rounds = round_dict['rounds']
        
        round_list = []
        
        for ridx, round in enumerate(rounds):
            # print(round)
            round_list.append(
                Round(
                    round_id=round['round_idx'],
                    selected=round['round_result']['selected'],
                    not_selected=round['round_result']['candidates'],
                    selected_multiplicity=round['round_result']['selected_multiplicity'],
                    not_selected_multiplicity=round['round_result']['candidates_multiplicity']
                )
            )
            
        return RoundIteration(
            question=question,
            answer_template=answer_template,
            topic=round_dict.get('topic', None),
            rounds=round_list
        )
        
    @lru_cache(maxsize=3)
    def _read_file(self, file_path: Text) -> List[RoundIteration]:
        with open(file_path, "r", encoding='utf-8') as file_:
            return [self._parse_item(line) for line in file_]
        
    def __iter__(self) -> Iterable[RoundIteration]:
        for file_path in self._data_paths:
            for item in self._read_file(file_path):
                yield item