import json
import logging
from typing import Any, Dict, List, Tuple, Optional

from overrides import overrides

from allennlp.common.file_utils import cached_path
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.instance import Instance
from allennlp.data.token_indexers import SingleIdTokenIndexer, TokenIndexer
from allennlp.data.tokenizers import Token, Tokenizer, SpacyTokenizer

from allennlp_models.rc.dataset_readers import utils

logger = logging.getLogger(__name__)


@DatasetReader.register("squad")
class SquadReader(DatasetReader):
    """
    Reads a JSON-formatted SQuAD file and returns a ``Dataset`` where the ``Instances`` have four
    fields: ``question``, a ``TextField``, ``passage``, another ``TextField``, and ``span_start``
    and ``span_end``, both ``IndexFields`` into the ``passage`` ``TextField``.  We also add a
    ``MetadataField`` that stores the instance's ID, the original passage text, gold answer strings,
    and token offsets into the original passage, accessible as ``metadata['id']``,
    ``metadata['original_passage']``, ``metadata['answer_texts']`` and
    ``metadata['token_offsets']``.  This is so that we can more easily use the official SQuAD
    evaluation script to get metrics.

    We also support limiting the maximum length for both passage and question. However, some gold
    answer spans may exceed the maximum passage length, which will cause error in making instances.
    We simply skip these spans to avoid errors. If all of the gold answer spans of an example
    are skipped, during training, we will skip this example. During validating or testing, since
    we cannot skip examples, we use the last token as the pseudo gold answer span instead. The
    computed loss will not be accurate as a result. But this will not affect the answer evaluation,
    because we keep all the original gold answer texts.

    # Parameters

    tokenizer : `Tokenizer`, optional (default=`SpacyTokenizer()`)
        We use this ``Tokenizer`` for both the question and the passage.  See :class:`Tokenizer`.
        Default is `SpacyTokenizer()`.
    token_indexers : `Dict[str, TokenIndexer]`, optional
        We similarly use this for both the question and the passage.  See :class:`TokenIndexer`.
        Default is `{"tokens": SingleIdTokenIndexer()}`.
    passage_length_limit : `int`, optional (default=`None`)
        if specified, we will cut the passage if the length of passage exceeds this limit.
    question_length_limit : `int`, optional (default=`None`)
        if specified, we will cut the question if the length of passage exceeds this limit.
    skip_invalid_examples: `bool`, optional (default=`False`)
        if this is true, we will skip those invalid examples
    """

    def __init__(
        self,
        tokenizer: Tokenizer = None,
        token_indexers: Dict[str, TokenIndexer] = None,
        passage_length_limit: int = None,
        question_length_limit: int = None,
        skip_invalid_examples: bool = False,
        **kwargs
    ) -> None:
        super().__init__(**kwargs)
        self._tokenizer = tokenizer or SpacyTokenizer()
        self._token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}
        self.passage_length_limit = passage_length_limit
        self.question_length_limit = question_length_limit
        self.skip_invalid_examples = skip_invalid_examples

    @overrides
    def _read(self, file_path: str):
        # if `file_path` is a URL, redirect to the cache
        file_path = cached_path(file_path)

        logger.info("Reading file at %s", file_path)
        with open(file_path) as dataset_file:
            dataset_json = json.load(dataset_file)
            dataset = dataset_json["data"]
        logger.info("Reading the dataset")
        for article in dataset:
            for paragraph_json in article["paragraphs"]:
                paragraph = paragraph_json["context"]
                tokenized_paragraph = self._tokenizer.tokenize(paragraph)

                for question_answer in paragraph_json["qas"]:
                    question_text = question_answer["question"].strip().replace("\n", "")
                    answer_texts = [answer["text"] for answer in question_answer["answers"]]
                    span_starts = [answer["answer_start"] for answer in question_answer["answers"]]
                    span_ends = [
                        start + len(answer) for start, answer in zip(span_starts, answer_texts)
                    ]
                    additional_metadata = {"id": question_answer.get("id", None)}
                    instance = self.text_to_instance(
                        question_text,
                        paragraph,
                        zip(span_starts, span_ends),
                        answer_texts,
                        tokenized_paragraph,
                        additional_metadata,
                    )
                    if instance is not None:
                        yield instance

    @overrides
    def text_to_instance(
        self,  # type: ignore
        question_text: str,
        passage_text: str,
        char_spans: List[Tuple[int, int]] = None,
        answer_texts: List[str] = None,
        passage_tokens: List[Token] = None,
        additional_metadata: Dict[str, Any] = None,
    ) -> Optional[Instance]:

        if not passage_tokens:
            passage_tokens = self._tokenizer.tokenize(passage_text)
        question_tokens = self._tokenizer.tokenize(question_text)
        if self.passage_length_limit is not None:
            passage_tokens = passage_tokens[: self.passage_length_limit]
        if self.question_length_limit is not None:
            question_tokens = question_tokens[: self.question_length_limit]
        char_spans = char_spans or []
        # We need to convert character indices in `passage_text` to token indices in
        # `passage_tokens`, as the latter is what we'll actually use for supervision.
        token_spans: List[Tuple[int, int]] = []
        passage_offsets = [(token.idx, token.idx + len(token.text)) for token in passage_tokens]
        for char_span_start, char_span_end in char_spans:
            if char_span_end > passage_offsets[-1][1]:
                continue
            (span_start, span_end), error = utils.char_span_to_token_span(
                passage_offsets, (char_span_start, char_span_end)
            )
            if error:
                logger.debug("Passage: %s", passage_text)
                logger.debug("Passage tokens: %s", passage_tokens)
                logger.debug("Question text: %s", question_text)
                logger.debug("Answer span: (%d, %d)", char_span_start, char_span_end)
                logger.debug("Token span: (%d, %d)", span_start, span_end)
                logger.debug("Tokens in answer: %s", passage_tokens[span_start : span_end + 1])
                logger.debug("Answer: %s", passage_text[char_span_start:char_span_end])
            token_spans.append((span_start, span_end))
        # The original answer is filtered out
        if char_spans and not token_spans:
            if self.skip_invalid_examples:
                return None
            else:
                token_spans.append((len(passage_tokens) - 1, len(passage_tokens) - 1))
        return utils.make_reading_comprehension_instance(
            question_tokens,
            passage_tokens,
            self._token_indexers,
            passage_text,
            token_spans,
            answer_texts,
            additional_metadata,
        )
