from transformers import PretrainedConfig, AutoTokenizer


class TSPConfig(PretrainedConfig):
    model_type = "tsp"

    def __init__(
        self,
        embedding_size=128,
        hidden_size=256,
        num_hidden_layers=12,
        num_attention_heads=4,
        intermediate_size=1024,
        dropout_prob=0.1,
        max_sequence_length=128,
        use_electra=False,
        electra_generator_size_divisor=4,
        auto_tokenizer_name: str = None,
        vocab_size: int = None,
        **kwargs
    ):
        assert hidden_size % num_attention_heads == 0
        self.embedding_size = embedding_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.intermediate_size = intermediate_size
        self.dropout_prob = dropout_prob
        self.max_sequence_length = max_sequence_length
        self.use_electra = use_electra
        self.electra_generator_size_divisor = electra_generator_size_divisor
        if auto_tokenizer_name is not None:
            tokenizer = AutoTokenizer.from_pretrained(auto_tokenizer_name)
            kwargs.update(
                {
                    "bos_token_id": tokenizer.bos_token_id,
                    "sep_token_id": tokenizer.sep_token_id,
                    "pad_token_id": tokenizer.pad_token_id,
                }
            )
            vocab_size = tokenizer.vocab_size
        self.vocab_size = vocab_size
        super().__init__(**kwargs)
        # Manually set mapping to auto models instead of using `register_for_auto_class`,
        # which can only register the model class that is used to execute `push_to_hub`
        self.auto_map = {
            "AutoConfig": "configuration_tsp.TSPConfig",
            "AutoModel": "modeling_tsp.TSPModel",
            "AutoModelForPreTraining": "modeling_tsp.TSPModelForPreTraining",
            "AutoModelForTokenClassification": "modeling_tsp.TSPModelForTokenClassification",
            "AutoModelForSequenceClassification": "modeling_tsp.TSPModelForSequenceClassification",
            "AutoModelForQuestionAnswering": "modeling_tsp.TSPModelForQuestionAnswering",
        }

    @classmethod
    def from_config(cls, config):
        return cls(
            embedding_size=config.embedding_size,
            hidden_size=config.hidden_size,
            num_hidden_layers=config.num_hidden_layers,
            num_attention_heads=config.num_attention_heads,
            intermediate_size=config.intermediate_size,
            dropout_prob=config.dropout_prob,
            max_sequence_length=config.max_sequence_length,
            position_embedding_type=config.position_embedding_type,
            use_electra=config.get("use_electra", None),
            electra_generator_size_divisor=config.get(
                "electra_generator_size_divisor", None
            ),
            auto_tokenizer_name=config.tokenizer,
        )
