

from datasets import load_dataset, load_metric, Dataset
import datasets
from datasets import fingerprint
from transformers import AutoTokenizer, BertTokenizerFast, RobertaTokenizerFast, XLNetTokenizerFast, AutoModelForTokenClassification, TrainingArguments, Trainer, DataCollatorForTokenClassification, pipeline
#from src.utils.training_data_generation.query_generation import generate_queries
import numpy as np
import random
import torch
from seqeval.metrics import classification_report
from seqeval.scheme import IOB2
#inernal imports
from src.core.configuration.datagen_conf import PERM_SCALE
from src.utils.fine_tuning.compute_f1 import *


def evaluate_model(model_path, model_name, model_path_name, fold):
    task = "ner" # Should be one of "ner", "pos" or "chunk"
    # Get the proper model checkpoint and tokenizer
    fpath = 'src/data/output/fine_tuning/logs/eval/' + model_path_name + '-fsacle-' + str(PERM_SCALE) + '.txt'
    file = open(fpath, 'w')
    file.write('\nEvaluation of :' + model_path_name + '\n\n\n')

    model_checkpoint = ''
    batch_size = 16
    learn_rate = 5e-5
    num_epochs = 1


    label_all_tokens = False
    # Generate queries if needed

    # Load the dataset
    # fingerprint.set_caching_enabled(False)
    datasets = load_dataset("src/utils/fine_tuning/custom_conll2003.py", download_mode="force_redownload", ignore_verifications=True)

    extra_examples_val = len(datasets["validation"]) % batch_size
    extra_examples_test = len(datasets["test"]) % batch_size
    print(len(datasets["validation"]), extra_examples_val)
    if (extra_examples_val != 0):
        datasets["validation"] = datasets["validation"] \
            .filter(lambda example, indice: indice >= extra_examples_val, with_indices=True)
        print(len(datasets["validation"]))
    if (extra_examples_test != 0):
        datasets["test"] = datasets["test"] \
            .filter(lambda example, indice: indice >= extra_examples_test, with_indices=True)
        print(len(datasets["test"]))

    datasets["train"] = datasets["train"] \
        .filter(lambda example, indice: indice %2 == 0, with_indices=True)

    datasets["train"] = datasets["train"] \
        .filter(lambda example, indice: indice % 2 == 0, with_indices=True)

    datasets["train"] = datasets["train"] \
        .filter(lambda example, indice: indice % 2 == 0, with_indices=True)

    # Get the label list
    label_list = datasets["train"].features[f"{task}_tags"].feature.names

    # Get the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    if (model_name == 'bert'):
        tokenizer = BertTokenizerFast.from_pretrained(model_path)
    elif (model_name == 'roberta'):
        tokenizer = RobertaTokenizerFast.from_pretrained(model_path, add_prefix_space=True)

    def tokenize_and_align_labels(examples):
        tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)
        labels = []
        for i, label in enumerate(examples[f"{task}_tags"]):
            word_ids = tokenized_inputs.word_ids(batch_index=i)
            previous_word_idx = None
            label_ids = []
            for word_idx in word_ids:
                # Special tokens have a word id that is None. We set the label to -100 so they are automatically
                # ignored in the loss function.
                if word_idx is None:
                    label_ids.append(-100)
                # We set the label for the first token of each word.
                elif word_idx != previous_word_idx:
                    label_ids.append(label[word_idx])
                # For the other tokens in a word, we set the label to either the current label or -100, depending on
                # the label_all_tokens flag.
                else:
                    label_ids.append(label[word_idx] if label_all_tokens else -100)
                previous_word_idx = word_idx

            labels.append(label_ids)

        tokenized_inputs["labels"] = labels
        return tokenized_inputs

    # Tokenize the dataset
    tokenized_datasets = datasets.map(tokenize_and_align_labels, batched=True, batch_size=batch_size)
    file.close()
    # Start prepping the fine-tuning stage
    model = AutoModelForTokenClassification.from_pretrained(model_path, num_labels=len(label_list))

    args = TrainingArguments(
        f"test-{task}",
        evaluation_strategy = "epoch",
        learning_rate=learn_rate,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        num_train_epochs=num_epochs,
        weight_decay=0.01,
    )

    data_collator = DataCollatorForTokenClassification(tokenizer)

    metric = load_metric('seqeval')
    # print(metric)
    # print(metric.features)
    # print(metric.inputs_description)

    #labels = [label_list[i] for i in example[f"{task}_tags"]]
    labels = label_list
    metric.compute(predictions=[labels], references=[labels])

    def compute_metrics(p):
        file = open(fpath, 'a')
        predictions, labels = p
        predictions = np.argmax(predictions, axis=2)
        # file.write('\nPredictions: ' + str(predictions[index]) + ' Labels: ' +  str(labels[index]))
        # Remove ignored index (special tokens)
        true_predictions = [
            [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
            for prediction, label in zip(predictions, labels)
        ]
        true_labels = [
            [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
            for prediction, label in zip(predictions, labels)
        ]

        file.write('\nTrue Predictions: ' + str(true_predictions) + '\nTrue Labels: ' + str(true_labels))
        file.close()

        results = metric.compute(predictions=true_predictions, references=true_labels)
        compute_precision_recall_f1(true_predictions, true_labels, fpath)
        print(results)
        # classification_report(true_labels, true_predictions, mode='strict', scheme=IOB2)

        # file.write("\nOverall precision: " + str(results["overall_precision"]) +
        #     "\nOverall recall: " + str(results["overall_recall"]) +
        #     "\nOverall f1: " + str(results["overall_f1"]) +
        #     "\nOverall accuracy: " + str(results["overall_accuracy"]) + '\n\n\n')
        #
        # file.write("\nprecision Attribute: " + str(results["ATR"]["precision"]) +
        #            "\nrecall Attribute: " + str(results["ATR"]["recall"]) +
        #            "\nf1 Attribute: " + str(results["ATR"]["f1"]) + '\n\n\n')
        #
        # file.write("\nprecision Entity: " + str(results["ENT"]["precision"]) +
        #            "\nrecall Entity: " + str(results["ENT"]["recall"]) +
        #            "\nf1 Entity: " + str(results["ENT"]["f1"]) + '\n\n\n')


        return {
            "precision": results["overall_precision"],
            "recall": results["overall_recall"],
            "f1": results["overall_f1"],
            "accuracy": results["overall_accuracy"],
        }

    # Fire up the trainer
    trainer = Trainer(
        model,
        args,
        train_dataset=tokenized_datasets["train"],
        eval_dataset=tokenized_datasets["validation"],
        data_collator=data_collator,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics
    )

    trainer.train()

    trainer.evaluate()

    file = open(fpath, 'a')
    file.write('\nEvaluation of :' + model_name + '- Fold: ' + str(fold) + '\n\nOn handcrafted data\n')
    file.close()


    # trainer.save_model(save_model_path + saved_model_name)

    # check_result(save_model_path + saved_model_name, tokenizer_name, fold)

    trainer.predict(test_dataset=tokenized_datasets["test"])

    # # path to handcraft eval = src/data/fine_tuning/eval-hand-craft.txt
    # path = 'src/data/fine_tuning/hand-crafted/'
    # file = open("conll_data_path.txt", "w")
    # file.write(path)
    # file.close()
    # eval_dataset = load_dataset("src/utils/fine_tuning/custom_conll2003.py")
    # tokenized_eval_dataset = eval_dataset.map(tokenize_and_align_labels, batched=True)
    # trainer.predict(test_dataset=tokenized_eval_dataset["test"])

path = 'src/data/fine_tuning/hand-crafted/'
file = open("conll_data_path.txt", "w")
file.write(path)
file.close()
# evaluate_model('models/bert/bert-lr-0.0001-dc-0.25', 'bert', 'bert-lr-0.0001-dc-0.25', 1)
# evaluate_model('models/bert/bert-lr-5e-05-dc-0.15', 'bert', 'bert-lr-5e-05-dc-0.15', 1)
# evaluate_model('models/roberta/roberta-lr-7.5e-05-dc-0.05', 'roberta', 'roberta-lr-7.5e-05-dc-0.05', 1)
# evaluate_model('/models/roberta/roberta-lr-0.0001-dc-0.25', 'roberta', 'roberta-lr-0.0001-dc-0.25', 1)
evaluate_model('models/xlnet/xlnet-lr-5e-05-dc-0.25', 'xlnet', 'xlnet-lr-5e-05-dc-0.25', 1)
evaluate_model('models/xlnet/xlnet-lr-5e-05-dc-0.3', 'xlnet', 'xlnet-lr-5e-05-dc-0.3', 1)
# finetune_model('xlnet', 'xlnet')
# check_result('models/bert/bertfold1', 'bert')