# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Named entity recognition fine-tuning: utilities to work with CoNLL-2003 task. """

from __future__ import absolute_import, division, print_function

import logging
import os
from io import open
import random
import functools
from sklearn.metrics import classification_report
import src.core.configuration.fine_tuning_conf as ft_conf 

logger = logging.getLogger(__name__)


class InputExample(object):
    """A single training/test example for token classification."""

    def __init__(self, guid, words, labels):
        """Constructs a InputExample.
        Args:
            guid: Unique id for the example.
            words: list. The words of the sequence.
            labels: (Optional) list. The labels for each word of the sequence. This should be
            specified for train and dev examples, but not for test examples.
        """
        self.guid = guid
        self.words = words
        self.labels = labels


class InputFeatures(object):
    """A single set of features of data."""

    def __init__(self, input_ids, input_mask, segment_ids, label_ids):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.label_ids = label_ids


def read_examples_from_file(data_dirs, mode, batch_size, is_hyper_parameter_search = False):
    total_examples = []
    for data_dir in data_dirs:
        examples = []
        file_path = os.path.join(data_dir, "{}.txt".format(mode))
        print(file_path)
        guid_index = 1
        
        with open(file_path, encoding="utf-8") as f:
            words = []
            labels = []
            for line in f:
                if line.startswith("-DOCSTART-") or line == "" or line == "\n":
                    if words:
                        examples.append(InputExample(guid="{}-{}".format(mode, guid_index),
                                                    words=words,
                                                    labels=labels))
                        guid_index += 1
                        words = []
                        labels = []
                else:
                    splits = line.split(" ")
                    words.append(splits[0].lower())
                    if len(splits) > 1:
                        labels.append(splits[-1].replace("\n", "").strip())
                    else:
                        # Examples could have no label for mode = "test"
                        labels.append("O")
                if guid_index > 4 * ft_conf.HP_SEARCH_DATA_SIZE:
                    break
            if words:
                examples.append(InputExample(guid="%s-%d".format(mode, guid_index),
                                            words=words,
                                            labels=labels))
                

            total_examples.append(examples)
    min_len = min([len(ex) for ex in total_examples])
    [random.shuffle(ex) for ex in total_examples]
    total_examples = [ex[:min_len] for ex in total_examples]
    total_examples = functools.reduce(lambda a, b: a+b, total_examples)
    extra_example_size = (len(total_examples) % batch_size)
    print("Number of examples: {}".format(len(total_examples[:-(len(total_examples) % batch_size)])))
    random.shuffle(total_examples)
    if is_hyper_parameter_search:
        total_examples = total_examples[:ft_conf.HP_SEARCH_DATA_SIZE]
        return total_examples
    return total_examples if extra_example_size == 0 else total_examples[:-(len(total_examples) % batch_size)]


def convert_examples_to_features(examples,
                                 label_list,
                                 max_seq_length,
                                 tokenizer,
                                 cls_token_at_end=False,
                                 cls_token="[CLS]",
                                 cls_token_segment_id=1,
                                 sep_token="[SEP]",
                                 sep_token_extra=False,
                                 pad_on_left=False,
                                 pad_token=0,
                                 pad_token_segment_id=0,
                                 pad_token_label_id=-1,
                                 sequence_a_segment_id=0,
                                 mask_padding_with_zero=True):
    """ Loads a data file into a list of `InputBatch`s
        `cls_token_at_end` define the location of the CLS token:
            - False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
            - True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]
        `cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet)
    """

    label_map = {label: i for i, label in enumerate(label_list)}
    print('label map', label_map)

    features = []
    # examples = examples[:10000]
    for (ex_index, example) in enumerate(examples):
        #if ex_index % 10000 == 0:
        #    print("Writing example %d of %d", ex_index, len(examples))

        tokens = []
        label_ids = []
        i = 0
        for word, label in zip(example.words, example.labels):
            # print('word', word)
            # print('label', label)
            if label == '' or label == ' ' or word == '' or word == ' ':
                # print(i)
                # print(example.words)
                # print(example.labels)
                continue
            i += 1
            word_tokens = tokenizer.tokenize(word)
            tokens.extend(word_tokens)
            # print(word)
            # Use the real label id for the first token of the word, and padding ids for the remaining tokens
            label_ids.extend([label_map[label]] + [pad_token_label_id] * (len(word_tokens) - 1))

        # Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa.
        special_tokens_count = 3 if sep_token_extra else 2
        if len(tokens) > max_seq_length - special_tokens_count:
            tokens = tokens[:(max_seq_length - special_tokens_count)]
            label_ids = label_ids[:(max_seq_length - special_tokens_count)]

        # The convention in BERT is:
        # (a) For sequence pairs:
        #  tokens:   [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
        #  type_ids:   0   0  0    0    0     0       0   0   1  1  1  1   1   1
        # (b) For single sequences:
        #  tokens:   [CLS] the dog is hairy . [SEP]
        #  type_ids:   0   0   0   0  0     0   0
        #
        # Where "type_ids" are used to indicate whether this is the first
        # sequence or the second sequence. The embedding vectors for `type=0` and
        # `type=1` were learned during pre-training and are added to the wordpiece
        # embedding vector (and position vector). This is not *strictly* necessary
        # since the [SEP] token unambiguously separates the sequences, but it makes
        # it easier for the model to learn the concept of sequences.
        #
        # For classification tasks, the first vector (corresponding to [CLS]) is
        # used as as the "sentence vector". Note that this only makes sense because
        # the entire model is fine-tuned.
        tokens += [sep_token]
        label_ids += [pad_token_label_id]
        if sep_token_extra:
            # roberta uses an extra separator b/w pairs of sentences
            tokens += [sep_token]
            label_ids += [pad_token_label_id]
        segment_ids = [sequence_a_segment_id] * len(tokens)

        if cls_token_at_end:
            tokens += [cls_token]
            label_ids += [pad_token_label_id]
            segment_ids += [cls_token_segment_id]
        else:
            tokens = [cls_token] + tokens
            label_ids = [pad_token_label_id] + label_ids
            segment_ids = [cls_token_segment_id] + segment_ids

        input_ids = tokenizer.convert_tokens_to_ids(tokens)


        # The mask has 1 for real tokens and 0 for padding tokens. Only real
        # tokens are attended to.
        input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)

        # Zero-pad up to the sequence length.
        padding_length = max_seq_length - len(input_ids)
        padding_length_label = max_seq_length - len(label_ids)
        # print(padding_length, padding_length_label)
        if pad_on_left:
            input_ids = ([pad_token] * padding_length) + input_ids
            input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask
            segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids
            label_ids = ([pad_token_label_id] * padding_length_label) + label_ids
        else:
            input_ids += ([pad_token] * padding_length)
            input_mask += ([0 if mask_padding_with_zero else 1] * padding_length)
            segment_ids += ([pad_token_segment_id] * padding_length)
            # print(len(label_ids))
            label_ids += ([pad_token_label_id] * padding_length_label)
            # print(len(label_ids))

        # print('\n\n\nMax seq len:', max_seq_length, 'Label ids len',  len(label_ids))
        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(segment_ids) == max_seq_length
        assert len(label_ids) == max_seq_length

        #if ex_index < 5:
        #    print("*** Example ***")
        #    print("guid: %s", example.guid)
        #    print("tokens: %s", " ".join([str(x) for x in tokens]))
        #    print("input_ids: %s", " ".join([str(x) for x in input_ids]))
        #    print("input_mask: %s", " ".join([str(x) for x in input_mask]))
        #    print("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
        #    print("label_ids: %s", " ".join([str(x) for x in label_ids]))

        features.append(
                InputFeatures(input_ids=input_ids,
                              input_mask=input_mask,
                              segment_ids=segment_ids,
                              label_ids=label_ids))
    return features


def get_labels(path):
    if path:
        with open(path, "r") as f:
            labels = f.read().splitlines()
        if "O" not in labels:
            labels = ["O"] + labels
        return labels
    else:
        return ["O", "B-MISC", "I-MISC",  "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC",
                "I-LOC", "B-ENT", "I-ENT", "B-ATR", "I-ATR", "B-AGG", "I-AGG", "B-FLT", 
                "I-FLT", "B-FLO", "I-FLO", "B-PRW", "I-PRW", "B-NUM", "I-NUM"]



#compute precision, recall and f1 from prediction and ground truth.
def compute_f1_precision_recall(predictions, ground_truths):
    true_positive_overall = 0
    true_positive_ent = 0
    true_positive_atr = 0
    true_positive_flt = 0
    true_positive_agg = 0
    true_positive_flo = 0
    false_positive_overall = 0
    false_positive_ent = 0
    false_positive_atr = 0
    false_positive_flt = 0
    false_positive_agg = 0
    false_positive_flo = 0
    false_negetive_overall = 0
    false_negetive_ent = 0
    false_negetive_atr = 0
    false_negetive_flt = 0
    false_negetive_flo = 0
    false_negetive_agg = 0

    true_negetive_overall = 0

    for i in range(0, len(ground_truths)):
        pred = predictions[i]
        g_truth = ground_truths[i]
        length = len(pred)
        for j in range(0, length):
            if pred[j] != 'O':
                if pred[j] == g_truth[j]:
                    true_positive_overall += 1
                    if pred[j] == 'I-AGG':
                        true_positive_agg += 1
                    if pred[j] == 'I-FLO':
                        true_positive_flo += 1
                    if pred[j] == 'I-FLT':
                        true_positive_flt += 1
                    if pred[j] == 'I-ATR':
                        true_positive_atr += 1
                    if pred[j] == 'I-ENT':
                        true_positive_ent += 1
                else:
                    false_positive_overall += 1
                    if pred[j] == 'I-AGG':
                        false_positive_agg += 1
                    if pred[j] == 'I-FLO':
                        false_positive_flo += 1
                    if pred[j] == 'I-FLT':
                        false_positive_flt += 1
                        # print(g_truth[j])
                    if pred[j] == 'I-ATR':
                        false_positive_atr += 1
                    if pred[j] == 'I-ENT':
                        false_positive_ent += 1
            else:
                if g_truth[j] != 'O':
                    false_negetive_overall += 1
                    if g_truth[j] == 'I-AGG':
                        false_negetive_agg += 1
                    if g_truth[j] == 'I-FLO':
                        false_negetive_flo += 1
                    if g_truth[j] == 'I-FLT':
                        false_negetive_flt += 1
                    if g_truth[j] == 'I-ATR':
                        false_negetive_atr += 1
                    if g_truth[j] == 'I-ENT':
                        false_negetive_ent += 1
                elif pred[j] == g_truth[j]:
                    continue
                else:
                    true_negetive_overall += 1
    print('true_positive_overall', true_positive_overall ,
    'true_positive_ent' , true_positive_ent,
    'true_positive_atr' , true_positive_atr,
    'false_positive_overall' , false_positive_overall,
    'false_positive_ent' , false_positive_ent,
    'false_positive_atr' , false_positive_atr,
    'false_positive_flt' , false_positive_flt,
    'false_negetive_overall' , false_negetive_overall,
    'false_negetive_ent' , false_negetive_ent,
    'false_negetive_atr', false_negetive_atr)

    if (true_positive_overall + false_positive_overall) == 0:
        precision = 0
    else:
        precision = true_positive_overall / (true_positive_overall + false_positive_overall)
    if (true_positive_overall + false_negetive_overall) == 0:
        recall = 0
    else:
        recall = true_positive_overall / (true_positive_overall + false_negetive_overall)
    if precision == 0 and recall == 0:
        f_1 = 0
    else:
        f_1 = (2 * precision * recall) / (precision + recall)

    print(('\nPrecision: ' + str(precision)))
    print(('\nRecall: ' + str(recall)))
    print(('\nF1: ' + str(f_1)))
    return precision, recall, f_1

    # if (true_positive_ent + false_positive_ent) == 0:
    #     precision = 0
    # else:
    #     precision = true_positive_ent / (true_positive_ent + false_positive_ent)
    # if (true_positive_ent + false_negetive_ent) == 0:
    #     recall = 0
    # else:
    #     recall = true_positive_ent / (true_positive_ent + false_negetive_ent)
    # if precision ==0 and recall == 0:
    #     f_1 = 0
    # else:
    #     f_1 = (2 * precision * recall) / (precision + recall)
    # file.write('\nEntity:')
    # file.write(('\nPrecision: ' + str(precision)))
    # file.write(('\nRecall: ' + str(recall)))
    # file.write(('\nF1: ' + str(f_1)))
    #
    # if (true_positive_atr + false_positive_atr) == 0:
    #     precision = 0
    # else:
    #     precision = true_positive_atr / (true_positive_atr + false_positive_atr)
    # if (true_positive_atr + false_negetive_atr) ==0:
    #     recall = 0
    # else:
    #     recall = true_positive_atr / (true_positive_atr + false_negetive_atr)
    # if precision ==0 and recall == 0:
    #     f_1 = 0
    # else:
    #     f_1 = (2 * precision * recall) / (precision + recall)
    # file.write('\nAttribute:')
    # file.write(('\nPrecision: ' + str(precision)))
    # file.write(('\nRecall: ' + str(recall)))
    # file.write(('\nF1: ' + str(f_1)))