#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# file: mrc_ner_dataset.py
import bz2
import json
import os

import torch
from tqdm import tqdm
from torch.utils.data import Dataset
from transformers.tokenization_utils_base import TruncationStrategy
from transformers.utils import logging
logger = logging.get_logger(__name__)
MULTI_SEP_TOKENS_TOKENIZERS_SET = {"roberta", "camembert", "bart", "mpnet"}


class PMRDataset(Dataset):
    """
    MRC NER Dataset
    Args:
        json_path: path to mrc-ner style json
        tokenizer: BertTokenizer
        max_length: int, max length of query+context
        possible_only: if True, only use possible samples that contain answer for the query/context
    """
    def __init__(self, data_path, tokenizer: None, max_length: int = 512, max_query_length = 64, pad_to_maxlen=False, context_first=False, evaluate=False, lazy_load=True):
        drive, tail = os.path.split(data_path)
        self.all_buffer_files = sorted([os.path.join(drive, x) for x in os.listdir(drive) if x.startswith(tail)], key= lambda x: int(x.split("_")[-1]))
        self.buffer_id = 0
        self.lazy_load = lazy_load
        if evaluate:
            if len(self.all_buffer_files) != 1:
                logger.error("please save evaluate file in one single buffer file")
                assert len(self.all_buffer_files) == 1
        if self.lazy_load:
            features = torch.load(self.all_buffer_files[self.buffer_id])
            keys = list(features.keys())
            self.all_data = [features.pop(x) for x in
                             tqdm(keys, desc="prepare dataset at buffer {}".format(self.buffer_id))]
        else:
            self.all_data = []
            for buffer_id in range(len(self.all_buffer_files)):
                features = torch.load(self.all_buffer_files[buffer_id])
                keys = list(features.keys())
                for x in tqdm(keys, desc="prepare dataset at buffer {}".format(buffer_id)):
                    feature_one = features.pop(x)
                    self.all_data.append(feature_one)
            self.all_buffer_files = [1] # set the length of self.all_buffer_files to be 1

        self.tokenizer = tokenizer
        self.max_length = max_length
        self.max_query_length = max_query_length
        self.pad_to_maxlen = pad_to_maxlen
        self.context_first = context_first

    def next_buffer(self):
        del self.all_data
        self.buffer_id = (self.buffer_id + 1) % len(self.all_buffer_files)
        features = torch.load(self.all_buffer_files[self.buffer_id])
        keys = list(features.keys())
        self.all_data = [features.pop(x) for x in tqdm(keys, desc="prepare dataset at buffer {}".format(self.buffer_id))]

    def __len__(self):
        return len(self.all_data)

    def __getitem__(self, item):
        """
        Args:
            item: int, idx
        Returns:
            tokens: tokens of query + context, [seq_len]
            attention_mask: attention mask, 1 for token, 0 for padding, [seq_len]
            token_type_ids: token type ids, 0 for query, 1 for context, [seq_len]
            label_mask: label mask, 1 for counting into loss, 0 for ignoring. [seq_len]
            match_labels: match labels, [seq_len, seq_len]
        """
        data = self.all_data[item]
        seq_len = len(data.input_ids)

        pair_input_ids = data.input_ids
        if data.attention_mask is None:
            pair_attention_mask = [1] * len(pair_input_ids)
        else:
            pair_attention_mask = data.attention_mask
        if data.token_type_ids is None:
            pair_token_types_ids = [0] * len(pair_input_ids)
        else:
            pair_token_types_ids = data.token_type_ids
        offset = data.doc_offset
        tokenizer_type = type(self.tokenizer).__name__.replace("Tokenizer", "").lower()
        sequence_added_tokens = (
            self.tokenizer.model_max_length - self.tokenizer.max_len_single_sentence + 1
            if tokenizer_type in MULTI_SEP_TOKENS_TOKENIZERS_SET
            else self.tokenizer.model_max_length - self.tokenizer.max_len_single_sentence
        )
        if self.context_first:
            label_mask = [1] + [1] * (offset - sequence_added_tokens) + [0] * (seq_len - offset + sequence_added_tokens - 1) # allow cls in DLM loss
        else:
            label_mask = [1] + [0] * (offset - 1) + [1] * (seq_len - offset - 1) + [0]

        assert all(label_mask[p] != 0 for p in data.start_positions)
        assert all(label_mask[p] != 0 for p in data.end_positions)
        assert len(label_mask) == seq_len

        match_labels = torch.zeros([seq_len, seq_len], dtype=torch.long)
        if data.start_positions != [] and data.end_positions != []:
            match_labels[0, 0] = 1
        for start, end in zip(data.start_positions, data.end_positions):
            if start >= seq_len or end >= seq_len:
                continue
            match_labels[start, end] = 1

        return [
            torch.LongTensor(pair_input_ids),
            torch.LongTensor(pair_attention_mask),
            torch.LongTensor(pair_token_types_ids),
            torch.LongTensor(label_mask),
            match_labels,
        ]

class MRCNERDataset(Dataset):
    """
    MRC NER Dataset
    Args:
        json_path: path to mrc-ner style json
        tokenizer: BertTokenizer
        max_length: int, max length of query+context
        possible_only: if True, only use possible samples that contain answer for the query/context
    """
    def __init__(self, json_path, tokenizer: None, max_length: int = 512, max_query_length = 64, possible_only=False, pad_to_maxlen=False, is_chinese=False, context_first=False):
        self.all_data = json.load(open(json_path, encoding="utf-8"))
        self.is_chinese = is_chinese
        self.prompt()
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.max_query_length = max_query_length
        self.possible_only = possible_only
        self.pad_to_maxlen = pad_to_maxlen
        self.context_first = context_first


    def prompt(self):
        num_examples = len(self.all_data)
        num_impossible = len([1 for x in self.all_data if x["impossible"]])
        self.neg_ratio = (num_examples - num_impossible) / num_impossible
        new_datas = []
        for data in self.all_data:
            label = data['entity_label']
            details = data['query']
            context = data['context']
            start_positions = data["start_position"]
            end_positions = data["end_position"]
            words = context.split()
            assert len(words) == len(context.split(" "))
            if self.is_chinese:
                query = '突出显示与“{}”相关的部分（如果有）。 含义：{}'.format(label, details)
            else:
                query = 'Highlight the parts (if any) related to "{}". Details: {}'.format(label, details)
            span_positions = {"{};{}".format(start_positions[i], end_positions[i]):" ".join(words[start_positions[i]: end_positions[i] + 1]) for i in range(len(start_positions))}
            new_data = {
                'context':words,
                'end_position':end_positions,
                'entity_label':label,
                'impossible':data['impossible'],
                'qas_id':data['qas_id'],
                'query':query,
                'span_position':span_positions,
                'start_position': start_positions,
            }
            # if label == "ORG":
            new_datas.append(new_data)
        self.all_data = new_datas


    def __len__(self):
        return len(self.all_data)

    def __getitem__(self, item):
        """
        Args:
            item: int, idx
        Returns:
            tokens: tokens of query + context, [seq_len]
            attention_mask: attention mask, 1 for token, 0 for padding, [seq_len]
            token_type_ids: token type ids, 0 for query, 1 for context, [seq_len]
            label_mask: label mask, 1 for counting into loss, 0 for ignoring. [seq_len]
            match_labels: match labels, [seq_len, seq_len]
        """
        data = self.all_data[item]
        tokenizer = self.tokenizer



        query = data["query"]
        context = data["context"]
        start_positions = data["start_position"]
        end_positions = data["end_position"]

        tokenizer_type = type(tokenizer).__name__.replace("Tokenizer", "").lower()
        sequence_added_tokens = (
            tokenizer.model_max_length - tokenizer.max_len_single_sentence + 1
            if tokenizer_type in MULTI_SEP_TOKENS_TOKENIZERS_SET
            else tokenizer.model_max_length - tokenizer.max_len_single_sentence
        )



        tok_to_orig_index = []
        orig_to_tok_index = []
        all_doc_tokens = []
        for (i, token) in enumerate(context):
            orig_to_tok_index.append(len(all_doc_tokens))
            if tokenizer.__class__.__name__ in [
                "RobertaTokenizer",
                "LongformerTokenizer",
                "BartTokenizer",
                "RobertaTokenizerFast",
                "LongformerTokenizerFast",
                "BartTokenizerFast",
            ]:
                sub_tokens = tokenizer.tokenize(token, add_prefix_space=True)
            elif tokenizer.__class__.__name__ in [
                'BertTokenizer'
            ]:
                sub_tokens = tokenizer.tokenize(token)
            elif tokenizer.__class__.__name__ in [
                'BertWordPieceTokenizer'
            ]:
                sub_tokens = tokenizer.encode(token, add_special_tokens=False).tokens
            else:
                sub_tokens = tokenizer.tokenize(token)
            for sub_token in sub_tokens:
                tok_to_orig_index.append(i)
                all_doc_tokens.append(sub_token)


        tok_start_positions = [orig_to_tok_index[x] for x in start_positions]
        tok_end_positions = []
        for x in end_positions:
            if x < len(context) - 1:
                tok_end_positions.append(orig_to_tok_index[x + 1] - 1)
            else:
                tok_end_positions.append(len(all_doc_tokens) - 1)


        if self.pad_to_maxlen:
            truncation = TruncationStrategy.ONLY_SECOND.value
            padding_strategy = "max_length"
        else:
            truncation = TruncationStrategy.ONLY_SECOND.value
            padding_strategy = "do_not_pad"

        if self.context_first:
            truncated_context = tokenizer.encode(
                all_doc_tokens, add_special_tokens=False, truncation=True, max_length=self.max_context_length
            )
            encoded_dict = tokenizer.encode_plus(  # TODO(thom) update this logic
                truncated_context,
                query,
                truncation=truncation,
                padding=padding_strategy,
                max_length=self.max_length,
                return_overflowing_tokens=True,
                return_token_type_ids=True,
            )
            tokens = encoded_dict['input_ids']
            type_ids = encoded_dict['token_type_ids']
            attn_mask = encoded_dict['attention_mask']
            # find new start_positions/end_positions, considering
            # 1. we add cls token at the beginning
            doc_offset = 1
            new_start_positions = [x + doc_offset for x in tok_start_positions if
                                   (x + doc_offset) <= self.max_context_length]
            new_end_positions = [x + doc_offset if (x + doc_offset) <= self.max_context_length else self.max_context_length for x
                                 in tok_end_positions]
            new_end_positions = new_end_positions[:len(new_start_positions)]
            label_mask = [0] * doc_offset + [1] * len(truncated_context) + [0] * (len(tokens) - len(truncated_context) - 1)
        else:
            truncated_query = tokenizer.encode(
                query, add_special_tokens=False, truncation=True, max_length=self.max_query_length
            )
            encoded_dict = tokenizer.encode_plus(  # TODO(thom) update this logic
                truncated_query,
                all_doc_tokens,
                truncation=truncation,
                padding=padding_strategy,
                max_length=self.max_length,
                return_overflowing_tokens=True,
                return_token_type_ids=True,
            )
            tokens = encoded_dict['input_ids']
            type_ids = encoded_dict['token_type_ids']
            attn_mask = encoded_dict['attention_mask']

            # find new start_positions/end_positions, considering
            # 1. we add query tokens at the beginning
            # 2. special tokens
            doc_offset = len(truncated_query) + sequence_added_tokens
            new_start_positions = [x + doc_offset for x in tok_start_positions if (x + doc_offset) < self.max_length - 1]
            new_end_positions = [x + doc_offset if (x + doc_offset) < self.max_length - 1 else self.max_length - 2 for x in
                                 tok_end_positions]
            new_end_positions = new_end_positions[:len(new_start_positions)]

            label_mask = [0] * doc_offset + [1] * (len(tokens) - doc_offset - 1) + [0]


        assert all(label_mask[p] != 0 for p in new_start_positions)
        assert all(label_mask[p] != 0 for p in new_end_positions)

        assert len(label_mask) == len(tokens)

        seq_len = len(tokens)
        match_labels = torch.zeros([seq_len, seq_len], dtype=torch.long)
        for start, end in zip(new_start_positions, new_end_positions):
            if start >= seq_len or end >= seq_len:
                continue
            match_labels[start, end] = 1

        return [
            torch.LongTensor(tokens),
            torch.LongTensor(attn_mask),
            torch.LongTensor(type_ids),
            torch.LongTensor(label_mask),
            match_labels,
        ]

def read_bz2(addr):
    f = []
    with bz2.open(addr) as reader:
        for line in tqdm(reader, desc="reading from {}".format(addr)):
            js_line = json.loads(line)
            f.append(js_line)
    f = dict(f)
    return f

def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, orig_answer_text):
    """Returns tokenized answer spans that better match the annotated answer."""
    tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))

    for new_start in range(input_start, input_end + 1):
        for new_end in range(input_end, new_start - 1, -1):
            text_span = " ".join(doc_tokens[new_start : (new_end + 1)])
            if text_span == tok_answer_text:
                return (new_start, new_end)

    return (input_start, input_end)