
from datasets import load_dataset, load_metric, Dataset
import datasets
from datasets import fingerprint
from transformers import AutoTokenizer, BertTokenizerFast, RobertaTokenizerFast, XLNetTokenizerFast, AutoModelForTokenClassification, TrainingArguments, Trainer, DataCollatorForTokenClassification, pipeline
#from src.utils.training_data_generation.query_generation import generate_queries
import numpy as np
import json
#inernal imports
# from src.core.configuration.datagen_conf import *
from src.core.configuration.datagen_conf import FOLD_COUNT, BATCH_SIZE
from src.utils.fine_tuning.compute_f1 import *


def finetune_model(model_name, tokenizer_name, fold, decay = 0.01, learning_rate = 5e-5):
    task = "ner" # Should be one of "ner", "pos" or "chunk"
    path = 'src/data/output/fine_tuning/logs/htest/f1s/' + model_name + '-d-' + str(decay) + '-lr-' + str(learning_rate)  + '.txt'
    # Get the proper model checkpoint and tokenizer
    file = open(path, 'a')
    file.write('\nFinetuning of :' + model_name + '- Fold: ' + fold +
               '-weight decay- ' + str(decay) + ' -learning rate- ' + str(learning_rate) + '\n\n\n')

    model_checkpoint = ''
    batch_size = BATCH_SIZE
    num_epochs = 1
    weight_decay = decay
    saved_model_name = model_name + '-d-' + str(decay) + '-lr-' + str(learning_rate)
    save_model_path = 'models/'
    if (model_name == 'bert'):
        model_checkpoint = 'bert-base-cased'
        # batch_size = BATCH_SIZE
        # learn_rate = 1e-5
        # num_epochs = 3
        # weight_decay = 0.3
        save_model_path = save_model_path + 'bert/'
    elif (model_name == 'roberta'):
        model_checkpoint = 'roberta-base'
        #batch_size = 64
        #learn_rate = 5e-4
        #num_epochs = 3
        #weight_decay = .05
        save_model_path = save_model_path + 'roberta/'
    elif (model_name == 'xlnet'):
        model_checkpoint = 'xlnet-base-cased'
        #batch_size = 64
        #learn_rate = 1e-3
        #num_epochs = 3
        #weight_decay = .25
        save_model_path = save_model_path + 'xlnet/'

    label_all_tokens = False
    file.write('\n Model saving path: ' + save_model_path)
    # Generate queries if needed

    # Load the dataset
    # fingerprint.set_caching_enabled(False)
    datasets = load_dataset("src/utils/fine_tuning/custom_conll2003.py")#, download_mode="force_redownload", ignore_verifications=True)

    # Truncate the dataset to fit batch size
    extra_examples_val = len(datasets["validation"]) % batch_size
    extra_examples_train = len(datasets["train"]) % batch_size
    extra_examples_test = len(datasets["test"]) % batch_size

    if (extra_examples_train != 0):
        datasets["train"] = datasets["train"]\
            .filter(lambda example, indice: indice >= extra_examples_val, with_indices=True)

    if (extra_examples_val != 0):
        datasets["validation"] = datasets["validation"]\
            .filter(lambda example, indice: indice >= extra_examples_val, with_indices=True)

    if (extra_examples_test != 0):
        datasets["test"] = datasets["test"]\
            .filter(lambda example, indice: indice >= extra_examples_test, with_indices=True)

        # datasets = Dataset.from_dict(datasets)

    datasets["validation"] = datasets["validation"] \
        .filter(lambda example, indice: indice %2 == 0, with_indices=True)
    datasets["validation"] = datasets["validation"] \
        .filter(lambda example, indice: indice %2 == 0, with_indices=True)
    # Get the label list
    label_list = datasets["train"].features[f"{task}_tags"].feature.names

    # Get the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
    if (tokenizer_name == 'bert'):
        tokenizer = BertTokenizerFast.from_pretrained(model_checkpoint)
    elif (tokenizer_name == 'roberta'):
        tokenizer = RobertaTokenizerFast.from_pretrained(model_checkpoint, add_prefix_space=True)
    elif (tokenizer_name == 'xlnet'):
        tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

    def tokenize_and_align_labels(examples):
        tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)
        labels = []
        for i, label in enumerate(examples[f"{task}_tags"]):
            word_ids = tokenized_inputs.word_ids(batch_index=i)
            previous_word_idx = None
            label_ids = []
            for word_idx in word_ids:
                # Special tokens have a word id that is None. We set the label to -100 so they are automatically
                # ignored in the loss function.
                if word_idx is None:
                    label_ids.append(-100)
                # We set the label for the first token of each word.
                elif word_idx != previous_word_idx:
                    label_ids.append(label[word_idx])
                # For the other tokens in a word, we set the label to either the current label or -100, depending on
                # the label_all_tokens flag.
                else:
                    label_ids.append(label[word_idx] if label_all_tokens else -100)
                previous_word_idx = word_idx

            labels.append(label_ids)

        tokenized_inputs["labels"] = labels
        return tokenized_inputs

    # Tokenize the dataset
    tokenized_datasets = datasets.map(tokenize_and_align_labels, batched=True, batch_size=batch_size)
    file.close()
    # Start prepping the fine-tuning stage
    model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=len(label_list))

    args = TrainingArguments(
        f"test-{task}",
        evaluation_strategy = "epoch",
        learning_rate=learning_rate,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        num_train_epochs=num_epochs,
        warmup_steps=500,
        weight_decay=weight_decay,
        seed=12345
    )

    data_collator = DataCollatorForTokenClassification(tokenizer)

    metric = load_metric('seqeval')

    # labels = [label_list[i] for i in example[f"{task}_tags"]]
    labels = label_list
    metric.compute(predictions=[labels], references=[labels])

    def compute_metrics(p):
        predictions, labels = p
        predictions = np.argmax(predictions, axis=2)
        # Remove ignored index (special tokens)
        true_predictions = [
            [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
            for prediction, label in zip(predictions, labels)
        ]
        true_labels = [
            [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
            for prediction, label in zip(predictions, labels)
        ]

        results = metric.compute(predictions=true_predictions, references=true_labels)

        compute_f1_precision_recall(true_predictions, true_labels, path)

        # file.write("\nprecision: " + str(results["overall_precision"]) +
        #     "\nrecall: " + str(results["overall_recall"]) +
        #     "\nf1: " + str(results["overall_f1"]) +
        #     "\naccuracy: " + str(results["overall_accuracy"]) + '\n\n\n')
        # file.close()


        return {
            "precision": results["overall_precision"],
            "recall": results["overall_recall"],
            "f1": results["overall_f1"],
            "accuracy": results["overall_accuracy"]
        }

    # Fire up the trainer
    trainer = Trainer(
        model,
        args,
        train_dataset=tokenized_datasets["train"],
        eval_dataset=tokenized_datasets["validation"],
        data_collator=data_collator,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics
    )

    trainer.train()

    trainer.evaluate()

    file = open(path, 'a')
    file.write('\nEvaluation of :' + model_name + '- Fold: ' + fold + '\n\nOn handcrafted data\n')
    file.close()


    trainer.save_model(save_model_path + saved_model_name)


    trainer.predict(test_dataset=tokenized_datasets["test"])

    check_result(save_model_path + saved_model_name, tokenizer_name, path)

    # # path to handcraft eval = src/data/fine_tuning/eval-hand-craft.txt
    # path = 'src/data/fine_tuning/hand-crafted/'
    # file = open("conll_data_path.txt", "w")
    # file.write(path)
    # file.close()
    # eval_dataset = load_dataset("src/utils/fine_tuning/custom_conll2003.py")
    # tokenized_eval_dataset = eval_dataset.map(tokenize_and_align_labels, batched=True)
    # trainer.predict(test_dataset=tokenized_eval_dataset["test"])


def check_result(model_path, tokenizer_name, path):
    # Get the tokenizer
    if (tokenizer_name == 'bert'):
        tokenizer = BertTokenizerFast.from_pretrained(model_path)
    elif (tokenizer_name == 'roberta'):
        tokenizer = RobertaTokenizerFast.from_pretrained(model_path, add_prefix_space=True)
    elif (tokenizer_name == 'xlnet'):
        tokenizer = AutoTokenizer.from_pretrained(model_path)
    # Prep the model
    model = AutoModelForTokenClassification.from_pretrained(model_path)
    ner_model = pipeline('ner', model=model, tokenizer=tokenizer, grouped_entities=True)
    file = open(path, 'a')
    easy_queries = open('src/data/test_data/easy_queries/flight_delay.txt', 'r')
    # Start testing sequences
    file.write('\n\nEasy queries: ')
    for line in easy_queries:
        # sequence = "Help me predict the average weather delay for each destination airport over the next week"
        sequence = line.split('|')
        entity = sequence[1]
        attribute = sequence[2]
        l = ner_model(sequence[0])
        # print(l)
        file.write('\n' + sequence[0])
        file.write('\n' + str(l))
        labels = ['LABEL_9', 'LABEL_10','LABEL_11', 'LABEL_12']

            # Iterate over the label sequence, concatenating the desired labels together
        result = ''
        for x in l:
            if x['entity_group'] in labels:
                result += x['word']
        file.write('\n' + result)
    hard_queries = open('src/data/test_data/hard_queries/flight_delay.txt', 'r')
    file.write('\n\nHard queries: ')
    for line in hard_queries:
        # sequence = "Help me predict the average security delay for each origin airport for the next week"
        sequence = line.split('|')
        entity = sequence[1]
        attribute = sequence[2]

        l = ner_model(sequence[0])
        # print(l)
        file.write('\n' + sequence[0])
        file.write('\n' + str(l))
        labels = ['LABEL_9', 'LABEL_10', 'LABEL_11', 'LABEL_12']

        # Iterate over the label sequence, concatenating the desired labels together
        result = ''
        for x in l:
            if x['entity_group'] in labels:
                result += x['word']
        file.write('\n' + result)
    file.close()


def loop_through_dataset_folds(model_name, decay = 0.01, learning_rate = 5e-5):
    for i in range(1, 2):
        path = 'src/data/fine_tuning/ner/fold' + str(i) + '/'
        file = open("conll_data_path.txt", "w")
        file.write(path)
        file.close()
        print(path)
        finetune_model(model_name=model_name, tokenizer_name=model_name, fold='fold' + str(i), decay=decay, learning_rate=learning_rate)

def combination_huperparams(model_name):
    file = open("src/data/fine_tuning/finetune_confs.json", 'r', encoding="utf8", newline='\n')
    param_dict = json.loads(file.read())
    file.close()

    weight_decay_list = param_dict['weight_decay']
    learning_rate_list = param_dict['learning_rate']
    batch_size_list = param_dict['batch_size']

    for weight_decay in weight_decay_list:
        for learning_rate in learning_rate_list:
            loop_through_dataset_folds(model_name, weight_decay, learning_rate)


combination_huperparams('bert')
combination_huperparams('roberta')
combination_huperparams('xlnet')
# loop_through_dataset_folds('bert', 'bert')
# loop_through_dataset_folds('roberta', 'roberta')
# loop_through_dataset_folds('xlnet', 'xlnet')
# finetune_model('xlnet', 'xlnet')
# check_result('models/bert/bertfold1', 'bert')
#check_result("models/bert/bert_custom_fold1", "bert")
#check_result("models/roberta/roberta_custom_fold1", "roberta", "fold1")