from transformers import PretrainedConfig
import json


class HyenaConfig(PretrainedConfig):
    model_type = "hyenadna"
    def __init__(
        self,
        vocab_size=12,
        d_model=256,
        d_inner=None,
        use_bias=True,
        train_freq=True,
        max_seq_len=1024,
        emb_dim=3,
        n_layer=12,
        num_inner_mlps=2,
        hyena_order=2,
        short_filter_order=3,
        filter_order=64,
        activation_freq=1,
        embed_dropout=0.1,
        hyena_dropout=0.0,
        hyena_filter_dropout=0.0,
        layer_norm_epsilon=1e-5,
        initializer_range=0.02,
        pad_vocab_size_multiple=8,
        num_prompts=64,
        prompts_size=64,
        **kwargs,
    ):
        self.vocab_size = vocab_size
        self.d_model = d_model
        if d_inner is None:
            self.d_inner = 4 * d_model
        else:
            self.d_inner = d_inner
        self.use_bias = use_bias
        self.train_freq = train_freq
        self.max_seq_len = max_seq_len
        self.emb_dim = emb_dim
        self.n_layer = n_layer
        self.hyena_order = hyena_order
        self.filter_order = filter_order
        self.short_filter_order = short_filter_order
        self.activation_freq = activation_freq
        self.num_inner_mlps = num_inner_mlps
        self.embed_dropout = embed_dropout
        self.hyena_dropout = hyena_dropout
        self.hyena_filter_dropout = hyena_filter_dropout
        self.layer_norm_epsilon = layer_norm_epsilon
        self.initializer_range = initializer_range
        self.pad_vocab_size_multiple = pad_vocab_size_multiple
        self.num_prompts = num_prompts
        self.prompts_size = prompts_size
        super().__init__(**kwargs)

    @classmethod
    def from_original_config(cls, config_path, **kwargs):
        with open(config_path, "r") as f:
            config = json.load(f)

        vocab_size = config["vocab_size"]
        d_model = config["d_model"]
        d_inner = config["d_inner"]
        max_seq_len = config["layer"]["l_max"]
        emb_dim = config["layer"]["emb_dim"]
        filter_order = config["layer"]["filter_order"]
        if "local_order" in config["layer"]:
            short_filter_order = config["layer"]["local_order"]
        elif "short_filter_order" in config["layer"]:
            short_filter_order = config["layer"]["short_filter_order"]
        else:
            short_filter_order = 3
        n_layer = config["n_layer"]
        activation_freq = config["layer"]["w"]
        embed_dropout = config["embed_dropout"]
        pad_vocab_size_multiple = config["pad_vocab_size_multiple"]
        return cls(vocab_size=vocab_size,
                   d_model=d_model,
                   d_inner=d_inner,
                   max_seq_len=max_seq_len,
                   emb_dim=emb_dim,
                   filter_order=filter_order,
                   short_filter_order=short_filter_order,
                   n_layer=n_layer,
                   activation_freq=activation_freq,
                   embed_dropout=embed_dropout,
                   pad_vocab_size_multiple=pad_vocab_size_multiple,
                   tie_word_embeddings=False,
                   **kwargs
                   )