import re


class ParamNameMgr():
    def __init__(self):
        self.pattern_layer = re.compile(r"\.layers\.(\d+?)\.")  # language_model.encoder.layers.#.xxx
        self.var_name_pattern_dict = {
            'wte': re.compile(r"\.word_embeddings\."),  # language_model.embedding.word_embeddings.weight
            'wpe': re.compile(r"\.position_embeddings\."),  # language_model.embedding.position_embeddings.weight
            'ln1': re.compile(r"\.input_norm\."),  # language_model.encoder.layers.#.input_norm.weight/bias
            'qkv': re.compile(r"\.query_key_value\."),  # language_model.encoder.layers.#.self_attention.query_key_value.weight/bias
            'dense': re.compile(r"\.self_attention\.dense\."),  # language_model.encoder.layers.#.self_attention.dense.weight/bias
            'ln2': re.compile(r"\.post_attention_norm\."),  # language_model.encoder.layers.#.post_attention_layernorm.weight/bias
            'mlp1': re.compile(r"\.mlp\.dense_h_to_4h\."),  # language_model.encoder.layers.#.mlp.dense_x_to_x.weight/bias
            'mlp2': re.compile(r"\.mlp\.dense_4h_to_h\."),  # language_model.encoder.layers.#.mlp.dense_x_to_x.weight/bias
            'final_norm': re.compile(r"\.final_norm\.")  # language_model.encoder.final_norm.weight/bias
        }
        # self.other_names = ['wte', 'ln_f']  # none layer parameters
        # self.no_bias_names = ['wte', 'gate', 'moe1', 'moe2']
        self.default_cate = 'others'

    # def get_all_metric_names(self, num_layers):
    #     name_list = []

    #     for key in self.other_names:
    #         name_list.append(key)
    #         if key not in self.no_bias_names:
    #             name_list.append(f'{key}_bias')

    #     for ly in range(num_layers):
    #         layer = ly + 1
    #         for key in self.var_name_pattern_dict.keys():
    #             if key not in self.other_names:
    #                 name = f'h_{layer}_{key}'
    #                 name_list.append(name)
    #                 if key not in self.no_bias_names:
    #                     name_list.append(f'{name}_bias')

    #     return name_list

    def get_all_cate_names(self):
        cate_names = list(self.var_name_pattern_dict.keys())
        cate_names.append(self.default_cate)
        return cate_names

    def get_metric_name(self, name, layer):
        metric = ""
        for key, pattern in self.var_name_pattern_dict.items():
            m = pattern.search(name)
            if m:
                if layer != 0:
                    metric = f'h_{layer}_{key}'
                else:
                    metric = key
                if name.endswith(".bias"):
                    metric += '_bias'
                break
        return metric

    def get_var_layer(self, name, model):
        m = self.pattern_layer.search(name)
        if m:
            local_ly = int(m.group(1))
            ly = model.module.module.language_model.encoder.layers[local_ly].layer_number
        else:
            ly = 0  # for others
        return ly

    def get_var_category(self, name):
        metric = self.default_cate
        is_default = True
        for key, pattern in self.var_name_pattern_dict.items():
            m = pattern.search(name)
            if m:
                metric = key
                is_default = False
                break
        if is_default:
            print(f'name={name} not categorized...')
        return metric


# GPT-2 Small:
'''
language_model.embedding.word_embeddings.weight: torch.Size([50048, 768])
language_model.embedding.position_embeddings.weight: torch.Size([1024, 768])
language_model.encoder.layers.0.input_norm.weight: torch.Size([768])
language_model.encoder.layers.0.input_norm.bias: torch.Size([768])
language_model.encoder.layers.0.self_attention.query_key_value.weight: torch.Size([2304, 768])
language_model.encoder.layers.0.self_attention.query_key_value.bias: torch.Size([2304])
language_model.encoder.layers.0.self_attention.dense.weight: torch.Size([768, 768])
language_model.encoder.layers.0.self_attention.dense.bias: torch.Size([768])
language_model.encoder.layers.0.post_attention_norm.weight: torch.Size([768])
language_model.encoder.layers.0.post_attention_norm.bias: torch.Size([768])
language_model.encoder.layers.0.mlp.dense_h_to_4h.weight: torch.Size([3072, 768])
language_model.encoder.layers.0.mlp.dense_h_to_4h.bias: torch.Size([3072])
language_model.encoder.layers.0.mlp.dense_4h_to_h.weight: torch.Size([768, 3072])
language_model.encoder.layers.0.mlp.dense_4h_to_h.bias: torch.Size([768])
language_model.encoder.layers.1.input_norm.weight: torch.Size([768])
language_model.encoder.layers.1.input_norm.bias: torch.Size([768])
language_model.encoder.layers.1.self_attention.query_key_value.weight: torch.Size([2304, 768])
language_model.encoder.layers.1.self_attention.query_key_value.bias: torch.Size([2304])
language_model.encoder.layers.1.self_attention.dense.weight: torch.Size([768, 768])
language_model.encoder.layers.1.self_attention.dense.bias: torch.Size([768])
language_model.encoder.layers.1.post_attention_norm.weight: torch.Size([768])
language_model.encoder.layers.1.post_attention_norm.bias: torch.Size([768])
language_model.encoder.layers.1.mlp.dense_h_to_4h.weight: torch.Size([3072, 768])
language_model.encoder.layers.1.mlp.dense_h_to_4h.bias: torch.Size([3072])
language_model.encoder.layers.1.mlp.dense_4h_to_h.weight: torch.Size([768, 3072])
language_model.encoder.layers.1.mlp.dense_4h_to_h.bias: torch.Size([768])
language_model.encoder.layers.2.input_norm.weight: torch.Size([768])
language_model.encoder.layers.2.input_norm.bias: torch.Size([768])
language_model.encoder.layers.2.self_attention.query_key_value.weight: torch.Size([2304, 768])
language_model.encoder.layers.2.self_attention.query_key_value.bias: torch.Size([2304])
language_model.encoder.layers.2.self_attention.dense.weight: torch.Size([768, 768])
language_model.encoder.layers.2.self_attention.dense.bias: torch.Size([768])
language_model.encoder.layers.2.post_attention_norm.weight: torch.Size([768])
language_model.encoder.layers.2.post_attention_norm.bias: torch.Size([768])
language_model.encoder.layers.2.mlp.dense_h_to_4h.weight: torch.Size([3072, 768])
language_model.encoder.layers.2.mlp.dense_h_to_4h.bias: torch.Size([3072])
language_model.encoder.layers.2.mlp.dense_4h_to_h.weight: torch.Size([768, 3072])
language_model.encoder.layers.2.mlp.dense_4h_to_h.bias: torch.Size([768])
language_model.encoder.layers.3.input_norm.weight: torch.Size([768])
language_model.encoder.layers.3.input_norm.bias: torch.Size([768])
language_model.encoder.layers.3.self_attention.query_key_value.weight: torch.Size([2304, 768])
language_model.encoder.layers.3.self_attention.query_key_value.bias: torch.Size([2304])
language_model.encoder.layers.3.self_attention.dense.weight: torch.Size([768, 768])
language_model.encoder.layers.3.self_attention.dense.bias: torch.Size([768])
language_model.encoder.layers.3.post_attention_norm.weight: torch.Size([768])
language_model.encoder.layers.3.post_attention_norm.bias: torch.Size([768])
language_model.encoder.layers.3.mlp.dense_h_to_4h.weight: torch.Size([3072, 768])
language_model.encoder.layers.3.mlp.dense_h_to_4h.bias: torch.Size([3072])
language_model.encoder.layers.3.mlp.dense_4h_to_h.weight: torch.Size([768, 3072])
language_model.encoder.layers.3.mlp.dense_4h_to_h.bias: torch.Size([768])
language_model.encoder.layers.4.input_norm.weight: torch.Size([768])
language_model.encoder.layers.4.input_norm.bias: torch.Size([768])
language_model.encoder.layers.4.self_attention.query_key_value.weight: torch.Size([2304, 768])
language_model.encoder.layers.4.self_attention.query_key_value.bias: torch.Size([2304])
language_model.encoder.layers.4.self_attention.dense.weight: torch.Size([768, 768])
language_model.encoder.layers.4.self_attention.dense.bias: torch.Size([768])
language_model.encoder.layers.4.post_attention_norm.weight: torch.Size([768])
language_model.encoder.layers.4.post_attention_norm.bias: torch.Size([768])
language_model.encoder.layers.4.mlp.dense_h_to_4h.weight: torch.Size([3072, 768])
language_model.encoder.layers.4.mlp.dense_h_to_4h.bias: torch.Size([3072])
language_model.encoder.layers.4.mlp.dense_4h_to_h.weight: torch.Size([768, 3072])
language_model.encoder.layers.4.mlp.dense_4h_to_h.bias: torch.Size([768])
language_model.encoder.layers.5.input_norm.weight: torch.Size([768])
language_model.encoder.layers.5.input_norm.bias: torch.Size([768])
language_model.encoder.layers.5.self_attention.query_key_value.weight: torch.Size([2304, 768])
language_model.encoder.layers.5.self_attention.query_key_value.bias: torch.Size([2304])
language_model.encoder.layers.5.self_attention.dense.weight: torch.Size([768, 768])
language_model.encoder.layers.5.self_attention.dense.bias: torch.Size([768])
language_model.encoder.layers.5.post_attention_norm.weight: torch.Size([768])
language_model.encoder.layers.5.post_attention_norm.bias: torch.Size([768])
language_model.encoder.layers.5.mlp.dense_h_to_4h.weight: torch.Size([3072, 768])
language_model.encoder.layers.5.mlp.dense_h_to_4h.bias: torch.Size([3072])
language_model.encoder.layers.5.mlp.dense_4h_to_h.weight: torch.Size([768, 3072])
language_model.encoder.layers.5.mlp.dense_4h_to_h.bias: torch.Size([768])
language_model.encoder.layers.6.input_norm.weight: torch.Size([768])
language_model.encoder.layers.6.input_norm.bias: torch.Size([768])
language_model.encoder.layers.6.self_attention.query_key_value.weight: torch.Size([2304, 768])
language_model.encoder.layers.6.self_attention.query_key_value.bias: torch.Size([2304])
language_model.encoder.layers.6.self_attention.dense.weight: torch.Size([768, 768])
language_model.encoder.layers.6.self_attention.dense.bias: torch.Size([768])
language_model.encoder.layers.6.post_attention_norm.weight: torch.Size([768])
language_model.encoder.layers.6.post_attention_norm.bias: torch.Size([768])
language_model.encoder.layers.6.mlp.dense_h_to_4h.weight: torch.Size([3072, 768])
language_model.encoder.layers.6.mlp.dense_h_to_4h.bias: torch.Size([3072])
language_model.encoder.layers.6.mlp.dense_4h_to_h.weight: torch.Size([768, 3072])
language_model.encoder.layers.6.mlp.dense_4h_to_h.bias: torch.Size([768])
language_model.encoder.layers.7.input_norm.weight: torch.Size([768])
language_model.encoder.layers.7.input_norm.bias: torch.Size([768])
language_model.encoder.layers.7.self_attention.query_key_value.weight: torch.Size([2304, 768])
language_model.encoder.layers.7.self_attention.query_key_value.bias: torch.Size([2304])
language_model.encoder.layers.7.self_attention.dense.weight: torch.Size([768, 768])
language_model.encoder.layers.7.self_attention.dense.bias: torch.Size([768])
language_model.encoder.layers.7.post_attention_norm.weight: torch.Size([768])
language_model.encoder.layers.7.post_attention_norm.bias: torch.Size([768])
language_model.encoder.layers.7.mlp.dense_h_to_4h.weight: torch.Size([3072, 768])
language_model.encoder.layers.7.mlp.dense_h_to_4h.bias: torch.Size([3072])
language_model.encoder.layers.7.mlp.dense_4h_to_h.weight: torch.Size([768, 3072])
language_model.encoder.layers.7.mlp.dense_4h_to_h.bias: torch.Size([768])
language_model.encoder.layers.8.input_norm.weight: torch.Size([768])
language_model.encoder.layers.8.input_norm.bias: torch.Size([768])
language_model.encoder.layers.8.self_attention.query_key_value.weight: torch.Size([2304, 768])
language_model.encoder.layers.8.self_attention.query_key_value.bias: torch.Size([2304])
language_model.encoder.layers.8.self_attention.dense.weight: torch.Size([768, 768])
language_model.encoder.layers.8.self_attention.dense.bias: torch.Size([768])
language_model.encoder.layers.8.post_attention_norm.weight: torch.Size([768])
language_model.encoder.layers.8.post_attention_norm.bias: torch.Size([768])
language_model.encoder.layers.8.mlp.dense_h_to_4h.weight: torch.Size([3072, 768])
language_model.encoder.layers.8.mlp.dense_h_to_4h.bias: torch.Size([3072])
language_model.encoder.layers.8.mlp.dense_4h_to_h.weight: torch.Size([768, 3072])
language_model.encoder.layers.8.mlp.dense_4h_to_h.bias: torch.Size([768])
language_model.encoder.layers.9.input_norm.weight: torch.Size([768])
language_model.encoder.layers.9.input_norm.bias: torch.Size([768])
language_model.encoder.layers.9.self_attention.query_key_value.weight: torch.Size([2304, 768])
language_model.encoder.layers.9.self_attention.query_key_value.bias: torch.Size([2304])
language_model.encoder.layers.9.self_attention.dense.weight: torch.Size([768, 768])
language_model.encoder.layers.9.self_attention.dense.bias: torch.Size([768])
language_model.encoder.layers.9.post_attention_norm.weight: torch.Size([768])
language_model.encoder.layers.9.post_attention_norm.bias: torch.Size([768])
language_model.encoder.layers.9.mlp.dense_h_to_4h.weight: torch.Size([3072, 768])
language_model.encoder.layers.9.mlp.dense_h_to_4h.bias: torch.Size([3072])
language_model.encoder.layers.9.mlp.dense_4h_to_h.weight: torch.Size([768, 3072])
language_model.encoder.layers.9.mlp.dense_4h_to_h.bias: torch.Size([768])
language_model.encoder.layers.10.input_norm.weight: torch.Size([768])
language_model.encoder.layers.10.input_norm.bias: torch.Size([768])
language_model.encoder.layers.10.self_attention.query_key_value.weight: torch.Size([2304, 768])
language_model.encoder.layers.10.self_attention.query_key_value.bias: torch.Size([2304])
language_model.encoder.layers.10.self_attention.dense.weight: torch.Size([768, 768])
language_model.encoder.layers.10.self_attention.dense.bias: torch.Size([768])
language_model.encoder.layers.10.post_attention_norm.weight: torch.Size([768])
language_model.encoder.layers.10.post_attention_norm.bias: torch.Size([768])
language_model.encoder.layers.10.mlp.dense_h_to_4h.weight: torch.Size([3072, 768])
language_model.encoder.layers.10.mlp.dense_h_to_4h.bias: torch.Size([3072])
language_model.encoder.layers.10.mlp.dense_4h_to_h.weight: torch.Size([768, 3072])
language_model.encoder.layers.10.mlp.dense_4h_to_h.bias: torch.Size([768])
language_model.encoder.layers.11.input_norm.weight: torch.Size([768])
language_model.encoder.layers.11.input_norm.bias: torch.Size([768])
language_model.encoder.layers.11.self_attention.query_key_value.weight: torch.Size([2304, 768])
language_model.encoder.layers.11.self_attention.query_key_value.bias: torch.Size([2304])
language_model.encoder.layers.11.self_attention.dense.weight: torch.Size([768, 768])
language_model.encoder.layers.11.self_attention.dense.bias: torch.Size([768])
language_model.encoder.layers.11.post_attention_norm.weight: torch.Size([768])
language_model.encoder.layers.11.post_attention_norm.bias: torch.Size([768])
language_model.encoder.layers.11.mlp.dense_h_to_4h.weight: torch.Size([3072, 768])
language_model.encoder.layers.11.mlp.dense_h_to_4h.bias: torch.Size([3072])
language_model.encoder.layers.11.mlp.dense_4h_to_h.weight: torch.Size([768, 3072])
language_model.encoder.layers.11.mlp.dense_4h_to_h.bias: torch.Size([768])
language_model.encoder.final_norm.weight: torch.Size([768])
language_model.encoder.final_norm.bias: torch.Size([768])
'''
