""" """

import glob
import random
import logging
from tqdm import tqdm
import numpy as np
import ujson as json
import os
import copy
from overrides import overrides
from typing import (
    Text, List, Dict, Any
)
from tasker import BaseTask


__SEPARATOR__ = " [@@SEP@@] "


@BaseTask.register("nq-question-sampling")
class NQQuestionSamplingTask(BaseTask):
    """ Sample questions to form a subset of the NQ dataset
    with short answers.
    """
    
    __VERSION__ = "0.2.1"
    
    def __init__(
        self,
        input_dir: Text,
        output_dir: Text,
        # num_samples: int,
        # seed: 31,
    ):
        """ """
        super().__init__(output_dir=output_dir)
        self._input_dir = input_dir
        # self._num_samples = num_samples
        # self._seed = seed
        # self._random_obj = random.Random(seed)
        
    @overrides
    def _run(self):
        """ """
        
        with open(os.path.join(self._input_dir, "v1.0-simplified_nq-dev-all.jsonl"), "r") as file_:
            data = []
            for line in tqdm(file_):
                dp = json.loads(line)
                # print(dp['annotations'])
                if any([da['short_answers'] and da['yes_no_answer'] == "NONE" for da in dp['annotations']]):
                    # process this datapoint
                    extractions = []

                    for da in dp['annotations']:
                        for short_answer in da['short_answers']:
                            sttk = short_answer['start_token']
                            edtk = short_answer['end_token']

                            # extract the answers from documnet
                            extraction = ' '.join([item['token'] for item in dp['document_tokens'][sttk:edtk]])
                            extractions.append(extraction)
                            
                    extractions = __SEPARATOR__.join(extractions)
                    
                    data.append({
                        "index": len(data),
                        "topic": "UNIFIED_TOPIC_NQ",
                        "question": dp['question_text'],
                        "answer_type": "SHORT_COMBINED",
                        "answer": extractions
                    })
                    
        # self._random_obj.shuffle(data)
        
        return data
        
    @overrides
    def _write(self, outputs):
        """ """
        
        with open(os.path.join(self._output_dir, "output.jsonl"), "w") as file_:
            for item in outputs:
                file_.write(json.dumps(item) + "\n")