In [1]:
# Uncomment line below to install exlib
# !pip install exlib
In [ ]:
import torch
from transformers import AutoModel, AutoTokenizer
import numpy as np
import tqdm
from tqdm import tqdm
from torch.utils.data import DataLoader
from datasets import load_dataset
import torch.nn as nn
import sentence_transformers
import sys
sys.path.insert(0, "../../src")
import exlib
from exlib.datasets.emotion_helper import project_points_onto_axes, load_emotions
from exlib.datasets.emotion import load_data, load_model, EmotionDataset, EmotionClassifier, EmotionFixScore, get_emotion_scores
from exlib.features.text import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Load datasets and pre-trained models¶
In [3]:
dataset = EmotionDataset("test")
dataloader = DataLoader(dataset, batch_size=2, shuffle=False)
model = EmotionClassifier().eval().to(device)
SamLowe/roberta-base-go_emotions
Model prediction¶
In [4]:
for batch in tqdm(dataloader):
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
output = model(input_ids, attention_mask)
utterances = [dataset.tokenizer.decode(input_id, skip_special_tokens=True) for input_id in input_ids]
for utterance, label in zip(utterances, output.logits):
id_str = model.model.config.id2label[label.argmax().item()]
print("Text: {}\nEmotion: {}\n".format(utterance, id_str))
break
0%| | 0/2714 [00:00<?, ?it/s]
Text: I’m really sorry about your situation :( Although I love the names Sapphira, Cirilla, and Scarlett! Emotion: remorse Text: It's wonderful because it's awful. At not with. Emotion: admiration
In [5]:
all_baseline_scores = get_emotion_scores([
"identity", "random", "word", "phrase", "sentence", "clustering", "archipelago"
])
SamLowe/roberta-base-go_emotions
100%|██████████████████████████████████████████████████████████████████████████████| 1357/1357 [00:43<00:00, 31.44it/s] 100%|██████████████████████████████████████████████████████████████████████████████| 1357/1357 [02:00<00:00, 11.29it/s] 100%|██████████████████████████████████████████████████████████████████████████████| 1357/1357 [01:21<00:00, 16.62it/s] 100%|██████████████████████████████████████████████████████████████████████████████| 1357/1357 [01:34<00:00, 14.43it/s] 100%|██████████████████████████████████████████████████████████████████████████████| 1357/1357 [00:47<00:00, 28.54it/s] 100%|██████████████████████████████████████████████████████████████████████████████| 1357/1357 [02:59<00:00, 7.57it/s] 100%|██████████████████████████████████████████████████████████████████████████████| 1357/1357 [18:36<00:00, 1.21it/s]
In [6]:
for name, score in all_baseline_scores.items():
print(f'BASELINE {name} mean score: {score.mean()}')
BASELINE identity mean score: 0.010318498686651098 BASELINE random mean score: 0.030460640761845705 BASELINE word mean score: 0.11819195071168308 BASELINE phrase mean score: 0.019752760732233695 BASELINE sentence mean score: 0.0119969120149827 BASELINE clustering mean score: 0.08897856287357343 BASELINE archipelago mean score: 0.052713106135909224
In [ ]: