'''
- gmb_trainer.py
- This file handles training various models for the annotated GMB corpus
  - *WARNING* - This function only seems to work on torch==1.7.1, not torch==1.8.0, still dissecting that issue for now
'''

# External imports
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from transformers import Trainer, TrainingArguments, BertTokenizerFast, BertForTokenClassification, RobertaTokenizerFast, RobertaForTokenClassification, XLNetTokenizerFast, XLNetForTokenClassification

# Internal imports
from src.core.configuration.datagen_conf import *
from src.core.interface import data_generation

# Encodings for the tokens and tags (These will be used globally)
unique_tags = set()
tag2id = {'O': 0, 'art': 1, 'tim': 2, 'geo': 3, 'org': 4, 'nat': 5, 'eve': 6, 'per': 7, 'gpe': 8}
id2tag = {}

# Dataset object (Makes life easier in the functions below)
class GMBDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

'''
----------train_model----------
- Fine-tunes an existing model on the GMB corpus
-----Inputs-----
- model_name - the name of the model to train
-----Output-----
- N/A - This function writes the trained models to models/gmb_corpus
'''
def train_model(model_name):
    # Start by loading in the data
    print("TRAINING: Loading GMB Corpus for fine-tuning")
    training_data = data_generation.get_data("gmb")
    print("TRAINING: GMB Corpus loaded. Initializing trainer")
    
    # Split it into train and test (80% train, 20% test)
    train_texts, val_texts, train_tags, val_tags = train_test_split(training_data["texts"], training_data["tags"], test_size=.2)

    # Assign values to the global encodings
    unique_tags = set(tag for doc in training_data["tags"] for tag in doc)
    #for id, tag in enumerate(unique_tags):
    #    tag2id[tag] = id
    for tag, id in tag2id.items():
        id2tag[id] = tag
    print("TRAINING: Tag to ID conversions set:\n", tag2id)

    # Initialize the transformer
    tokenizer = ""
    if model_name == "bert":
        tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')
    elif model_name == "roberta":
        tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base', add_prefix_space=True)
    elif model_name == "xlnet":
        tokenizer = XLNetTokenizerFast.from_pretrained('xlnet-base-cased')

    # Initialize the encodings
    train_encodings = tokenizer(train_texts, is_split_into_words=True, return_offsets_mapping=True, padding=True, truncation=True)
    val_encodings = tokenizer(val_texts, is_split_into_words=True, return_offsets_mapping=True, padding=True, truncation=True)

    # Encode the training and validation sets
    train_labels = encode_tags(train_tags, train_encodings)
    val_labels = encode_tags(val_tags, val_encodings)

    # Remove the offset mappings (Don't want to pass that to the model)
    train_encodings.pop("offset_mapping")
    val_encodings.pop("offset_mapping")

    # Initialize both datasets
    train_dataset = GMBDataset(train_encodings, train_labels)
    val_dataset = GMBDataset(val_encodings, val_labels)

    # Initialize the model
    model = ""
    if model_name == "bert":
        model = BertForTokenClassification.from_pretrained('bert-base-cased', num_labels=len(unique_tags))
    elif model_name == "roberta":
        model = RobertaForTokenClassification.from_pretrained('roberta-base', num_labels=len(unique_tags))
    elif model_name == "xlnet":
        model = XLNetForTokenClassification.from_pretrained('xlnet-base-cased', num_labels=len(unique_tags))

    # Now, we can start training. First, initialize the training arguments
    training_args = TrainingArguments(
        output_dir="models/gmb_corpus/"+model_name,
        num_train_epochs=3,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=64,
        warmup_steps=500,
        weight_decay=0.01,
        logging_dir='src/data/output/fine_tuning/logs/gmb_corpus',
        logging_steps=10,
    )

    # Initialize the trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset
    )

    # Finally, start training
    print("TRAINING: Beginning {} fine-tuning on the GMB Corpus".format(model_name))
    trainer.train()

    # Save the finished model
    trainer.save_model("models/gmb_corpus/{}/trained_model".format(model_name))

    # Operation finished, return
    return


'''
----------encode_tags----------
- Applies the encodings for both tokens and tags in the training set
-----Inputs-----
- tags - The tags to encode
- encodings - The 
-----Output-----
- N/A - This function writes the trained models to models/gmb_corpus
'''
def encode_tags(tags, encodings):
    labels = [[tag2id[tag] for tag in doc] for doc in tags]
    encoded_labels = []
    for doc_labels, doc_offset in zip(labels, encodings.offset_mapping):
        # create an empty array of -100
        doc_enc_labels = np.ones(len(doc_offset),dtype=int) * -100
        arr_offset = np.array(doc_offset)

        # set labels whose first offset position is 0 and the second is not 0
        enc_labels_len = len(doc_enc_labels[(arr_offset[:,0] == 0) & (arr_offset[:,1] != 0)])
        doc_labels_len = len(doc_labels)
        if enc_labels_len < doc_labels_len:
            #print(enc_labels_len, doc_labels_len)
            #print(doc_enc_labels[(arr_offset[:,0] == 0) & (arr_offset[:,1] != 0)])
            #print(doc_labels)
            offset = doc_labels_len - enc_labels_len
            doc_enc_labels[(arr_offset[:,0] == 0) & (arr_offset[:,1] != 0)] = doc_labels[:-offset]
        elif enc_labels_len > doc_labels_len:
            #print(enc_labels_len, doc_labels_len)
            #print(doc_enc_labels[(arr_offset[:,0] == 0) & (arr_offset[:,1] != 0)])
            #print(doc_labels)
            offset = enc_labels_len - doc_labels_len
            doc_enc_labels[(arr_offset[:,0] == 0) & (arr_offset[:,1] != 0)] = doc_labels + (np.ones(offset, dtype=int) * -100).tolist()
        else:
            doc_enc_labels[(arr_offset[:,0] == 0) & (arr_offset[:,1] != 0)] = doc_labels
        encoded_labels.append(doc_enc_labels.tolist())

    return encoded_labels

#train_model("bert")
#train_model("roberta")
#train_model("xlnet")