In [1]:
# Uncomment line below to install exlib
# !pip install exlib
In [ ]:
import torch
from transformers import AutoModel, AutoTokenizer
import numpy as np
import tqdm
from tqdm import tqdm
from torch.utils.data import DataLoader
from datasets import load_dataset
import torch.nn as nn
import sentence_transformers

import sys
sys.path.insert(0, "../../src")
import exlib
from exlib.datasets.emotion_helper import project_points_onto_axes, load_emotions
from exlib.datasets.emotion import load_data, load_model, EmotionDataset, EmotionClassifier, EmotionFixScore, get_emotion_scores

from exlib.features.text import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Load datasets and pre-trained models¶

In [3]:
dataset = EmotionDataset("test")
dataloader = DataLoader(dataset, batch_size=2, shuffle=False)
model = EmotionClassifier().eval().to(device)
SamLowe/roberta-base-go_emotions

Model prediction¶

In [4]:
for batch in tqdm(dataloader): 
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    output = model(input_ids, attention_mask)
    utterances = [dataset.tokenizer.decode(input_id, skip_special_tokens=True) for input_id in input_ids]
    for utterance, label in zip(utterances, output.logits):
        id_str = model.model.config.id2label[label.argmax().item()]
        print("Text: {}\nEmotion: {}\n".format(utterance, id_str))
    break
  0%|                                                                                         | 0/2714 [00:00<?, ?it/s]
Text: I’m really sorry about your situation :( Although I love the names Sapphira, Cirilla, and Scarlett!
Emotion: remorse

Text: It's wonderful because it's awful. At not with.
Emotion: admiration


In [5]:
all_baseline_scores = get_emotion_scores([
    "identity", "random", "word", "phrase", "sentence", "clustering", "archipelago"    
])
SamLowe/roberta-base-go_emotions
100%|██████████████████████████████████████████████████████████████████████████████| 1357/1357 [00:43<00:00, 31.44it/s]
100%|██████████████████████████████████████████████████████████████████████████████| 1357/1357 [02:00<00:00, 11.29it/s]
100%|██████████████████████████████████████████████████████████████████████████████| 1357/1357 [01:21<00:00, 16.62it/s]
100%|██████████████████████████████████████████████████████████████████████████████| 1357/1357 [01:34<00:00, 14.43it/s]
100%|██████████████████████████████████████████████████████████████████████████████| 1357/1357 [00:47<00:00, 28.54it/s]
100%|██████████████████████████████████████████████████████████████████████████████| 1357/1357 [02:59<00:00,  7.57it/s]
100%|██████████████████████████████████████████████████████████████████████████████| 1357/1357 [18:36<00:00,  1.21it/s]
In [6]:
for name, score in all_baseline_scores.items():
    print(f'BASELINE {name} mean score: {score.mean()}')
BASELINE identity mean score: 0.010318498686651098
BASELINE random mean score: 0.030460640761845705
BASELINE word mean score: 0.11819195071168308
BASELINE phrase mean score: 0.019752760732233695
BASELINE sentence mean score: 0.0119969120149827
BASELINE clustering mean score: 0.08897856287357343
BASELINE archipelago mean score: 0.052713106135909224
In [ ]: