

import torch

from transformers import AutoModel, AutoTokenizer


class DetModel(torch.nn.Module):
    def __init__(self, model_name="bert-base-uncased", num_labels=2):
        super(DetModel, self).__init__()
        self.model = AutoModel.from_pretrained(model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)

        self.detector = torch.nn.Linear(self.model.config.hidden_size, 2)

        self.act = torch.nn.ReLU()
        self.dropout = torch.nn.Dropout(self.model.config.hidden_dropout_prob)
        self.softmax = torch.nn.Softmax(dim=1)

    def forward(self, input_ids, attention_mask, **kwargs):
        outputs = self.model(
            input_ids=input_ids, attention_mask=attention_mask, **kwargs
        )
        hidden_states = outputs.last_hidden_state

        pooled_output = hidden_states[:, 0]

        pooled_output = self.dropout(pooled_output)
        pooled_output = self.act(pooled_output)

        det_logits = self.detector(pooled_output)
        det_logits = self.softmax(det_logits)

        return {
            "det_logits": det_logits,
            "hidden_states": hidden_states,
            "last_hidden_state": outputs.last_hidden_state,
        }


class CLSModel(torch.nn.Module):
    def __init__(self, model_name="bert-base-uncased", num_labels=2):
        super(CLSModel, self).__init__()
        self.model = AutoModel.from_pretrained(model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)

        self.classifier = torch.nn.Linear(self.model.config.hidden_size, num_labels)

        self.act = torch.nn.ReLU()
        self.dropout = torch.nn.Dropout(self.model.config.hidden_dropout_prob)
        self.softmax = torch.nn.Softmax(dim=1)

    def forward(self, input_ids, attention_mask, **kwargs):
        attention_mask.to(self.model.device)
        outputs = self.model(
            input_ids=input_ids, attention_mask=attention_mask, **kwargs
        )
        hidden_states = outputs.last_hidden_state

        pooled_output = hidden_states[:, 0]

        pooled_output = self.dropout(pooled_output)
        pooled_output = self.act(pooled_output)

        cls_logits = self.classifier(pooled_output)
        cls_logits = self.softmax(cls_logits)

        return {
            "cls_logits": cls_logits,
            "hidden_states": hidden_states,
            "last_hidden_state": outputs.last_hidden_state,
        }
