# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random
import warnings
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
from numpy.lib.function_base import extract

import torch
from torch.nn.utils.rnn import pad_sequence

from transformers.file_utils import PaddingStrategy
from transformers import PreTrainedModel
from transformers import BertTokenizer, BertTokenizerFast
from transformers import BatchEncoding, PreTrainedTokenizerBase

from pathlib import Path
from tqdm import tqdm, trange
from tempfile import TemporaryDirectory
import shelve
import spacy
from random import random, randrange, randint, shuffle, choice, sample
from pytorch_pretrained_bert.tokenization import BertTokenizer
import numpy as np
import json
import wandb
import networkx as nx
from functools import reduce
import logging
# from svo_extraction import extractSVOs
nlp = spacy.load("en")
from nltk.tokenize import sent_tokenize, word_tokenize
import time
import os
import pickle
import itertools

from predefined_connectives import predefined_connectives
from svo_extraction import extractSVOs
connectives = [i for i in predefined_connectives.keys()]
# the longest of predefined connectives consists of four words
unigram_connective = []
bigram_connective = []
trigram_connective = []
forgram_connective = []

for i in predefined_connectives.keys():
    if len(i.split(" ")) == 1:
        unigram_connective.append(i)
    elif len(i.split(" ")) == 2:
        bigram_connective.append(i.split(" "))
    elif len(i.split(" ")) == 3:
        trigram_connective.append(i.split(" "))
    elif len(i.split(" ")) == 4:
        forgram_connective.append(i.split(" "))


InputDataClass = NewType("InputDataClass", Any)

"""
A DataCollator is a function that takes a list of samples from a Dataset and collate them into a batch, as a dictionary
of Tensors.
"""
DataCollator = NewType("DataCollator", Callable[[List[InputDataClass]], Dict[str, torch.Tensor]])


def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Tensor]:
    """
    Very simple data collator that simply collates batches of dict-like objects and performs special handling for
    potential keys named:
        - ``label``: handles a single value (int or float) per object
        - ``label_ids``: handles a list of values per object
    Does not do any additional preprocessing: property names of the input object will be used as corresponding inputs
    to the model. See glue and ner for example of how it's useful.
    """

    # In this function we'll make the assumption that all `features` in the batch
    # have the same attributes.
    # So we will look at the first element as a proxy for what attributes exist
    # on the whole batch.
    if not isinstance(features[0], (dict, BatchEncoding)):
        features = [vars(f) for f in features]

    first = features[0]
    batch = {}

    # Special handling for labels.
    # Ensure that tensor is created with the correct type
    # (it should be automatically the case, but let's make sure of it.)
    if "label" in first and first["label"] is not None:
        label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"]
        dtype = torch.long if isinstance(label, int) else torch.float
        batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
    elif "label_ids" in first and first["label_ids"] is not None:
        if isinstance(first["label_ids"], torch.Tensor):
            batch["labels"] = torch.stack([f["label_ids"] for f in features])
        else:
            dtype = torch.long if type(first["label_ids"][0]) is int else torch.float
            batch["labels"] = torch.tensor([f["label_ids"] for f in features], dtype=dtype)

    # Handling of all other possible keys.
    # Again, we will use the first element to figure out which key/values are not None for this model.
    for k, v in first.items():
        if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
            if isinstance(v, torch.Tensor):
                batch[k] = torch.stack([f[k] for f in features])
            else:
                batch[k] = torch.tensor([f[k] for f in features])

    return batch


def _collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None):
    """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
    # Tensorize if necessary.
    if isinstance(examples[0], (list, tuple)):
        examples = [torch.tensor(e, dtype=torch.long) for e in examples]

    # Check if padding is necessary.
    length_of_first = examples[0].size(0)
    are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
    if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):
        return torch.stack(examples, dim=0)

    # If yes, check if we have a `pad_token`.
    if tokenizer._pad_token is None:
        raise ValueError(
            "You are attempting to pad samples but the tokenizer you are using"
            f" ({tokenizer.__class__.__name__}) does not have a pad token."
        )

    # Creating the full tensor and filling it with our data.
    max_length = max(x.size(0) for x in examples)
    if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
        max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
    result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id)
    for i, example in enumerate(examples):
        if tokenizer.padding_side == "right":
            result[i, : example.shape[0]] = example
        else:
            result[i, -example.shape[0] :] = example
    return result


def tolist(x: Union[List[Any], torch.Tensor]):
    return x.tolist() if isinstance(x, torch.Tensor) else x


@dataclass
class DataCollatorForLanguageModelingPath:
    """
    Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
    are not all of the same length.
    Args:
        tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
            The tokenizer used for encoding the data.
        mlm (:obj:`bool`, `optional`, defaults to :obj:`True`):
            Whether or not to use masked language modeling. If set to :obj:`False`, the labels are the same as the
            inputs with the padding tokens ignored (by setting them to -100). Otherwise, the labels are -100 for
            non-masked tokens and the value to predict for the masked token.
        mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
            The probability with which to (randomly) mask tokens in the input, when :obj:`mlm` is set to :obj:`True`.
        pad_to_multiple_of (:obj:`int`, `optional`):
            If set will pad the sequence to a multiple of the provided value.
    .. note::
        For best performance, this data collator should be used with a dataset having items that are dictionaries or
        BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a
        :class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the
        argument :obj:`return_special_tokens_mask=True`.
    """

    tokenizer: PreTrainedTokenizerBase
    mlm: bool = True
    mlm_probability: float = 0.15
    pad_to_multiple_of: Optional[int] = None
    max_predictions_per_seq : Optional[int] = None
    vocab_list: Optional[List] = None

    def __post_init__(self):
        if self.mlm and self.tokenizer.mask_token is None:
            raise ValueError(
                "This tokenizer does not have a mask token which is necessary for masked language modeling. "
                "You should pass `mlm=False` to train on causal language modeling instead."
            )

    def __call__(
        self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
    ) -> Dict[str, torch.Tensor]:

        # Handle dict or lists with proper padding and conversion to tensor.
        if isinstance(examples[0], (dict, BatchEncoding)):
            batch = self.tokenizer.pad(examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of)
        else:
            batch = {"input_ids": _collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)}
        batch_tokens = [self.tokenizer.convert_ids_to_tokens(i) for i in batch["input_ids"].tolist()]
        
        # If special token mask has been preprocessed, pop it from the dict.
        special_tokens_mask = batch.pop("special_tokens_mask", None)

        if self.mlm:
            input_ids = []
            masked_labels = []
            paths = []

            for item in batch_tokens:
                input_id, label, path = self.create_masked_lm_predictions(
                    item, masked_lm_prob=self.mlm_probability, max_predictions_per_seq=self.max_predictions_per_seq, vocab_list=self.vocab_list,
                )
                input_ids.append(input_id)
                masked_labels.append(label)
                paths.append(path)
            batch["input_ids"] = torch.tensor(input_ids)
            batch["labels"] = torch.tensor(masked_labels)
            batch["paths"] = torch.tensor(paths)
            # batch["input_ids"], batch["labels"] = self.create_masked_lm_predictions(
            #     batch["tokens"], masked_lm_prob=self.mlm_probability, max_predictions_per_seq=self.max_predictions_per_seq, vocab_list=self.vocab_list,
            # )

        else:
            labels = batch["input_ids"].clone()
            if self.tokenizer.pad_token_id is not None:
                labels[labels == self.tokenizer.pad_token_id] = -100
            batch["labels"] = labels
        return batch

    def create_masked_lm_predictions(self, tokens, masked_lm_prob, max_predictions_per_seq, vocab_list):
        """Creates the predictions for the masked LM objective. This is mostly copied from the Google BERT repo, but
        with several refactors to clean it up and remove a lot of unnecessary variables."""

        svo_ids = extractSVOs(nlp, tokens, offset=1)
        labels = self.tokenizer.convert_tokens_to_ids(tokens)

        svo_ids = svo_ids[:10] # 10 is the maximum number of svos

        nodes = []
        paths_positive = []
        paths_negative = []
        G = nx.Graph()
        for svo in svo_ids:
            if svo[0] not in nodes:
                nodes.append(svo[0])
            if svo[1] not in nodes:
                nodes.append(svo[1])
            if svo[2] not in nodes:
                nodes.append(svo[2])
            G.add_edge(svo[0], svo[1])
            G.add_edge(svo[1], svo[2])
        # nodes_sampled = sample(nodes, int(len(nodes)*0.3))
        nodes_sampled = nodes
        permutation_combination = list(itertools.combinations(nodes_sampled, 2))
        for item in permutation_combination:
            if nx.has_path(G, item[0], item[1]):
                paths_positive.append([item[0], item[1], 1])
            else:
                paths_negative.append([item[0], item[1], 0])
        if len(paths_negative) >= len(paths_positive):
            paths = sample(paths_negative, len(paths_positive)) + paths_positive
        else:
            paths = sample(paths_positive, len(paths_negative)) + paths_negative

        if paths_negative == []:
            paths = paths_positive
        elif paths_positive == []:
            paths = paths_negative

        shuffle(paths)
        
        paths = paths[:6]                        ## Causion: 6 is the maximum number of paths
        while len(paths) < 6:
            paths.append([-1, -1, -1])
        
        # paths = torch.Tensor(paths)   # [6,3]

        cand_indices = []

        cand_indexes = []
        tokens_word = []
        for (i, token) in enumerate(tokens):
            if (len(cand_indexes) >= 1 and token.startswith("##")):
                cand_indexes[-1].append(i)
                tokens_word[-1] = tokens_word[-1]+token.replace("##", "")
            else:
                cand_indexes.append([i])
                tokens_word.append(token)

        ### find connectives with respect to the number of words in a connective ###
        # uni_indexes = []
        mask_connective_indexes = []
        for w_index, word in enumerate(tokens_word):
            if word in unigram_connective:
                mask_connective_indexes.extend(cand_indexes[w_index])
                # uni_indexes.append(cand_indexes[w_index])

        # bi_indexes = []
        for w_index, word in enumerate(tokens_word):
            if w_index == len(tokens_word) - 1:
                break
            if [tokens_word[w_index], tokens_word[w_index + 1]] in bigram_connective:
                mask_connective_indexes.extend(cand_indexes[w_index])
                mask_connective_indexes.extend(cand_indexes[w_index + 1])
                # bi_indexes.append([cand_indexes[w_index], cand_indexes[w_index+1]])
        
        # tri_indexes = []
        for w_index, word in enumerate(tokens_word):
            if w_index == len(tokens_word) - 2:
                break
            if [tokens_word[w_index], tokens_word[w_index + 1], tokens_word[w_index + 2]] in trigram_connective:
                mask_connective_indexes.extend(cand_indexes[w_index])
                mask_connective_indexes.extend(cand_indexes[w_index + 1])
                mask_connective_indexes.extend(cand_indexes[w_index + 2])
                # tri_indexes.append([cand_indexes[w_index], cand_indexes[w_index+1], cand_indexes[w_index + 2]])
        
        # for_indexes = []
        for w_index, word in enumerate(tokens_word):
            if w_index == len(tokens_word) - 3:
                break
            if [tokens_word[w_index], tokens_word[w_index + 1], tokens_word[w_index + 2], tokens_word[w_index + 3]] in trigram_connective:
                # for_indexes.append([cand_indexes[w_index], cand_indexes[w_index+1], cand_indexes[w_index + 2], cand_indexes[w_index + 3]])
                mask_connective_indexes.extend(cand_indexes[w_index])
                mask_connective_indexes.extend(cand_indexes[w_index + 1])
                mask_connective_indexes.extend(cand_indexes[w_index + 2])
                mask_connective_indexes.extend(cand_indexes[w_index + 3])

        
        svo_indexes = []
        for svo_item in svo_ids:
            svo_indexes.append(svo_item[0] if np.random.randint(0,2) else svo_item[2])
        
        for (i, token) in enumerate(tokens):
            if token == "[CLS]" or token == "[SEP]":
                continue
            cand_indices.append(i)
        

        num_to_mask = min(max_predictions_per_seq,
                        max(1, int(round(len(tokens) * masked_lm_prob)))) - len(mask_connective_indexes) - len(svo_indexes)
        
        try:
            assert num_to_mask < len(cand_indices)
        except:
            print(num_to_mask)
            print("-------------------")
            print(cand_indices)

        shuffle(cand_indices)
        if num_to_mask >= 0:
            mask_indices = sample(cand_indices, num_to_mask)
        else:
            mask_indices = []

        while len(set(mask_indices) & set(mask_connective_indexes)) > 0:
            mask_indices.extend(sample(cand_indices, len(set(mask_indices) & set(mask_connective_indexes))))
            a = set(mask_indices) - (set(mask_indices) & set(mask_connective_indexes))
            mask_indices = list(a)

        mask_indices_total = sorted(mask_indices + mask_connective_indexes + svo_indexes)


        masked_token_labels = []

        for index in mask_indices_total:
            if random() < 0.8:
                masked_token = "[MASK]"
            else:
                # 10% of the time, keep original
                if random() < 0.5:
                    masked_token = tokens[index]
                # 10% of the time, replace with random word
                else:
                    masked_token = choice(vocab_list)
            try:
                masked_token_labels.append(tokens[index])
            except:
                print(len(tokens))
                print(index)
                print(svo_ids)
                print(mask_indices)
                print(mask_connective_indexes)
                print(svo_indexes)
                input()
            # Once we've saved the true label for that token, we can overwrite it with the masked version
            tokens[index] = masked_token

        assert len(mask_indices_total) == len(masked_token_labels)
        inputs = self.tokenizer.convert_tokens_to_ids(tokens)

        for idx, item in enumerate(labels):
            if idx not in mask_indices_total:
                labels[idx] = -100

        return inputs, labels, paths

    def mask_tokens(
        self, inputs: torch.Tensor, special_tokens_mask: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
        """
        labels = inputs.clone()
        # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
        probability_matrix = torch.full(labels.shape, self.mlm_probability)
        if special_tokens_mask is None:
            special_tokens_mask = [
                self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
            ]
            special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
        else:
            special_tokens_mask = special_tokens_mask.bool()

        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
        masked_indices = torch.bernoulli(probability_matrix).bool()
        labels[~masked_indices] = -100  # We only compute loss on masked tokens

        # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
        indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
        inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)

        # 10% of the time, we replace masked input tokens with random word
        indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
        random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
        inputs[indices_random] = random_words[indices_random]

        # The rest of the time (10% of the time) we keep the masked input tokens unchanged
        return inputs, labels