"""Utilities for HuggingFace."""
import os
from typing import Tuple, Union

import tensorflow as tf
from transformers import PreTrainedTokenizer
from transformers import TFBertPreTrainedModel, BertConfig
from transformers import TFRobertaPreTrainedModel, RobertaConfig


def get_body_and_head(
    model: Union[TFBertPreTrainedModel, TFRobertaPreTrainedModel]
) -> Tuple[tf.keras.layers.Layer, tf.keras.layers.Layer]:
    # BERT appears to have a dropout layer between the body and head, so this
    # takes care of that.
    body, *head = [el for el in model.layers if not isinstance(el, tf.keras.layers.Dropout)]
    if not head:
        head = None
    elif len(head) > 1:
        raise ValueError(
            f"Expected model to have a single 'head' layer. Instead found {len(head)}. TODO: Support this."
        )
    else:
        head = head[0]
    return body, head


def get_body(model):
    return get_body_and_head(model)[0]


def get_mergeable_variables(model):
    return get_body_and_head(model)[0].trainable_variables


def get_all_variables(model):
    return model.trainable_variables


def clone_model(model):
    cloned = model.__class__(model.config)
    cloned(model.dummy_inputs)
    cloned.set_weights(model.get_weights())
    return cloned

###############################################################################


class HfSaverCallback(tf.keras.callbacks.Callback):
    def __init__(self, filepath: str):
        self.filepath = os.path.expanduser(filepath)

    def on_epoch_end(self, epoch, logs=None):
        self.model.save_pretrained(self.filepath)
        # TODO: Quick hack before bed. Make nicer.
        self.model.save_pretrained(self.filepath + f'_epoch{epoch}')


###############################################################################

def make_roberta_config(
    tokenizer: PreTrainedTokenizer,
    hidden_size: int,
    num_hidden_layers: int,
    max_position_embeddings: int,
):
    intermediate_size = 4 * hidden_size
    num_attention_heads = hidden_size // 64

    return RobertaConfig.from_dict({
        "model_type": "roberta",
        "bos_token_id": tokenizer.bos_token_id,
        "eos_token_id": tokenizer.eos_token_id,
        "pad_token_id": tokenizer.pad_token_id,
        "hidden_dropout_prob": 0.1,
        "attention_probs_dropout_prob": 0.1,
        "initializer_range": 0.02,
        "hidden_size": hidden_size,
        "intermediate_size": intermediate_size,
        "hidden_act": "gelu",
        "layer_norm_eps": 1e-05,
        "max_position_embeddings": max_position_embeddings,
        "num_attention_heads": num_attention_heads,
        "num_hidden_layers": num_hidden_layers,
        "type_vocab_size": 1,
        "vocab_size": tokenizer.vocab_size
    })


def make_bert_config(
    tokenizer: PreTrainedTokenizer,
    hidden_size: int,
    num_hidden_layers: int,
    max_position_embeddings: int,
):
    intermediate_size = 4 * hidden_size
    num_attention_heads = hidden_size // 64

    return BertConfig.from_dict({
        "model_type": "bert",
        "bos_token_id": tokenizer.bos_token_id,
        "eos_token_id": tokenizer.eos_token_id,
        "pad_token_id": tokenizer.pad_token_id,
        "hidden_dropout_prob": 0.1,
        "attention_probs_dropout_prob": 0.1,
        "initializer_range": 0.02,
        "hidden_size": hidden_size,
        "intermediate_size": intermediate_size,
        "hidden_act": "gelu",
        "layer_norm_eps": 1e-12,
        "max_position_embeddings": max_position_embeddings,
        "num_attention_heads": num_attention_heads,
        "num_hidden_layers": num_hidden_layers,
        "type_vocab_size": 2,
        "vocab_size": tokenizer.vocab_size
    })
