R"""Computes stuff for viewing localization information for components.

Right now, this is a bit temporary to generate JSON that can be viewed
by some one-off js code.





EXPS_DIR="${DATA_DIR}/pi1"
DATASETS_DIR="${EXPS_DIR}/datasets"
MODELS_DIR="${EXPS_DIR}/models"
FISHER_DIR="${EXPS_DIR}/fishers"
PER_EXAMPLE_FISHERS_DIR="${EXPS_DIR}/per_example_fishers"

og_pef_name="feather_berts_0.snli_train.all_vars.50000ex.65536.h5"
og_h_name="spH.nmf_decomp2.c512_1250Iters_65536pe_mvpp10_50000ex.${og_pef_name}"
NMF_NAME="fit_w.skip50000.50000ex.65536vpe.${og_h_name}"

CUDA_VISIBLE_DEVICES=0 python ./scripts1/data_gen/make_H_localization_infos.py \
    --model="connectivity/feather_berts_0" \
    --nmf_path="${PER_EXAMPLE_FISHERS_DIR}/${NMF_NAME}"


"""
import json
import os


from absl import app
from absl import flags

from em.analysis.parameters import localization_infos
from em.analysis.parameters import bert_param_infos

from em.models import em_models
from em.tools.nmf import nmf_common
from em.util import sparse_util

from em.util.color_util import cu

FLAGS = flags.FLAGS

flags.DEFINE_string("model", None, "")
flags.DEFINE_bool("from_pt", True, "")

flags.DEFINE_string("nmf_path", None, "")


def _log(s):
    print(cu.hlw(s))


def rename_infos(infos, rename_fn):
    for info in infos:
        info.name = rename_fn(info.name)


def to_json_str(infos):
    objs = [info.to_json_obj() for info in infos]
    return json.dumps(objs)


def main(_):
    # TODO: Make the variable name normalization less hacky.
    model_type, _ = em_models.split_model_ri(FLAGS.model)
    if model_type == 'transformer':
        to_nice_name = bert_param_infos.to_nice_name
    else:
        raise ValueError(f'Invalid model type: {model_type}')

    model = em_models.from_pretrained(os.path.expanduser(FLAGS.model), from_pt=FLAGS.from_pt)
    _log('Loaded model.')

    nmf = nmf_common.SparseNmfDecomposition.load(os.path.expanduser(FLAGS.nmf_path))
    nmf.normalize_components_to_unit_norm()
    _log('Loaded NMF.')

    infos = localization_infos.make_parameter_infos(model, nmf.get_full_sparse_H())
    rename_infos(infos, to_nice_name)

    json_str = to_json_str(infos)
    
    print(3 * '\n')
    print(cu.hlc(json_str))


if __name__ == "__main__":
    app.run(main)
