from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
import torch
import transformers

from . import TextualModel


@TextualModel.register
class HuggingfaceModel(torch.nn.Module):
    def __init__(self, *, pretrained_model_name_or_path="bert-base-uncased", num_labels=2, freeze=False):
        super().__init__()
        # TODO: add tokenizer args
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
        self.model = transformers.AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path, num_labels=num_labels)

    def data_collator(self, 
        data: Dict[str, Any], 
        features: Optional[List[str]], 
        *, 
        padding: bool = True, 
        truncation: bool = True,
        max_length: Optional[int] = None, 
        return_tensors: str = 'pt'
    ):
        for feature in features:
            data[feature] = self.tokenizer(data[feature], padding=padding, max_length=max_length, return_tensors=return_tensors, truncation=truncation)
        
        return data
    
    def forward(self, **inputs):
        return self.model(**inputs)