import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset, load_metric
from transformers import (AutoModel, 
    AutoModelForSequenceClassification,
   AutoTokenizer,
    AdamW,
    AutoTokenizer,
    DataCollatorWithPadding,
    TrainingArguments, 
    Trainer)

from sklearn import preprocessing

#define dataset class
class SDDataset(Dataset):
    def __init__(self, encodings,labels):
        self.encodings = encodings
        self.labels = labels
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

def filter_language(x_stance_data, language):
    data_dict =  {"question":[],"comment":[],"label":[]}
    for lang, question, comment, label in zip(x_stance_data["language"],x_stance_data["question"],x_stance_data["comment"],x_stance_data["label"]):
        if lang==language:
            data_dict["question"].append(question)
            data_dict["comment"].append(comment)
            data_dict["label"].append(label)
    return data_dict

def filter_comments(x_stance_data, language):
    data_dict =  {"comment":[],"label":[]}
    for lang, comment, label in zip(x_stance_data["language"],x_stance_data["comment"],x_stance_data["label"]):
        if lang==language:
            data_dict["comment"].append(comment)
            data_dict["label"].append(label)
    return data_dict

# compute metrics function
def compute_metrics(eval_preds):
    metric = load_metric("glue","mrpc")
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)


#try out different pre-trained models
#checkpoint = 'bert-base-german-cased'
#checkpoint = 'bert-base-multilingual-cased'
#checkpoint = 'bert-base-multilingual-uncased'
checkpoint = 'dbmdz/bert-base-german-uncased'

# initialize tokenizer and data collator
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)


x_stance = load_dataset("x_stance")
# filter language
x_stance_train = filter_language(x_stance["train"],"de")
x_stance_test = filter_language(x_stance["test"],"de")
x_stance_val = filter_language(x_stance["validation"],"de")

    

# pre-process the labels to be binary instead of string values
label_encoder = preprocessing.LabelEncoder()
train_labels = label_encoder.fit_transform(x_stance_train["label"])
train_labels = torch.as_tensor(train_labels)

test_labels = label_encoder.fit_transform(x_stance_test["label"])
test_labels = torch.as_tensor(test_labels)

val_labels = label_encoder.fit_transform(x_stance_val["label"])
val_labels = torch.as_tensor(val_labels)

all_labels = label_encoder.inverse_transform([0,1])
print(all_labels)

# embedd sentences
train_encodings = tokenizer(x_stance_train["question"], x_stance_train["comment"], return_tensors='pt', 
                            max_length = 512, padding='max_length', truncation=True)
val_encodings = tokenizer(x_stance_val["question"], x_stance_val["comment"], return_tensors='pt', 
                          max_length = 512, padding='max_length', truncation=True)
test_encodings = tokenizer(x_stance_test["question"], x_stance_test["comment"], return_tensors='pt', 
                           max_length = 512, padding='max_length', truncation=True)


traindata = SDDataset(train_encodings, train_labels)
testdata = SDDataset(test_encodings, test_labels)
valdata = SDDataset(val_encodings,val_labels)

# Initializing a model with the further defined configuration
model = AutoModelForSequenceClassification.from_pretrained(checkpoint)

# define optimizer
optimizer = AdamW(model.parameters(), lr=1e-5)


# train using a Trainer
training_args = TrainingArguments(
    output_dir='./output/'+checkpoint+'_only_comments',    # output directory
    num_train_epochs=3,              # total number of training epochs
    per_device_train_batch_size=20,  # batch size per device during training
    per_device_eval_batch_size=14,   # batch size for evaluation
    warmup_steps=500,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
    logging_steps=10000,
    save_total_limit=2,
    evaluation_strategy="epoch",     # define the steps when the model should be evaluated
)

trainer = Trainer(
    model=model,                                 # the instantiated 🤗 Transformers model to be trained
    args=training_args,                          # training arguments, defined above
    train_dataset=traindata,                     # training dataset
    eval_dataset=valdata,                        # evaluation dataset
    tokenizer=tokenizer,                         # used tokenizer
    data_collator=data_collator,
    compute_metrics=compute_metrics,              # define metric for evaluation
) 

trainer.train()

# test accuracy
predictions = trainer.predict(testdata)
preds = np.argmax(predictions.predictions, axis=-1)
metric = load_metric('glue','mrpc')
final_score = metric.compute(predictions=preds, references=predictions.label_ids)

del model
torch.cuda.empty_cache()

