# coding=utf-8
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MultiSeqCalibrator model configuration"""

from transformers.utils import logging
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
from typing import Literal, Optional


logger = logging.get_logger(__name__)



class MultiSeqCalibratorConfig(Qwen2Config):
    model_type = "multi_seq_calibrator"

    def __init__(
        self,
        group_size: int = 4,
        architecture: str = "probe", # probe, summary, direct
        num_summarization_vectors: int = 1,
        mlp_hidden_size: int = None,
        max_context_len: int = 8192,
        append_bin_idx: int = None,
        input_embeds_size: int = None,
        increment_position_ids: bool = False,
        attn_types: str = "",
        agent_emb: bool = False,
        node_features: str = "",
        bin_aggregate: bool = False,
        no_early_node_features_projection: bool = False,
        late_node_features_projection: bool = False,
        group_softmax: bool = False,
        sum_group_softmax: bool = False,
        attend_all_group_softmax: bool = False,
        late_group_softmax: bool = False,
        late_node_features_projection_norm: bool = False,
        sum_bin_aggregate: bool = False,
        causal_bin_aggregate: bool = False,
        **kwargs
    ):
        super().__init__(**kwargs)     # fills in every arg Qwen2Config knows about

        self.group_size = group_size   # now it will round-trip through save_pretrained
        self.architecture = architecture
        self.mlp_hidden_size = mlp_hidden_size
        self.max_context_len = max_context_len
        self.input_embeds_size = input_embeds_size
        self.attn_types = attn_types
        self.increment_position_ids = increment_position_ids
        self.agent_emb = agent_emb
        self.node_features = node_features
        self.bin_aggregate = bin_aggregate
        self.no_early_node_features_projection = no_early_node_features_projection
        self.late_node_features_projection = late_node_features_projection
        self.group_softmax = group_softmax
        self.sum_group_softmax = sum_group_softmax
        self.attend_all_group_softmax = attend_all_group_softmax
        self.late_group_softmax = late_group_softmax
        self.late_node_features_projection_norm = late_node_features_projection_norm
        self.sum_bin_aggregate = sum_bin_aggregate
        self.causal_bin_aggregate = causal_bin_aggregate
        
        if "sliding_attention" in self.layer_types:
            raise ValueError("Sliding attention is not supported for hivemind")