"""Model variable interaction classes for generic transformers.

For now, this focuses on BERT/RoBERTa.
"""
import abc
import collections
import dataclasses
import re
from typing import Any, List, Optional, Sequence, Union

from absl import flags
import numpy as np
import tensorflow as tf
import wrapt


FLAGS = flags.FLAGS


# Some typedefs.
NameOrVariable = Union[str, tf.Variable]


_ATTENTION_SUBLAYERS = frozenset({
    "attention/self/query",
    "attention/self/key",
    "attention/self/value",
    "attention/output/dense",
    "attention/output/LayerNorm",
})

_FFW_SUBLAYERS = frozenset({
    "intermediate/dense",
    "output/dense",
    "output/LayerNorm",
})


_SUBLAYER_ORDER = [
    [
        "attention/self/query",
        "attention/self/key",
        "attention/self/value",
    ],
    ["attention/output/dense"],
    ["attention/output/LayerNorm"],
    ["intermediate/dense"],
    ["output/dense"],
    ["output/LayerNorm"],
]


"""DELETE LATER


# Embeddings: (for both Bert and RoBERTa)
# tf_roberta_for_sequence_classification/roberta/embeddings/word_embeddings/weight:0
# tf_roberta_for_sequence_classification/roberta/embeddings/token_type_embeddings/embeddings:0
# tf_roberta_for_sequence_classification/roberta/embeddings/position_embeddings/embeddings:0
# tf_roberta_for_sequence_classification/roberta/embeddings/LayerNorm/gamma:0
# tf_roberta_for_sequence_classification/roberta/embeddings/LayerNorm/beta:0

# BERT-only
# tf_bert_for_sequence_classification/bert/pooler/dense/kernel:0
# tf_bert_for_sequence_classification/bert/pooler/dense/bias:0


Let's think of use-cases/interfaces/command-line interfaces.


--merge_layer_indices: List[int]
--merge_embeddings: bool
--merge_pooler: bool (BERT-only)
--merge_vars_subset: str (probably fairly direct specification of sublayer along with some predefined groups)

# Probably these (more like QoL common alternatives for --merge_layer_indices)
--merge_min_layer: int
--merge_max_layer: int

# Maybe these (more like QoL common alternatives for --merge_vars_subset)
--merge_layer_norms: bool
--merge_biases: bool
--merge_kernels: bool


"""


@dataclasses.dataclass
class VariableFilter:
    # Sequence of 0-based indices of layers to include.
    layer_indices: Optional[Sequence[int]] = None

    # TODO: Maybe add option to configure what embeddings variables to merge.
    # Whether or not to merge to embeddings layer.
    merge_embeddings: bool = True

    # TODO: Maybe add option to configure what pooler variables to merge.
    # NOTE: It looks like RoBERTa doesn't have a pooler, so this
    # will have an effect on BERT models only.
    merge_pooler: bool = True

    merge_attention: bool = True
    merge_ffw: bool = True

    # NOTE: Something added later, hopefully this doesn't break stuff.
    merge_classifier: bool = False

    def __post_init__(self):
        pass

    def does_variable_match(self, v: NameOrVariable) -> bool:
        v = var_name(v)

        if is_pooler_layer(v):
            return bool(self.merge_pooler)

        if is_classifier_layer(v):
            return bool(self.merge_classifier)

        embeddings_sublayer = extract_embeddings_variable_sublayer(v)
        if embeddings_sublayer is not None:
            return bool(self.merge_embeddings)

        layer_index = extract_layer_index(v)
        if layer_index is None:
            raise ValueError(f'Unsupported variable: {v}')

        if self.layer_indices is not None and layer_index not in self.layer_indices:
            return False

        sublayer = extract_encoder_variable_sublayer(v)

        if sublayer in _ATTENTION_SUBLAYERS:
            return bool(self.merge_attention)
        elif sublayer in _FFW_SUBLAYERS:
            return bool(self.merge_ffw)
        else:
            raise ValueError(f'Unsupported variable: {v}')

        return True

    def filter_parallel_lists(
        self,
        named_variables: Sequence[NameOrVariable],
        *other_lists: Sequence[Sequence[Any]],
    ):
        """Filters the passed lists by the variable (names) in the first passed list."""
        # Make sure all inputs have the same length.
        assert len({len(named_variables)} | {len(a) for a in other_lists}) == 1
        # NOTE: I can probably generalize this later to accept a general VariableFilter object.
        inds = {i for i, v in enumerate(named_variables) if self.does_variable_match(v)}
        all_lists = [named_variables, *other_lists]
        ret = []
        for lst in all_lists:
            filtered = []
            for i, x in enumerate(lst):
                if i in inds:
                    filtered.append(x)
            ret.append(filtered)

        if len(ret) == 1:
            return ret[0]
        else:
            return tuple(ret)


def add_variable_filter_flags(prefix: str = 'merge'):
    flags.DEFINE_integer(f"{prefix}_min_layer", None, "Zero-based inclusive index of first layer to merge.")
    flags.DEFINE_integer(f"{prefix}_max_layer", None, "Zero-based exclusive index of last layer to merge.")

    flags.DEFINE_bool(f"{prefix}_embeddings", True, "")
    flags.DEFINE_bool(f"{prefix}_pooler", True, "")

    flags.DEFINE_bool(f"{prefix}_attention", True, "")
    flags.DEFINE_bool(f"{prefix}_ffw", True, "")

    flags.DEFINE_bool(f"{prefix}_classifier", False, "")


def _get_flag(prefix: str, name: str):
    return getattr(FLAGS, f'{prefix}_{name}')


def get_variable_filter_from_flags(prefix: str = 'merge') -> VariableFilter:
    min_layer = _get_flag(prefix, "min_layer")
    max_layer = _get_flag(prefix, "max_layer")

    if min_layer is not None and max_layer is None:
        # TODO: Support this case.
        raise ValueError('If specifying --merge_min_layer, must also specify --merge_max_layer.')
    elif min_layer is None and max_layer is not None:
        layer_indices = list(range(max_layer))
    elif min_layer is not None and max_layer is not None:
        layer_indices = list(range(min_layer, max_layer))
    else:
        layer_indices = None

    return VariableFilter(
        layer_indices=layer_indices,
        merge_embeddings=_get_flag(prefix, 'embeddings'),
        merge_pooler=_get_flag(prefix, 'pooler'),
        merge_attention=_get_flag(prefix, 'attention'),
        merge_ffw=_get_flag(prefix, 'ffw'),
        merge_classifier=_get_flag(prefix, 'classifier'),
    )


def var_name(v: Union[None, NameOrVariable]) -> Union[None, str]:
    if isinstance(v, str):
        return v
    elif v is None:
        return None
    else:
        return v.name


def extract_layer_index(v: NameOrVariable) -> Union[None, str]:
    v = var_name(v)
    m = re.search(r'/encoder/layer_\._(\d+)/', v)
    return int(m.group(1)) if m else None


def extract_encoder_variable_sublayer(v: NameOrVariable) -> Union[None, str]:
    v = var_name(v)
    m = re.search(r'/encoder/layer_\._\d+/(.+)/\w+\:', v)
    return m.group(1) if m else None


def extract_embeddings_variable_sublayer(v: NameOrVariable) -> Union[None, str]:
    v = var_name(v)
    m = re.search(r'/embeddings/(.+)/\w+\:', v)
    return m.group(1) if m else None


def is_pooler_layer(v: NameOrVariable) -> bool:
    v = var_name(v)
    return '/pooler/' in v


def is_classifier_layer(v: NameOrVariable) -> bool:
    v = var_name(v)
    return '/classifier/' in v


def to_nice_name(v: NameOrVariable) -> str:
    v = var_name(v)
    layer_index = extract_layer_index(v)
    if layer_index is None:
        raise ValueError(f'TODO: Support variable: {v}')
    sublayer = extract_encoder_variable_sublayer(v)
    name = v.split(':')[0]
    name = name.split(sublayer)[-1]
    return f'layer{layer_index}/{sublayer}{name}'


###############################################################################

class NamedProxy(wrapt.ObjectProxy):
    """Proxy wrapper for object letting us set names for tf.Tensors and np.ndarrays."""
    def __init__(self, wrappee: np.ndarray, name: str):
        super().__init__(wrappee)
        if isinstance(wrappee, tf.Tensor):
            raise TypeError('TODO: Support tensors, tf bugs out with the current version of this class.')
        self._self_name = name

    @property
    def name(self):
        return self._self_name


###############################################################################

def group_by_blocks(variables: Sequence[tf.Variable]) -> List[List[tf.Variable]]:
    # A block means what is commonly referred to as a transformer layer. Embeddings
    # and pooling are their own blocks.
    #
    # We assume that the variable in each block are contiguous in the input sequence.

    embedding_vars = []
    layer_index_to_vars = collections.OrderedDict()
    pooler_vars = []

    for v in variables:
        embeddings_sublayer = extract_embeddings_variable_sublayer(v)
        if embeddings_sublayer is not None:
            embedding_vars.append(v)
            continue

        if is_pooler_layer(v):
            pooler_vars.append(v)
            continue

        layer_index = extract_layer_index(v)
        if layer_index is None:
            raise ValueError(f'Unsupported variable: {v}')

        layer_index_to_vars.setdefault(layer_index, []).append(v)

    groups = [embedding_vars, *layer_index_to_vars.values(), pooler_vars]
    return [g for g in groups if len(g) > 0]


def group_by_sub_blocks(variables: Sequence[tf.Variable]) -> List[List[tf.Variable]]:
    # A sub-block is either the self-attention layer or the FFW layer in a block.Embeddings
    # and pooling are their own sub-blocks.
    #
    # We assume that the variable in each sub-block are contiguous in the input sequence.
    blocks = group_by_blocks(variables)

    sub_blocks = []

    for block in blocks:

        partition = collections.OrderedDict()
        for v in block:
            layer_index = extract_layer_index(v)

            if layer_index is None:
                partition.setdefault('other', []).append(v)
                continue

            sublayer = extract_encoder_variable_sublayer(v)

            if sublayer in _ATTENTION_SUBLAYERS:
                partition.setdefault('attn', []).append(v)
                continue
            elif sublayer in _FFW_SUBLAYERS:
                partition.setdefault('ffw', []).append(v)
                continue
            else:
                raise ValueError(f'Unsupported variable: {v}')

        sub_blocks.extend(partition.values())

    return sub_blocks


###############################################################################
# Stuff copied from previous repo commented out below.
###############################################################################

# def clean_name(v: Union[str, tf.Variable]):
#     if not isinstance(v, str):
#         v = v.name
#     return v.split(':')[0]

# def get_number_of_units_from_kernel(v: tf.Tensor):
#     return tf.shape(v)[-1]


def _get_layer_name(v: Union[str, tf.Variable]) -> str:
    v = var_name(v)
    name = v.split(':')[0]
    return "/".join(name.split('/')[:-1])


def group_by_kernel_bias(variables: Sequence[tf.Variable]):
    # Assumes all variables are either a kernel or a bias.
    layer_to_vars = collections.OrderedDict()
    for v in variables:
        layer_name = _get_layer_name(v)
        layer_to_vars.setdefault(layer_name, []).append(v)
    return [
        # (kernel, bias?)
        tuple(sorted(vs, key=lambda v: v.name, reverse=True))
        for vs in layer_to_vars.values()
    ]


def _tail_of_variable_name(v: NameOrVariable) -> str:
    # Typically used to see if it is a kernel/bias.
    v = var_name(v)
    name = v.split(':')[0]
    return name.split('/')[-1]


def homogenize_kernel_biases(variables: Sequence[tf.Variable]) -> List[tf.Tensor]:
    layer_to_vars = collections.OrderedDict()
    for v in variables:
        tail = _tail_of_variable_name(v)
        layer_name = _get_layer_name(v)
        if tail in ('kernel', 'bias'):
            layer_to_vars.setdefault(layer_name, []).append(v)
        else:
            layer_to_vars.setdefault(var_name(v), []).append(v)

    ret = []
    for vs in layer_to_vars.values():
        if len(vs) == 1:
            ret.append(vs[0])
            continue

        # kernel.shape = [d_input, d_output]
        # bias.shape = [d_output]
        kernel, bias = sorted(vs, key=lambda v: v.name, reverse=True)
        h = tf.concat([kernel, bias[..., None, :]], axis=-2)
        ret.append(h)

    return ret
