""" Read the data for scoring in the SimpleQA task. """

from functools import lru_cache
from typing import (
    Text, List, Union, Optional,
    Iterable, Generator
)
import ujson as json
import re
from dataclasses import dataclass, field


@dataclass(eq=True)
class BackOff:
    score: float
    backoff: Text
    multiplicity: int


@dataclass(eq=True)
class SimpleQAScoredQuestion:
    index: int
    question: Text
    gold_answer: Text
    answer_type: Text
    backoffs: List[BackOff]

    
class SimpleQAScoringDataReader:
    
    def __init__(self, data_path: Union[List[Text], Text]):
        if not isinstance(data_path, List):
            data_path = [data_path]
            
        self._data_paths = data_path
        
    def _read_file(self, file_path: Text) -> List[SimpleQAScoredQuestion]:
        """ """
        
        with open(file_path, "r", encoding='utf-8') as file_:
            datalist = []
            
            for line in file_:
                data = json.loads(line)
                backoffs = [
                    BackOff(**backoff) for backoff in data['claims']
                ]
                datalist.append(
                    SimpleQAScoredQuestion(
                        index=data['index'],
                        question=data['question'],
                        gold_answer=data['gold_answer'],
                        answer_type=data['answer_type'],
                        backoffs=backoffs
                    )
                )

        return datalist
            
    def __iter__(self) -> Iterable[SimpleQAScoredQuestion]:
        for file_path in self._data_paths:
            for item in self._read_file(file_path):
                yield item