

""" Load a checkpoint file of pretrained transformer to a model in pytorch """

import numpy as np
import tensorflow as tf
import torch
#import ipdb
#from models import *

def load_param(checkpoint_file, conversion_table):
    """
    Load parameters in pytorch model from checkpoint file according to conversion_table
    checkpoint_file : pretrained checkpoint model file in tensorflow
    conversion_table : { pytorch tensor in a model : checkpoint variable name }
    """
    for pyt_param, tf_param_name in conversion_table.items():
        tf_param = tf.train.load_variable(checkpoint_file, tf_param_name)

        # for weight(kernel), we should do transpose
        if tf_param_name.endswith('kernel'):
            tf_param = np.transpose(tf_param)

        assert pyt_param.size() == tf_param.shape, \
            'Dim Mismatch: %s vs %s ; %s' % \
                (tuple(pyt_param.size()), tf_param.shape, tf_param_name)
        
        # assign pytorch tensor from tensorflow param
        pyt_param.data = torch.from_numpy(tf_param)


def load_model(model, checkpoint_file):
    """ Load the pytorch model from checkpoint file """

    # Embedding layer
    e, p = model.embed, 'bert/embeddings/'
    load_param(checkpoint_file, {
        e.tok_embed.weight: p+"word_embeddings",
        e.pos_embed.weight: p+"position_embeddings",
        e.seg_embed.weight: p+"token_type_embeddings",
        e.norm.gamma:       p+"LayerNorm/gamma",
        e.norm.beta:        p+"LayerNorm/beta"
    })

    # Transformer blocks
    for i in range(len(model.blocks)):
        b, p = model.blocks[i], "bert/encoder/layer_%d/"%i
        load_param(checkpoint_file, {
            b.attn.proj_q.weight:   p+"attention/self/query/kernel",
            b.attn.proj_q.bias:     p+"attention/self/query/bias",
            b.attn.proj_k.weight:   p+"attention/self/key/kernel",
            b.attn.proj_k.bias:     p+"attention/self/key/bias",
            b.attn.proj_v.weight:   p+"attention/self/value/kernel",
            b.attn.proj_v.bias:     p+"attention/self/value/bias",
            b.proj.weight:          p+"attention/output/dense/kernel",
            b.proj.bias:            p+"attention/output/dense/bias",
            b.pwff.fc1.weight:      p+"intermediate/dense/kernel",
            b.pwff.fc1.bias:        p+"intermediate/dense/bias",
            b.pwff.fc2.weight:      p+"output/dense/kernel",
            b.pwff.fc2.bias:        p+"output/dense/bias",
            b.norm1.gamma:          p+"attention/output/LayerNorm/gamma",
            b.norm1.beta:           p+"attention/output/LayerNorm/beta",
            b.norm2.gamma:          p+"output/LayerNorm/gamma",
            b.norm2.beta:           p+"output/LayerNorm/beta",
        })

