# coding=utf-8
# Copyright 2018 The HugginFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Transformer XL checkpoint and datasets."""

from __future__ import absolute_import, division, print_function

import argparse
import os
import sys
from io import open

import torch

import pytorch_pretrained_bert.tokenization_transfo_xl as data_utils
from pytorch_pretrained_bert.modeling_transfo_xl import (CONFIG_NAME,
                                                         WEIGHTS_NAME,
                                                         TransfoXLConfig,
                                                         TransfoXLLMHeadModel,
                                                         load_tf_weights_in_transfo_xl)
from pytorch_pretrained_bert.tokenization_transfo_xl import (CORPUS_NAME,
                                                             VOCAB_NAME)

if sys.version_info[0] == 2:
    import cPickle as pickle
else:
    import pickle

# We do this to be able to load python 2 datasets pickles
# See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918
data_utils.Vocab = data_utils.TransfoXLTokenizer
data_utils.Corpus = data_utils.TransfoXLCorpus
sys.modules['data_utils'] = data_utils
sys.modules['vocabulary'] = data_utils

def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
                                             transfo_xl_config_file,
                                             pytorch_dump_folder_path,
                                             transfo_xl_dataset_file):
    if transfo_xl_dataset_file:
        # Convert a pre-processed corpus (see original TensorFlow repo)
        with open(transfo_xl_dataset_file, "rb") as fp:
            corpus = pickle.load(fp, encoding="latin1")
        # Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term)
        pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_NAME
        print("Save vocabulary to {}".format(pytorch_vocab_dump_path))
        corpus_vocab_dict = corpus.vocab.__dict__
        torch.save(corpus_vocab_dict, pytorch_vocab_dump_path)

        corpus_dict_no_vocab = corpus.__dict__
        corpus_dict_no_vocab.pop('vocab', None)
        pytorch_dataset_dump_path = pytorch_dump_folder_path + '/' + CORPUS_NAME
        print("Save dataset to {}".format(pytorch_dataset_dump_path))
        torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path)

    if tf_checkpoint_path:
        # Convert a pre-trained TensorFlow model
        config_path = os.path.abspath(transfo_xl_config_file)
        tf_path = os.path.abspath(tf_checkpoint_path)

        print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path))
        # Initialise PyTorch model
        if transfo_xl_config_file == "":
            config = TransfoXLConfig()
        else:
            config = TransfoXLConfig(transfo_xl_config_file)
        print("Building PyTorch model from configuration: {}".format(str(config)))
        model = TransfoXLLMHeadModel(config)

        model = load_tf_weights_in_transfo_xl(model, config, tf_path)
        # Save pytorch-model
        pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME)
        pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME)
        print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path)))
        torch.save(model.state_dict(), pytorch_weights_dump_path)
        print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path)))
        with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
            f.write(config.to_json_string())


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--pytorch_dump_folder_path",
                        default = None,
                        type = str,
                        required = True,
                        help = "Path to the folder to store the PyTorch model or dataset/vocab.")
    parser.add_argument("--tf_checkpoint_path",
                        default = "",
                        type = str,
                        help = "An optional path to a TensorFlow checkpoint path to be converted.")
    parser.add_argument("--transfo_xl_config_file",
                        default = "",
                        type = str,
                        help = "An optional config json file corresponding to the pre-trained BERT model. \n"
                            "This specifies the model architecture.")
    parser.add_argument("--transfo_xl_dataset_file",
                        default = "",
                        type = str,
                        help = "An optional dataset file to be converted in a vocabulary.")
    args = parser.parse_args()
    convert_transfo_xl_checkpoint_to_pytorch(args.tf_checkpoint_path,
                                     args.transfo_xl_config_file,
                                     args.pytorch_dump_folder_path,
                                     args.transfo_xl_dataset_file)
