from transformers import RobertaTokenizer, RobertaForSequenceClassification
from transformers import BertTokenizer, BertForSequenceClassification
from textblob import TextBlob
import re
import pandas as pd
#from utils import *

device = DefineDevice(device=None)

NEUTRAL = 0
TOXIC = 1

###
terms = list(pd.read_csv('data/hate-speech-and-offensive-language/lexicons/refined_ngram_dict.csv').ngram)
def terms_lf(text, terms=terms):
    return TOXIC if np.sum([t in text for t in terms])>1 else NEUTRAL

###
def racism_lf(text):
    return TOXIC if re.search(r"(nigg|monkey|jungle bunny|slave|savage|wetback)", text, flags=re.I) else ABSTAIN

def offensive_lf(text):
    return TOXIC if re.search(r"(fuck|bitch|stupid|idiot)", text, flags=re.I) else ABSTAIN

def lgbt_lf(text):
    return TOXIC if re.search(r"(fagg)", text, flags=re.I) else ABSTAIN

###
def racism2_lf(text):
    return TOXIC if re.search(r"(black|latino|latina|asian)", text, flags=re.I) else ABSTAIN

def offensive2_lf(text):
    return TOXIC if re.search(r"(christian|jewish|muslim)", text, flags=re.I) else ABSTAIN

def lgbt2_lf(text):
    return TOXIC if re.search(r"(homosex|gay|lesbian|queer)", text, flags=re.I) else ABSTAIN

###
def textblob_sentiment_lf(text):
    scores = TextBlob(text)
    return TOXIC if scores.sentiment.polarity < 0 else NEUTRAL

#https://huggingface.co/IMSyPP/hate_speech_en
bert_tokenizer = BertTokenizer.from_pretrained('IMSyPP/hate_speech_en')
bert_model = BertForSequenceClassification.from_pretrained('IMSyPP/hate_speech_en').to(device)
def bert_hate_lf(text, model = bert_model, tokenizer = bert_tokenizer, device=device):
    batch = tokenizer.encode(text, return_tensors='pt', truncation=True).to(device)
    return TOXIC if torch.nn.Softmax(dim=1)(model(batch).logits)[0][0].item() < .5 else NEUTRAL

#https://huggingface.co/s-nlp/roberta_toxicity_classifier?text=I+like+you.+I+love+you
roberta_tokenizer = RobertaTokenizer.from_pretrained('SkolkovoInstitute/roberta_toxicity_classifier')
roberta_model = RobertaForSequenceClassification.from_pretrained('SkolkovoInstitute/roberta_toxicity_classifier').to(device)
def roberta_toxicity_lf(text, model = roberta_model, tokenizer = roberta_tokenizer, device=device):
    batch = tokenizer.encode(text, return_tensors='pt', truncation=True).to(device)
    return TOXIC if torch.nn.Softmax(dim=1)(model(batch).logits)[0][1].item() > .5 else NEUTRAL