import json
from transformers import LlamaConfig


class LlamaConfigAttn(LlamaConfig):
    def __init__(
        self,
        amplify_topk=None,
        amplify_topk_head=None,
        amplify_layer_head=None,
        amplify_head_id=None,
        amplify_k_scope=None,
        amplify_q_scope=None,
        amplify_reverse=None,
        amplify_total_topk=None,
        amplify_decay=None,
        amplify_total_threshold=None,
        amplify_total_threshold_output=None,
        amplify_total_threshold_upper=None,
        amplify_output_factor=None,
        amplify_exclude_self=None,
        amplify_exclude_cal=None,
        amplify_uncert_threshold=None,
        amplify_uncert_type=None,
        amplify_uncert_constant=None,
        amplify_uncert_upper=None,
        amplify_uncert_score_type=None,
        amplify_skip_stopwords=None,
        amplify_skip_penalty=None,
        amplify_use_sink=None,
        amplify_factor=None,
        amplify_smooth_window=None,
        # amplify_topk_grain=None,
        start_layer=None,
        end_layer=None,
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.amplify_layer_head = amplify_layer_head
        self.amplify_topk = amplify_topk
        self.amplify_topk_head = amplify_topk_head
        self.amplify_head_id = amplify_head_id
        self.amplify_k_scope = amplify_k_scope
        self.amplify_q_scope = amplify_q_scope
        self.amplify_reverse = amplify_reverse
        self.amplify_total_topk = amplify_total_topk
        self.amplify_decay = amplify_decay
        self.amplify_total_threshold = amplify_total_threshold
        self.amplify_total_threshold_output = amplify_total_threshold_output
        self.amplify_total_threshold_upper = amplify_total_threshold_upper
        self.amplify_output_factor = amplify_output_factor
        self.amplify_exclude_self = amplify_exclude_self
        self.amplify_exclude_cal = amplify_exclude_cal
        self.amplify_uncert_threshold = amplify_uncert_threshold
        self.amplify_uncert_type = amplify_uncert_type
        self.amplify_uncert_constant = amplify_uncert_constant
        self.amplify_uncert_upper = amplify_uncert_upper
        self.amplify_uncert_score_type = amplify_uncert_score_type
        self.amplify_skip_stopwords = amplify_skip_stopwords
        self.amplify_skip_penalty = amplify_skip_penalty
        self.amplify_use_sink = amplify_use_sink
        self.amplify_factor = amplify_factor
        self.amplify_smooth_window = amplify_smooth_window
        # self.amplify_topk_grain = amplify_topk_grain
        self.start_layer = start_layer
        self.end_layer = end_layer if end_layer is not None else self.num_hidden_layers
