from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
import torch
import transformers

from . import TextualModel


@TextualModel.register
class BertModel(torch.nn.Module):
    def __init__(self, num_classes=2, *, freeze=False):
        super().__init__()
        self.tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-uncased")
        self.model = transformers.BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=num_classes)

    def data_collator(self, 
        data: Dict[str, Any], 
        features: Optional[List[str]], 
        *, 
        padding: bool = True, 
        truncation: bool = True,
        max_length: Optional[int] = None, 
        return_tensors: str = 'pt'
    ):
        for feature in features:
            data[feature] = self.tokenizer(data[feature], padding=padding, max_length=max_length, return_tensors=return_tensors, truncation=truncation)
        
        return data
    
    def forward(self, **inputs):
        return self.model(**inputs)