from datasets import ClassLabel, Sequence
import random
import pandas as pd
from IPython.display import display, HTML
from datasets import load_dataset, load_metric

import transformers
from transformers import AutoTokenizer
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer

import numpy as np

from transformers import DataCollatorForTokenClassification
from seqeval.scheme import IOB2

#https://github.com/huggingface/datasets/blob/master/metrics/seqeval/seqeval.py
#https://github.com/chakki-works/seqeval/blob/2921931184a98aff0dbbda5ff943214fe50a7847/seqeval/metrics/sequence_labeling.py#L22
#https://github.com/chakki-works/seqeval

task = "ner" # Should be one of "ner", "pos" or "chunk"
model_checkpoint = "bert-base-uncased"
batch_size = 16


datasets = load_dataset("src/utils/fine_tuning/custom_conll2003.py")

print(datasets["test"][0], datasets["train"].features[f"ner_tags"])

label_list = datasets["test"].features[f"{task}_tags"].feature.names
print(label_list)


def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset) - 1)
        while pick in picks:
            pick = random.randint(0, len(dataset) - 1)
        picks.append(pick)

    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
        elif isinstance(typ, Sequence) and isinstance(typ.feature, ClassLabel):
            df[column] = df[column].transform(lambda x: [typ.feature.names[i] for i in x])
    display(HTML(df.to_html()))

show_random_elements(datasets["train"])


tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

assert isinstance(tokenizer, transformers.PreTrainedTokenizerFast)

print(tokenizer("predict the average weather delay for each airline"))

example = datasets["train"][1000]
print(example["tokens"])

tokenized_input = tokenizer(example["tokens"], is_split_into_words=True)
tokens = tokenizer.convert_ids_to_tokens(tokenized_input["input_ids"])
print('Tokens: ', tokens)

print('Tokenized inputs: ', tokenized_input.word_ids())

word_ids = tokenized_input.word_ids()
print(word_ids)
aligned_labels = [-100 if i is None else example[f"{task}_tags"][i] for i in word_ids]
print(len(aligned_labels), len(tokenized_input["input_ids"]))

label_all_tokens = True

def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)

    labels = []
    for i, label in enumerate(examples[f"{task}_tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            # Special tokens have a word id that is None. We set the label to -100 so they are automatically
            # ignored in the loss function.
            if word_idx is None:
                label_ids.append(-100)
            # We set the label for the first token of each word.
            elif word_idx != previous_word_idx:
                label_ids.append(label[word_idx])
            # For the other tokens in a word, we set the label to either the current label or -100, depending on
            # the label_all_tokens flag.
            else:
                label_ids.append(label[word_idx] if label_all_tokens else -100)
            previous_word_idx = word_idx

        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs

tokenized_datasets = datasets.map(tokenize_and_align_labels, batched=True)


model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=len(label_list))

args = TrainingArguments(
    f"test-{task}",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=3,
    weight_decay=0.01,
)


data_collator = DataCollatorForTokenClassification(tokenizer)

metric = load_metric("seqeval")

labels = [label_list[i] for i in example[f"{task}_tags"]]
print('labels')
print(labels)
print(metric.compute(predictions=[labels], references=[labels]))


def compute_metrics(p):
    print(p)
    predictions, labels = p
    print('shape of prediction', len(predictions))
    predictions = np.argmax(predictions, axis=2)

    # Remove ignored index (special tokens)
    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    print('true prediction and labels:')
    print(true_predictions[123], true_labels[123])

    results = metric.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }

trainer = Trainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

trainer.train()

trainer.evaluate()

predictions, labels, _ = trainer.predict(tokenized_datasets["test"])
predictions = np.argmax(predictions, axis=2)

# Remove ignored index (special tokens)
true_predictions = [
    [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
    for prediction, label in zip(predictions, labels)
]
true_labels = [
    [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
    for prediction, label in zip(predictions, labels)
]

results = metric.compute(predictions=true_predictions, references=true_labels)
print('Result of testing set', results)


def PLSA(lmbda, k, vocabulary, document_list, corpus):
    xs = []
    ys = []
    thetas = {}
    pis = {}
    n_wk = {}
    n_dk = {}
    for i in range(0, k):
        thta = Theta(vocabulary)
        thta.norm_dist()
        thetas[i] = thta
    for doc in document_list:
        pi = Topic_pi(k)
        pi.norm_dist()
        pis[doc.index] = pi

    # Iterate over the EM for 100 times
    for z in range(0, 60):
        prev_pis = pis
        prev_thetas = thetas

        # E step
        for word in vocabulary:
            for t in range(0, k):
                sum = 0.0
                i = 0
                for document in document_list:
                    j = 0
                    i += 1
                    for w in document.body.split():
                        j += 1
                        if (w == word):
                            count = document.word_freq[word]
                            pi = pis[document.index]
                            p_z_pi = pi.distribution[t]
                            theta = thetas[t]
                            p_w_theta = theta.distribution[word]
                            top = p_z_pi * p_w_theta

                            bottom = 0.0

                            for m in range(0, k):
                                pi_m = pis[document.index]
                                p_z_pi_m = pi_m.distribution[m]
                                theta_m = thetas[m]
                                p_w_theta_m = theta_m.distribution[word]
                                bottom += p_z_pi_m * p_w_theta_m
                            sum += count * (top / bottom)
                n_wk[(word, t)] = sum

        for document in document_list:
            for t in range(0, k):
                sum = 0.0
                i = 0
                for word in document.body.split():
                    i += 1
                    count = document.word_freq[word]
                    pi = pis[document.index]
                    p_z_pi = pi.distribution[t]
                    theta = thetas[t]
                    p_w_theta = theta.distribution[word]
                    top = p_z_pi * p_w_theta

                    bottom = 0.0

                    for m in range(0, k):
                        pi_m = pis[document.index]
                        p_z_pi_m = pi_m.distribution[m]
                        theta_m = thetas[m]
                        p_w_theta_m = theta_m.distribution[word]
                        bottom += p_z_pi_m * p_w_theta_m
                    sum += count * (top / bottom)
                n_dk[(document.index, t)] = sum
        # M-step
        for word in vocabulary:
            for t in range(0, k):
                sum_n_wk = 0.0
                for w in vocabulary:
                    sum_n_wk += n_wk[(word, t)]
                theta = thetas[t]
                theta.distribution[word] = n_wk[(word, t)] / sum_n_wk

        for doc in documents:
            for t in range(0, k):
                sum_n_dk = 0.0
                for m in range(0, k):
                    sum_n_dk += n_dk[(document.index, m)]
                pi = pis[document.index]
                pi.distribution[t] = n_dk[(document.index, t)] / sum_n_dk

        likelihood_prev = calculate_likelihood(corpus, lmbda, prev_pis, prev_thetas, k)
        likelihood_current = calculate_likelihood(corpus, lmbda, pis, thetas, k)
        improvement = (likelihood_prev - likelihood_current) / likelihood_prev
        xs.append(z)
        ys.append(improvement)
        print(improvement)
        if improvement < 0.0001:
            break

    return thetas, xs, ys