R"""Stores the relationship between component coeffs and kl divergences of perturbations.
"""

import os

from absl import app
from absl import flags
import numpy as np
from transformers import AutoTokenizer

from em.fishers import diagonal

from em.projects.pi import qqp_components_context as QCC
from em.projects.pi.exps import ablation_exp_util
from em.projects.pi.exps import coeff_kl_relationship_util

RunOutput = coeff_kl_relationship_util.RunOutput
OutputForComponent = coeff_kl_relationship_util.OutputForComponent


# Flag enums:
_KL_TARGETER_EXAMPLES_TYPES = ['random', 'highest_coeffs']
_ABLATION_STYLE = ['sign_guide', 'gaussian']


FLAGS = flags.FLAGS

# Flags describing what files to read.
flags.DEFINE_string("pef_path", None, "")
flags.DEFINE_string("nmf_path", None, "")
flags.DEFINE_string("retaining_fisher_path", None, "")

# Other inputs.
flags.DEFINE_string("model", None, "")
flags.DEFINE_string("tokenizer", None, "Assumed to be equal to model if left None.")

# Outputs
flags.DEFINE_string("output_dir", None, "Path to directory to write output to. Must already exist.")
flags.DEFINE_string("output_prefix", '', "Optional prefix to prepend to output file names.")

# Flags describing what components to run on.
flags.DEFINE_bool('sort_by_coeff_mag', True, "")
flags.DEFINE_integer('start_component_index', 0, "")
flags.DEFINE_integer('n_components', None, "")

# Flags describing the run parameters.
flags.DEFINE_integer('n_evaluation_examples', None, "")
# flags.DEFINE_enum('ablation_style', _ABLATION_STYLE[0], _ABLATION_STYLE, '')
flags.DEFINE_integer('ablating_fisher_top_k', None, 
                     "Leave set to None or non-positive value to use all parameters present in the component's Fisher.")

# KL range targeter parameters.
flags.DEFINE_float('min_target_kl', None, '')
flags.DEFINE_float('max_target_kl', None, '')
flags.DEFINE_enum('kl_range_targeter_examples_type', _KL_TARGETER_EXAMPLES_TYPES[0], _KL_TARGETER_EXAMPLES_TYPES, '')
flags.DEFINE_integer('n_kl_range_targeter_examples', None, "")
flags.DEFINE_integer('kl_range_targeter_max_iters', 100, "")

# Flags describing the number of runs to complete.
flags.DEFINE_integer('n_kl_range_finds', None, "")
flags.DEFINE_integer('n_sign_guides_per_kl_range_find', 1, "")

# Misc flags.
flags.DEFINE_string("exp_special_processing", 'HF_MNLI', "")


flags.mark_flags_as_required([
    'pef_path', 'nmf_path', 'retaining_fisher_path',
    'n_components', 'n_evaluation_examples',
    'min_target_kl', 'max_target_kl', 'n_kl_range_targeter_examples',
    'n_kl_range_finds'
])

##########################################################################


def read_in_to_exp() -> ablation_exp_util.Experiment1:
    tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer or FLAGS.model)
    hacc = QCC.QqpComponentContext(
        model_name_pattern=FLAGS.model,
        pef_filepath_pattern=os.path.expanduser(FLAGS.pef_path),
        nmf_filepath_pattern=os.path.expanduser(FLAGS.nmf_path),
        tokenizer=tokenizer,
        special_processing=FLAGS.exp_special_processing,
    )
    retaining_fisher = diagonal.DiagonalFisher.load(FLAGS.retaining_fisher_path).fishers
    return ablation_exp_util.Experiment1(
        mc=hacc.make_model_context('a'),
        retaining_fisher=retaining_fisher,
    )


def get_component_indices(exp):
    start = FLAGS.start_component_index
    end = start + FLAGS.n_components
    if not FLAGS.sort_by_coeff_mag:
        return range(start, end)
    ordered_comps = np.argsort(np.sort(-exp.nmf.W, axis=0).sum(axis=0))
    return ordered_comps[start : end]


##########################################################################


def make_helper(exp, component_index: int):
    assert FLAGS.min_target_kl < FLAGS.max_target_kl

    if FLAGS.kl_range_targeter_examples_type == 'highest_coeffs':
        kl_range_targeter_ex_indices = np.argsort(-exp.nmf.W[:, component_index])[:FLAGS.n_kl_range_targeter_examples]
    else:
        kl_range_targeter_ex_indices = None

    if FLAGS.ablating_fisher_top_k > 0:
        ablate_top_k_params = FLAGS.ablating_fisher_top_k
    else:
        ablate_top_k_params = None

    return ablation_exp_util.ExperimentHelper1(
        exp=exp,
        component_index=component_index,
        n_evaluation_examples=FLAGS.n_evaluation_examples,
        kl_target_range=[FLAGS.min_target_kl, FLAGS.max_target_kl],
        n_kl_range_targeter_examples=FLAGS.n_kl_range_targeter_examples,
        kl_range_targeter_ex_indices=kl_range_targeter_ex_indices,
        ablate_top_k_params=ablate_top_k_params,
        fixed_sign_guide=False,
    )


def run_for_component(exp, component_index: int):
    helper = make_helper(exp, component_index)

    output = OutputForComponent(
        component_index=component_index,
        runs=[],
        evaluation_ex_indices=helper.evaluation_ex_indices,
        W=exp.W[helper.evaluation_ex_indices, :],
        labels=exp.labels[helper.evaluation_ex_indices],
        og_logits=exp.mc.container.predicted_logits[helper.evaluation_ex_indices, :],
        pef_path=os.path.expanduser(FLAGS.pef_path),
        nmf_path=os.path.expanduser(FLAGS.nmf_path),
        retaining_fisher_path=os.path.expanduser(FLAGS.retaining_fisher_path),
        model=FLAGS.model,
        tokenizer=FLAGS.tokenizer or FLAGS.model,
    )

    for _ in range(FLAGS.n_kl_range_finds):
        for j in range(FLAGS.n_sign_guides_per_kl_range_find):
            if j == 0:
                eval_results = helper.do_run(FLAGS.kl_range_targeter_max_iters)
            else:
                eval_results = helper.evaluate_with_different_sign_guide()
            logits = eval_results.logits
            output.runs.append(
                RunOutput(
                    lmbda=helper._last_lmbda,
                    delta=helper._last_delta,
                    logits=logits,
                )
            )

    filename = f'{FLAGS.output_prefix}coeff_kl_relationship.comp{component_index}.h5'
    filepath = os.path.join(os.path.expanduser(FLAGS.output_dir), filename)
    output.save(filepath)

##########################################################################


def main(_):
    output_dir = os.path.expanduser(FLAGS.output_dir)  
    assert os.path.exists(output_dir), f'Please create the directory: {output_dir}'

    exp = read_in_to_exp()

    for component_index in get_component_indices(exp):
        run_for_component(exp, component_index)


if __name__ == "__main__":
    app.run(main)
