from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
import torch
import transformers
from transformers import GPT2Config,GPT2Tokenizer,GPT2ForSequenceClassification

from . import TextualModel


@TextualModel.register
class GPT2Model(torch.nn.Module):
    def __init__(self, num_classes=2, *, freeze=False):
        super().__init__()
        self.model_config = GPT2Config.from_pretrained(pretrained_model_name_or_path='gpt2', num_labels=num_classes)
        self.tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model_name_or_path='gpt2')

        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.sep_token = '<|SEP|>'

        self.model = GPT2ForSequenceClassification.from_pretrained(pretrained_model_name_or_path='gpt2', config=self.model_config)
        self.model.config.pad_token_id = self.model.config.eos_token_id

    def data_collator(self, 
        data: Dict[str, Any], 
        features: Optional[List[str]], 
        *, 
        padding: bool = True, 
        truncation: bool = True,
        max_length: Optional[int] = None, 
        return_tensors: str = 'pt'
    ):
        for feature in features:
            data[feature] = self.tokenizer(data[feature], padding=padding, max_length=max_length, return_tensors=return_tensors, truncation=truncation)
        
        return data
    
    def forward(self, **inputs):
        return self.model(**inputs)