from os import path

"""
Script to generate a configuration file for the weights of a Bert network. Sparsity may
be specified here for use with the more flexible unstructured ADMM implementation.
"""

SPARSITY = 0.5
NUM_LAYERS = 12

LAYER_TYPES = ["bert.encoder.layer.{}.attention.self.query.weight",
               "bert.encoder.layer.{}.attention.self.key.weight",
               "bert.encoder.layer.{}.attention.self.value.weight",
               "bert.encoder.layer.{}.attention.output.dense.weight",
               "bert.encoder.layer.{}.intermediate.dense.weight",
               "bert.encoder.layer.{}.output.dense.weight"]

OUT_DIR = path.join(".", "configs")
OUT_FILE_NAME = "bert_{}_config.json".format(str(SPARSITY).replace(".", "_"))
OUT_LOC = path.join(OUT_DIR, OUT_FILE_NAME)

print("Generating {}".format(OUT_FILE_NAME))

with open(OUT_LOC, "w") as f:
    f.write("{\n\t\"prune_ratios\" : {\n")

    for i in range(NUM_LAYERS):
        for j, layer_type in enumerate(LAYER_TYPES):
            f.write("\t\t\"{}\" : {},\n".format(layer_type.format(i), SPARSITY))
    
    f.write("\t\t\"bert.pooler.dense.weight\" : {}\n".format(SPARSITY))

    f.write("\t}\n}")