from __future__ import division, absolute_import, print_function, unicode_literals

import copy
import json
import sys
from io import open

from positional_embeddings import PositionalEmbeddingsTypes


class ModelConfig(object):
    """Configuration class to store the configuration of a `BertModel`.
    """

    def __init__(self,
                 vocab_size_or_config_json_file,
                 hidden_size=768,
                 num_hidden_layers=12,
                 num_attention_heads=12,
                 intermediate_size=3072,
                 default_ffn_norm=None,
                 hidden_dropout_prob=0.1,
                 hidden_act="relu",
                 attention_complexity="auto",
                 attention_probs_dropout_prob=0.1,
                 max_position_embeddings=512,
                 token_type_embeddings=False,
                 embedding_ln_type="hardtanh",
                 type_vocab_size=2,
                 initializer_range=0.02,
                 pos_emb_type="learned",
                 embedding_dropout=0.25,
                 relpe_type=None,
                 final_ln_type=None,
                 pooler_function="mean",
                 pooler_no_dense=False,
                 pooler_act="gelu",
                 pooler_ln_type=None,
                 classifier_bias=False,
                 lm_head_act="gelu",
                 lm_head_ln_type="uncentered_ln",
                 local_attention=False,
                 window_size=1024
                 ):
        """Constructs ModelConfig.

        Args:
            vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
            hidden_size: Size of the encoder layers and the pooler layer.
            num_hidden_layers: Number of hidden layers in the Transformer encoder.
            num_attention_heads: Number of attention heads for each attention layer in
                the Transformer encoder.
            intermediate_size: The size of the "intermediate" (i.e., feed-forward)
                layer in the Transformer encoder.
            hidden_act: The non-linear activation function (function or string) in the
                encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
            hidden_dropout_prob: The dropout probabilitiy for all fully connected
                layers in the embeddings, encoder, and pooler.
            attention_probs_dropout_prob: The dropout ratio for the attention
                probabilities.
            max_position_embeddings: The maximum sequence length that this model might
                ever be used with. Typically, set this to something large just in case
                (e.g., 512 or 1024 or 2048).
            type_vocab_size: The vocabulary size of the `token_type_ids` passed into
                `BertModel`.
            initializer_range: The sttdev of the truncated_normal_initializer for
                initializing all weight matrices.
        """
        if isinstance(vocab_size_or_config_json_file,
                      str) or (sys.version_info[0] == 2 and isinstance(
            vocab_size_or_config_json_file, unicode)):
            with open(vocab_size_or_config_json_file, "r",
                      encoding='utf-8') as reader:
                json_config = json.loads(reader.read())
            for key, value in json_config.items():
                self.__dict__[key] = value
        elif isinstance(vocab_size_or_config_json_file, int):
            self.vocab_size = vocab_size_or_config_json_file
            self.hidden_size = hidden_size
            self.num_hidden_layers = num_hidden_layers
            self.num_attention_heads = num_attention_heads
            self.intermediate_size = intermediate_size
            self.default_ffn_norm = default_ffn_norm
            self.hidden_dropout_prob = hidden_dropout_prob
            self.attention_complexity = attention_complexity
            self.attention_probs_dropout_prob = attention_probs_dropout_prob
            self.max_position_embeddings = max_position_embeddings
            self.token_type_embeddings = token_type_embeddings
            self.type_vocab_size = type_vocab_size
            self.initializer_range = initializer_range
            self.pos_emb_type = PositionalEmbeddingsTypes[pos_emb_type.upper()]
            self.embedding_ln_type = embedding_ln_type
            self.embedding_dropout = embedding_dropout
            self.relpe_type = relpe_type
            self.final_ln_type = final_ln_type
            self.pooler_function = pooler_function
            self.pooler_no_dense = pooler_no_dense
            self.pooler_act = pooler_act
            self.pooler_ln_type = pooler_ln_type
            self.classifier_bias = classifier_bias
            self.lm_head_act = lm_head_act
            self.lm_head_ln_type = lm_head_ln_type
            self.local_attention = local_attention
            self.window_size = window_size
        else:
            raise ValueError(
                "First argument must be either a vocabulary size (int)"
                "or the path to a pretrained model config file (str)")

    @classmethod
    def from_dict(cls, json_object):
        """Constructs a `ModelConfig` from a Python dictionary of parameters."""
        config = ModelConfig(vocab_size_or_config_json_file=-1)
        for key, value in json_object.items():
            config.__dict__[key] = value
        return config

    @classmethod
    def from_json_file(cls, json_file):
        """Constructs a `ModelConfig` from a json file of parameters."""
        with open(json_file, "r", encoding='utf-8') as reader:
            text = reader.read()
        return cls.from_dict(json.loads(text))

    def __repr__(self):
        return str(self.to_json_string())

    def to_dict(self):
        """Serializes this instance to a Python dictionary."""
        output = copy.deepcopy(self.__dict__)
        return output

    def to_json_string(self):
        """Serializes this instance to a JSON string."""
        return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
