from torch import nn
from transformers import AutoModel


class ModelClassification(nn.Module):
    def __init__(
        self,
        n_classes,
        tokenizer,
        dropout_prob,
        model_name="roberta-large",
        is_reduced=False,
        **kwargs
    ):
        super(ModelClassification, self).__init__()

        self.transformer = AutoModel.from_pretrained(model_name)
        self.transformer.config.pad_token_id = tokenizer.pad_token_id
        hidden_size = self.transformer.config.hidden_size

        self.classifier = nn.Sequential(
            nn.Dropout(dropout_prob) if dropout_prob > 0 else nn.Identity(),
            nn.Linear(hidden_size, hidden_size // 2 if is_reduced else hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout_prob) if dropout_prob > 0 else nn.Identity(),
            nn.Linear(hidden_size // 2 if is_reduced else hidden_size, n_classes),
        )

    def forward(self, inputs):
        input_ids, attention_mask = inputs["input_ids"], inputs["attention_mask"]
        outputs = self.transformer(
            input_ids=input_ids, attention_mask=attention_mask, return_dict=False
        )
        cls_representation = outputs[0][:, 0, :]
        return self.classifier(cls_representation)


class llm_base:
    def __init__(self, transformer, device):
        self.transformer = transformer.to(device)

    def __call__(self, tokenized_data):
        inputs = self.transformer(
            input_ids=tokenized_data["input_ids"],
            attention_mask=tokenized_data["attention_mask"],
            return_dict=False,
        )
        cls_representation = inputs[0][:, 0, :]
        return cls_representation
