# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import json
import torch.nn.functional as F
import torch
import pdb
import random
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset

logger = logging.getLogger(__name__)


class InputExample(object):
    """A single training/test example for token classification."""

    def __init__(self, guid, words, labels, hp_labels):
        """Constructs a InputExample.

        Args:
            guid: Unique id for the example.
            words: list. The words of the sequence.
            labels: (Optional) list. The labels for each word of the sequence. This should be
            specified for train and dev examples, but not for test examples.
        """
        self.guid = guid
        self.words = words
        self.labels = labels
        self.hp_labels = hp_labels


class InputFeatures(object):
    """A single set of features of data."""

    def __init__(self, input_ids, input_mask, segment_ids, label_ids, full_label_ids, hp_label_ids):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.label_ids = label_ids
        self.full_label_ids = full_label_ids
        self.hp_label_ids = hp_label_ids


def read_examples_from_file(data_dir, mode, is_json):
    # import pdb
    # pdb.set_trace()

    if is_json:

        file_path = os.path.join(data_dir, "{}.json".format(mode))
    else:
        file_path = os.path.join(data_dir, "{}.txt".format(mode))
    guid_index = 1
    examples = []

    if is_json:

        with open(file_path, 'r') as f:
            data = json.load(f)

            for item in data:
                words = item["str_words"]
                labels = item["tags"]
                if "tags_hp" in labels:
                    hp_labels = item["tags_hp"]
                else:
                    hp_labels = [None] * len(labels)
                examples.append(InputExample(guid="%s-%d".format(mode, guid_index), words=words, labels=labels,
                                             hp_labels=hp_labels))
                guid_index += 1

    else:

        with open(file_path, encoding="utf-8") as f:
            words = []
            labels = []
            for line in f:
                if line.startswith("-DOCSTART-") or line == "" or line == "\n":
                    if words:
                        hp_labels = [None] * len(labels)
                        examples.append(
                            InputExample(guid=f"{mode}-{guid_index}", words=words, labels=labels, hp_labels=hp_labels))
                        guid_index += 1
                        words = []
                        labels = []
                else:
                    splits = line.split(" ")
                    words.append(splits[0])
                    if len(splits) > 1:
                        labels.append(splits[-1].replace("\n", ""))
                    else:
                        # Examples could have no label for mode = "test"
                        labels.append("O")

            if words:
                hp_labels = [None] * len(labels)
                examples.append(InputExample(guid="%s-%d".format(mode, guid_index), words=words, labels=labels,
                                             hp_labels=hp_labels))
        # pdb.set_trace()
    return examples


def convert_examples_to_features(
        examples,
        label_list,
        max_seq_length,
        tokenizer,
        cls_token_at_end=False,
        cls_token="[CLS]",
        cls_token_segment_id=1,
        sep_token="[SEP]",
        sep_token_extra=False,
        pad_on_left=False,
        pad_token=0,
        pad_token_segment_id=0,
        pad_token_label_id=-100,
        sequence_a_segment_id=0,
        mask_padding_with_zero=True,
        show_exnum=-1,
):
    """ Loads a data file into a list of `InputBatch`s
        `cls_token_at_end` define the location of the CLS token:
            - False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
            - True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]
        `cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet)
    """
    label_map = {label: i for i, label in enumerate(label_list)}
    features = []
    extra_long_samples = 0
    for (ex_index, example) in enumerate(examples):
        if ex_index % 10000 == 0:
            logger.info("Writing example %d of %d", ex_index, len(examples))

        tokens = []
        label_ids = []
        full_label_ids = []
        hp_label_ids = []
        for word, label, hp_label in zip(example.words, example.labels, example.hp_labels):
            if isinstance(label, str):
                label = label_map[label]
            word_tokens = tokenizer.tokenize(word)
            tokens.extend(word_tokens)
            # Use the real label id for the first token of the word, and padding ids for the remaining tokens
            label_ids.extend([label] + [pad_token_label_id] * (len(word_tokens) - 1))
            hp_label_ids.extend([hp_label if hp_label is not None else pad_token_label_id] + [pad_token_label_id] * (
                        len(word_tokens) - 1))
            full_label_ids.extend([label] * len(word_tokens))

        # Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa.
        special_tokens_count = 3 if sep_token_extra else 2
        if len(tokens) > max_seq_length - special_tokens_count:
            tokens = tokens[: (max_seq_length - special_tokens_count)]
            label_ids = label_ids[: (max_seq_length - special_tokens_count)]
            hp_label_ids = hp_label_ids[: (max_seq_length - special_tokens_count)]
            full_label_ids = full_label_ids[: (max_seq_length - special_tokens_count)]
            extra_long_samples += 1

        # The convention in BERT is:
        # (a) For sequence pairs:
        #  tokens:   [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
        #  type_ids:   0   0  0    0    0     0       0   0   1  1  1  1   1   1
        # (b) For single sequences:
        #  tokens:   [CLS] the dog is hairy . [SEP]
        #  type_ids:   0   0   0   0  0     0   0
        #
        # Where "type_ids" are used to indicate whether this is the first
        # sequence or the second sequence. The embedding vectors for `type=0` and
        # `type=1` were learned during pre-training and are added to the wordpiece
        # embedding vector (and position vector). This is not *strictly* necessary
        # since the [SEP] token unambiguously separates the sequences, but it makes
        # it easier for the model to learn the concept of sequences.
        #
        # For classification tasks, the first vector (corresponding to [CLS]) is
        # used as as the "sentence vector". Note that this only makes sense because
        # the entire model is fine-tuned.
        tokens += [sep_token]
        label_ids += [pad_token_label_id]
        hp_label_ids += [pad_token_label_id]
        full_label_ids += [pad_token_label_id]
        if sep_token_extra:
            # roberta uses an extra separator b/w pairs of sentences
            tokens += [sep_token]
            label_ids += [pad_token_label_id]
            hp_label_ids += [pad_token_label_id]
            full_label_ids += [pad_token_label_id]
        segment_ids = [sequence_a_segment_id] * len(tokens)

        if cls_token_at_end:
            tokens += [cls_token]
            label_ids += [pad_token_label_id]
            hp_label_ids += [pad_token_label_id]
            full_label_ids += [pad_token_label_id]
            segment_ids += [cls_token_segment_id]
        else:
            tokens = [cls_token] + tokens
            label_ids = [pad_token_label_id] + label_ids
            hp_label_ids = [pad_token_label_id] + hp_label_ids
            full_label_ids = [pad_token_label_id] + full_label_ids
            segment_ids = [cls_token_segment_id] + segment_ids

        input_ids = tokenizer.convert_tokens_to_ids(tokens)

        # The mask has 1 for real tokens and 0 for padding tokens. Only real
        # tokens are attended to.
        input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)

        # Zero-pad up to the sequence length.
        padding_length = max_seq_length - len(input_ids)
        if pad_on_left:
            input_ids = ([pad_token] * padding_length) + input_ids
            input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask
            segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids
            label_ids = ([pad_token_label_id] * padding_length) + label_ids
            hp_label_ids = ([pad_token_label_id] * padding_length) + hp_label_ids
            full_label_ids = ([pad_token_label_id] * padding_length) + full_label_ids
        else:
            input_ids += [pad_token] * padding_length
            input_mask += [0 if mask_padding_with_zero else 1] * padding_length
            segment_ids += [pad_token_segment_id] * padding_length
            label_ids += [pad_token_label_id] * padding_length
            hp_label_ids += [pad_token_label_id] * padding_length
            full_label_ids += [pad_token_label_id] * padding_length

        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(segment_ids) == max_seq_length
        assert len(label_ids) == max_seq_length
        assert len(hp_label_ids) == max_seq_length
        assert len(full_label_ids) == max_seq_length

        if ex_index < show_exnum:
            logger.info("*** Example ***")
            logger.info("guid: %s", example.guid)
            logger.info("tokens: %s", " ".join([str(x) for x in tokens]))
            logger.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
            logger.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
            logger.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
            logger.info("label_ids: %s", " ".join([str(x) for x in label_ids]))
            logger.info("hp_label_ids: %s", " ".join([str(x) for x in hp_label_ids]))
            logger.info("full_label_ids: %s", " ".join([str(x) for x in full_label_ids]))

        features.append(
            InputFeatures(input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, label_ids=label_ids,
                          full_label_ids=full_label_ids, hp_label_ids=hp_label_ids)
        )
    logger.info("Extra long example %d of %d", extra_long_samples, len(examples))
    return features


def load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, mode, raw_feature=False, final=True):
    if args.local_rank not in [-1, 0] and not evaluate:
        torch.distributed.barrier()  # Make sure only the first process in distributed training process the dataset, and the others will use the cache

    # Load data features from cache or dataset file
    cached_features_file = os.path.join(
        args.data_dir,
        "cached_{}_{}_{}_{}".format(
            mode, list(filter(None, args.model_name_or_path.split("/"))).pop(), str(args.max_seq_length),
            str(args.n_shot)
        ),
    )

    if os.path.exists(cached_features_file) and not args.overwrite_cache:
        logger.info("Loading features from cached file %s", cached_features_file)
        features = torch.load(cached_features_file)
    else:
        logger.info("Creating features from dataset file at %s", args.data_dir)
        examples = read_examples_from_file(args.data_dir, mode, args.is_json)
        features = convert_examples_to_features(
            examples,
            labels,
            args.max_seq_length,
            tokenizer,
            cls_token_at_end=bool(args.model_type in ["xlnet"]),
            # xlnet has a cls token at the end
            cls_token=tokenizer.cls_token,
            cls_token_segment_id=2 if args.model_type in ["xlnet"] else 0,
            sep_token=tokenizer.sep_token,
            sep_token_extra=bool(args.model_type in ["roberta"]),
            # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
            pad_on_left=bool(args.model_type in ["xlnet"]),
            # pad on the left for xlnet
            pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
            pad_token_segment_id=4 if args.model_type in ["xlnet"] else 0,
            pad_token_label_id=pad_token_label_id,
        )
        if args.local_rank in [-1, 0]:
            logger.info("Saving features into cached file %s", cached_features_file)
            torch.save(features, cached_features_file)

    if args.local_rank == 0 and not evaluate:
        torch.distributed.barrier()  # Make sure only the first process in distributed training process the dataset, and the others will use the cache

    if raw_feature:
        return features

    if not final and len(features) > 100000:
        percentage = 0.05
        random.shuffle(features)
        count = int(percentage * len(features))
        logger.info("Use a partial dataset with size %d", count)
        logger.info("Whole dataset with size %d", len(features))

        features = features[:count]

    # Convert to Tensors and build dataset
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
    all_label_ids = torch.tensor([f.label_ids for f in features], dtype=torch.long)
    all_full_label_ids = torch.tensor([f.full_label_ids for f in features], dtype=torch.long)
    all_hp_label_ids = torch.tensor([f.hp_label_ids for f in features], dtype=torch.long)
    all_ids = torch.tensor([f for f in range(len(features))], dtype=torch.long)

    dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_full_label_ids,
                            all_hp_label_ids, all_ids)
    return dataset


def get_labels(path=None):
    if path and os.path.exists(path + "tag_to_id.json"):
        labels = []
        with open(path + "tag_to_id.json", "r") as f:
            data = json.load(f)
            for l, _ in data.items():
                labels.append(l)
        if "O" not in labels:
            labels = ["O"] + labels
        return labels
    elif path:
        with open(path + "labels.txt", "r") as f:
            labels = f.read().splitlines()
        if "O" not in labels:
            labels = ["O"] + labels
        new_labels = [l for l in labels if len(l) > 0]
        return new_labels
    else:
        return ["O", "B-LOC", "B-ORG", "B-PER", "B-MISC", "I-PER", "I-MISC", "I-ORG", "I-LOC"]


def tag_to_id(path=None):
    if path and os.path.exists(path + "tag_to_id.json"):
        with open(path + "tag_to_id.json", 'r') as f:
            data = json.load(f)
        return data
    elif path:
        with open(path + "labels.txt", "r") as f:
            labels = f.read().splitlines()
        if "O" not in labels:
            labels = ["O"] + labels
        data = {}
        for i, label in enumerate(labels):
            data[label] = i
        return data
    else:
        return {"O": 0, "B-LOC": 1, "B-ORG": 2, "B-PER": 3, "B-MISC": 4, "I-PER": 5, "I-MISC": 6, "I-ORG": 7,
                "I-LOC": 8}


def get_chunk_type(tok, idx_to_tag):
    """
    The function takes in a chunk ("B-PER") and then splits it into the tag (PER) and its class (B)
    as defined in BIOES

    Args:
        tok: id of token, ex 4
        idx_to_tag: dictionary {4: "B-PER", ...}

    Returns:
        tuple: "B", "PER"

    """
    tag_name = idx_to_tag[tok]
    tag_class = tag_name.split('-')[0]
    tag_type = tag_name.split('-')[-1]
    return tag_class, tag_type


def get_chunks(seq, tags):
    """Given a sequence of tags, group entities and their position

    Args:
        seq: [4, 4, 0, 0, ...] sequence of labels
        tags: dict["O"] = 4

    Returns:
        list of (chunk_type, chunk_start, chunk_end)

    Example:
        seq = [4, 5, 0, 3]
        tags = {"B-PER": 4, "I-PER": 5, "B-LOC": 3}
        result = [("PER", 0, 2), ("LOC", 3, 4)]

    """
    default = tags["O"]
    idx_to_tag = {idx: tag for tag, idx in tags.items()}
    chunks = []

    chunk_type, chunk_start = None, None
    for i, tok in enumerate(seq):
        if tok == default and chunk_type is not None:
            chunk = (chunk_type, chunk_start, i)
            chunks.append(chunk)
            chunk_type, chunk_start = None, None

        elif tok != default:
            tok_chunk_class, tok_chunk_type = get_chunk_type(tok, idx_to_tag)
            if chunk_type is None:
                chunk_type, chunk_start = tok_chunk_type, i
            elif tok_chunk_type != chunk_type or tok_chunk_class == "B":
                chunk = (chunk_type, chunk_start, i)
                chunks.append(chunk)
                chunk_type, chunk_start = tok_chunk_type, i
        else:
            pass

    if chunk_type is not None:
        chunk = (chunk_type, chunk_start, len(seq))
        chunks.append(chunk)
    return chunks


def split_data(features, split_ratio=0.5, mode='random'):
    labeled_list = []
    meta_list = []

    if mode == 'random':
        labeled_list = random.sample(range(len(features)), int(split_ratio * (len(features))))
        meta_list = set([*range(len(features))]) - set(labeled_list)
        meta_list = list(meta_list)

    labeled_dict = {l: True for l in labeled_list}
    meta_dict = {l: True for l in meta_list}

    labeled_input_ids = torch.tensor([f.input_ids for i, f in enumerate(features) if i in labeled_dict],
                                     dtype=torch.long)
    labeled_input_mask = torch.tensor([f.input_mask for i, f in enumerate(features) if i in labeled_dict],
                                      dtype=torch.long)
    labeled_segment_ids = torch.tensor([f.segment_ids for i, f in enumerate(features) if i in labeled_dict],
                                       dtype=torch.long)
    labeled_label_ids = torch.tensor([f.label_ids for i, f in enumerate(features) if i in labeled_dict],
                                     dtype=torch.long)
    labeled_full_label_ids = torch.tensor([f.full_label_ids for i, f in enumerate(features) if i in labeled_dict],
                                          dtype=torch.long)
    labeled_hp_label_ids = torch.tensor([f.hp_label_ids for i, f in enumerate(features) if i in labeled_dict],
                                        dtype=torch.long)
    labeled_ids = torch.tensor([f for f in range(len(features)) if f in labeled_dict], dtype=torch.long)

    meta_input_ids = torch.tensor([f.input_ids for i, f in enumerate(features) if i in meta_dict], dtype=torch.long)
    meta_input_mask = torch.tensor([f.input_mask for i, f in enumerate(features) if i in meta_dict], dtype=torch.long)
    meta_segment_ids = torch.tensor([f.segment_ids for i, f in enumerate(features) if i in meta_dict], dtype=torch.long)
    meta_label_ids = torch.tensor([f.label_ids for i, f in enumerate(features) if i in meta_dict], dtype=torch.long)
    meta_full_label_ids = torch.tensor([f.full_label_ids for i, f in enumerate(features) if i in meta_dict],
                                       dtype=torch.long)
    meta_hp_label_ids = torch.tensor([f.hp_label_ids for i, f in enumerate(features) if i in meta_dict],
                                     dtype=torch.long)
    meta_ids = torch.tensor([f for f in range(len(features)) if f in meta_dict], dtype=torch.long)

    labeled_dataset = TensorDataset(labeled_input_ids, labeled_input_mask, labeled_segment_ids, labeled_label_ids,
                                    labeled_full_label_ids,
                                    labeled_hp_label_ids, labeled_ids)

    meta_dataset = TensorDataset(meta_input_ids, meta_input_mask, meta_segment_ids, meta_label_ids, meta_full_label_ids,
                                 meta_hp_label_ids, meta_ids)

    return labeled_dataset, meta_dataset


if __name__ == '__main__':
    save(args)