""" Take a set of examples, and filter out the ones that can't be used as exemplars. """

import os
try:
    import ujson as json
except ModuleNotFoundError:
    import json
from glob import glob
from typing import List, Tuple, Text, Dict, Any
from overrides import overrides
from tasker import BaseTask


@BaseTask.register("examplar-filtering")
class ExamplarFilteringTask(BaseTask):
    
    __VERSION__ = "0.0.2"
    
    def __init__(
        self,
        input_dir: Text,
        output_dir: Text,
    ):
        super().__init__(output_dir=output_dir)
        self._input_dir = input_dir

    @overrides
    def _run(self):
        
        selections = glob(os.path.join(self._input_dir, "*.jsonl"))
        
        def sample_from_file(file_path: Text) -> List[Dict[Text, Any]]:
            """ """
            
            sampled = []
            
            with open(file_path, 'r', encoding='utf-8') as file_:
                for line in file_:
                    datapoint = json.loads(line)
                    question = datapoint['question']
                    answer_type = datapoint['answer_type']
                    claims = datapoint['claims']
                    gold_answer = datapoint['gold_answer']
                    
                    if answer_type in {'Number', 'Date', 'Duration'}:
                        continue

                    if len(claims) < 5:
                        continue

                    if claims[-1]['score'] == 0:
                        continue
                    
                    if claims[1]['score'] == 1:
                        continue

                    sampled.append({
                        "question": question,
                        "answer_type": answer_type,
                        "gold_answer": gold_answer,
                        "claims": claims
                    })
                    
            return sampled
        
        results = []
        
        for seleceted_file_path in selections:
            results.extend(
                sample_from_file(seleceted_file_path)
            )
            
        return results
    
    @overrides
    def _write(self, outputs):
        with open(os.path.join(self._output_dir, "filtered.jsonl"), 'w', encoding='utf-8') as file_:
            json.dump(outputs, file_, ensure_ascii=False, indent=2)