from transformers import AutoTokenizer, AutoModelForSequenceClassification  
from collections import defaultdict
from transformers import BertTokenizer, BertForSequenceClassification, TextClassificationPipeline
import torch  
from detoxify import Detoxify
import torch.nn.functional as F
 
def sentiment_init(device):
    # Assuming you're using a model pre-trained for sentiment analysis  
    model_name = "nlptown/bert-base-multilingual-uncased-sentiment" # This model is trained for sentiment analysis and classifies text into 5 classes [1 star ... 5 stars]  
      
    tokenizer = AutoTokenizer.from_pretrained(model_name)  
    model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)  
  
    def sentiment_reward(batch_of_texts):
        # Encoding the texts  
        encoded_inputs = tokenizer(batch_of_texts, padding=True, truncation=True, return_tensors="pt")  
        encoded_inputs = {key: value.to(device) for key, value in encoded_inputs.items()}
          
        # Getting the model output  
        with torch.no_grad():
            outputs = model(**encoded_inputs)  
          
        # The output is a tuple, we're interested in the first element which are the logits  
        logits = outputs[0]  
          
        # Get the predicted class for each text in the batch   
        # (in the case of nlptown/bert-base-multilingual-uncased-sentiment, these classes correspond to number of stars)  
        predictions = torch.argmax(logits, dim=-1)  
        return predictions.cpu()

    return sentiment_reward

def toxic_init(device):
    model_path = "GroNLP/hateBERT"
    tokenizer = BertTokenizer.from_pretrained(model_path)
    model = BertForSequenceClassification.from_pretrained(model_path, num_labels=2)
    pipeline = TextClassificationPipeline(model=model, tokenizer=tokenizer)
      
    def safety_reward(batch_of_texts):  
        return torch.tensor([(elem['score'] if elem['label'] != 'toxic' else 1 - elem['score']) for elem in pipeline(batch_of_texts,
             padding=True, truncation=True)])
    
    return safety_reward

def safety_init(device, weights = None, temperature = None, binary = False):
    bert = Detoxify('multilingual', device=device)
    if weights is None: weights = defaultdict(lambda: 1)
    if temperature is None: temperature = 5.0
    def safety_reward(batch_of_texts):
        raw_output = bert.predict(batch_of_texts)
        wts = []
        rewards = []
        for k, v in raw_output.items():
            rewards.append(v)
            wts.append(min(weights[k], 1.0))
        rewards = torch.tensor(rewards).T
        rewards = rewards * torch.tensor(wts)[None]
        if binary:
            return ((1 - (F.softmax(rewards * temperature, dim=-1) * rewards).sum(-1)) ** 2 > 0.9) * 2 - 1
        else:
            return (1 - (F.softmax(rewards * temperature, dim=-1) * rewards).sum(-1)) ** 2
        # return 1 - rewards.max(-1)[0]
    return safety_reward

def vowel_init(device = None):
    def vowel_reward(batch_of_texts):
        vowels = 'aeiouAEIOU'
        return torch.tensor([sum(1 for letter in ''.join(filter(str.isalpha, s)) if letter in vowels) / (len(''.join(filter(str.isalpha, s))) + 1e-8) for s in batch_of_texts])
    return vowel_reward

def length_init(device = None):
    def length_reward(batch_of_texts):
        return torch.tensor([max((1024 - len(b)) / 1024, 0.0) for b in batch_of_texts])
    return length_reward

def dummy_init(device=None):
    return lambda x: torch.tensor([1.0] * (len(x) // 2) + [0.1] * (len(x) // 2))

def readability_init():
    import textstat
    def read_reward(batch_of_texts):
        read_level = torch.tensor([textstat.text_standard(txt, float_output=True) for txt in batch_of_texts])
        return torch.where((read_level < 2) | (read_level > 8), 0, (8 - read_level) / 6)
        # read_level = torch.clamp(torch.tensor([textstat.text_standard(txt, float_output=True) for txt in batch_of_texts]) / 18, max=1)
    return read_reward

#readability_init()(['this is extremely hard', 'i dont even know what to say'])
