#! /usr/bin/env python3
# coding=utf-8



from simpletransformers.classification import ClassificationModel, ClassificationArgs
from sklearn.model_selection import train_test_split
import os
import torch
import pandas as pd

import logging
os.environ["TOKENIZERS_PARALLELISM"] = "true"

torch.multiprocessing.set_sharing_strategy('file_system')
logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)

# Preparing train data
train_df = pd.read_csv("./data/formality_classifier.csv", header=0)
train_df, eval_df = train_test_split(train_df, test_size=0.2, random_state=25536)
# eval_df = pd.read_csv("./data/allsides_eval.csv", header=0)

# Optional model configuration
model_args = ClassificationArgs()
model_args.num_train_epochs = 2
model_args.labels_list = [0, 1]
model_args.fp16 = True
model_args.n_gpu = 3
model_args.use_multiprocessing_for_evaluation = True
model_args.use_cached_eval_features = False
model_args.overwrite_output_dir = True
model_args.output_dir = "./formality_classifier"


# Create a ClassificationModel
model = ClassificationModel(
    "roberta", "roberta-base", args=model_args
)

# Train the model
model.train_model(train_df)

# Evaluate the model
result, model_outputs, wrong_predictions = model.eval_model(eval_df)

# Make predictions with the model
predictions, raw_outputs = model.predict(["Trump tries to shift focus to victims of undocumented immigrants"])

print(predictions)