"""Parameter informations for BERT."""
import dataclasses
import re
from typing import Union

###############################################################################


def is_embeddings_layer(v: str) -> bool:
    return '/embeddings/' in v


def is_encoder_layer(v: str) -> bool:
    return '/encoder/' in v


def is_pooler_layer(v: str) -> bool:
    return '/pooler/' in v


def is_classifier_layer(v: str) -> bool:
    return '/classifier/' in v


###############################################################################

def extract_embeddings_sublayer(v: str) -> Union[None, str]:
    m = re.search(r'/embeddings/(.+)/\w+\:', v)
    return m.group(1) if m else None


def extract_encoder_layer_index(v: str) -> Union[None, str]:
    m = re.search(r'/encoder/layer_\._(\d+)/', v)
    return int(m.group(1)) if m else None


def extract_encoder_sublayer(v: str) -> Union[None, str]:
    m = re.search(r'/encoder/layer_\._\d+/(.+)/\w+\:', v)
    return m.group(1) if m else None


def extract_variable_type(v: str) -> Union[None, str]:
    m = re.search(r'/(\w+)\:', v)
    return m.group(1) if m else None


###############################################################################


def to_nice_name(v: str) -> str:
    var_type = extract_variable_type(v)

    if is_embeddings_layer(v):
        sublayer = extract_embeddings_sublayer(v)
        assert sublayer is not None
        if sublayer == 'LayerNorm':
            assert var_type is not None
            sublayer = f'{sublayer}/{var_type}'
        return f'embeddings/{sublayer}'

    elif is_encoder_layer(v):
        layer_index = extract_encoder_layer_index(v)
        sublayer = extract_encoder_sublayer(v)
        assert layer_index is not None
        assert sublayer is not None
        assert var_type is not None
        return f'encoder/layer{layer_index}/{sublayer}/{var_type}'

    elif is_pooler_layer(v):
        assert var_type is not None
        return f'pooler/{var_type}'

    elif is_classifier_layer(v):
        assert var_type is not None
        return f'classifier/{var_type}'

    else:
        raise ValueError(f'Unrecognized variable: {v}')


###############################################################################
###############################################################################
R"""Example variable names:

# Embeddings:
tf_bert_for_sequence_classification/bert/embeddings/word_embeddings/weight:0
tf_bert_for_sequence_classification/bert/embeddings/token_type_embeddings/embeddings:0
tf_bert_for_sequence_classification/bert/embeddings/position_embeddings/embeddings:0
tf_bert_for_sequence_classification/bert/embeddings/LayerNorm/gamma:0
tf_bert_for_sequence_classification/bert/embeddings/LayerNorm/beta:0

# Encoder layer:
tf_bert_for_sequence_classification/bert/encoder/layer_._8/attention/self/query/kernel:0
tf_bert_for_sequence_classification/bert/encoder/layer_._8/attention/self/query/bias:0
tf_bert_for_sequence_classification/bert/encoder/layer_._8/attention/self/key/kernel:0
tf_bert_for_sequence_classification/bert/encoder/layer_._8/attention/self/key/bias:0
tf_bert_for_sequence_classification/bert/encoder/layer_._8/attention/self/value/kernel:0
tf_bert_for_sequence_classification/bert/encoder/layer_._8/attention/self/value/bias:0
tf_bert_for_sequence_classification/bert/encoder/layer_._8/attention/output/dense/kernel:0
tf_bert_for_sequence_classification/bert/encoder/layer_._8/attention/output/dense/bias:0
tf_bert_for_sequence_classification/bert/encoder/layer_._8/attention/output/LayerNorm/gamma:0
tf_bert_for_sequence_classification/bert/encoder/layer_._8/attention/output/LayerNorm/beta:0
tf_bert_for_sequence_classification/bert/encoder/layer_._8/intermediate/dense/kernel:0
tf_bert_for_sequence_classification/bert/encoder/layer_._8/intermediate/dense/bias:0
tf_bert_for_sequence_classification/bert/encoder/layer_._8/output/dense/kernel:0
tf_bert_for_sequence_classification/bert/encoder/layer_._8/output/dense/bias:0
tf_bert_for_sequence_classification/bert/encoder/layer_._8/output/LayerNorm/gamma:0
tf_bert_for_sequence_classification/bert/encoder/layer_._8/output/LayerNorm/beta:0

# Pooler:
tf_bert_for_sequence_classification/bert/pooler/dense/kernel:0
tf_bert_for_sequence_classification/bert/pooler/dense/bias:0

# Classifier:
tf_bert_for_sequence_classification/classifier/kernel:0
tf_bert_for_sequence_classification/classifier/bias:0

"""
