R"""Script for running guided ablations."""
import os
from typing import Sequence

from absl import app
from absl import flags
import numpy as np
from transformers import AutoTokenizer

from em.fishers import diagonal

from em.projects.pi import qqp_components_context as QCC
from em.projects.pi.exps import ablation_exp_util
from em.projects.pi.exps import guided_ablations

from em.util.color_util import cu

ABLATION_EXP_TYPES = guided_ablations.ABLATION_EXP_TYPES

OutputForAblator = guided_ablations.OutputForAblator
OutputForComponent = guided_ablations.OutputForComponent

FLAGS = flags.FLAGS

# Flags describing what files to read.
flags.DEFINE_string("pef_path", None, "")
flags.DEFINE_string("nmf_path", None, "")
flags.DEFINE_string("retaining_fisher_path", None, "")

# Other inputs.
flags.DEFINE_string("model", None, "")
flags.DEFINE_string("tokenizer", None, "Assumed to be equal to model if left None.")

# Outputs
flags.DEFINE_string("output_dir", None, "Path to directory to write output to. Must already exist.")
flags.DEFINE_string("output_prefix", '', "Optional prefix to prepend to output file names.")

# Flags describing what components to run on.
flags.DEFINE_bool('sort_by_coeff_mag', True, "")
flags.DEFINE_integer('start_component_index', 0, "")
flags.DEFINE_integer('n_components', None, "")


# Flags for the guided ablations.

flags.DEFINE_float('min_target_kl', None, '')
flags.DEFINE_float('max_target_kl', None, '')
flags.DEFINE_integer('kl_range_targeter_max_iters', 25, "")

flags.DEFINE_integer('n_selected_examples', None, "")
flags.DEFINE_enum('ablating_variable_style',
                  guided_ablations.ABLATING_VARIABLE_STYLES[0],
                  guided_ablations.ABLATING_VARIABLE_STYLES,
                  '')

flags.DEFINE_integer('n_kl_range_finds', None, "")

flags.DEFINE_float('max_delta_H', 5.0, 'Using the H as the Fisher can sometimes require larger deltas.')

flags.DEFINE_string("exp_special_processing", 'HF_MNLI', "")

flags.DEFINE_list('ablation_exp_types', None, '')


flags.mark_flags_as_required([
    'pef_path', 'nmf_path', 'retaining_fisher_path',
    'n_components', 'output_dir',
    'min_target_kl', 'max_target_kl', 'n_selected_examples',
    'n_kl_range_finds',
])


##########################################################################


def read_in_to_exp() -> ablation_exp_util.Experiment1:
    tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer or FLAGS.model)
    hacc = QCC.QqpComponentContext(
        model_name_pattern=FLAGS.model,
        pef_filepath_pattern=os.path.expanduser(FLAGS.pef_path),
        nmf_filepath_pattern=os.path.expanduser(FLAGS.nmf_path),
        tokenizer=tokenizer,
        special_processing=FLAGS.exp_special_processing,
    )
    retaining_fisher = diagonal.DiagonalFisher.load(FLAGS.retaining_fisher_path).fishers
    return ablation_exp_util.Experiment1(
        mc=hacc.make_model_context('a'),
        retaining_fisher=retaining_fisher,
    )


def get_component_indices(exp):
    start = FLAGS.start_component_index
    end = start + FLAGS.n_components
    if not FLAGS.sort_by_coeff_mag:
        return range(start, end)
    ordered_comps = np.argsort(np.sort(-exp.nmf.W, axis=0).sum(axis=0))
    return ordered_comps[start : end]


##########################################################################

def make_helper(exp, component_index: int):
    assert FLAGS.min_target_kl < FLAGS.max_target_kl
    return guided_ablations.ExperimentHelper1(
        exp=exp,
        component_index=component_index,
        n_selected_examples=FLAGS.n_selected_examples,
        kl_target_range=[FLAGS.min_target_kl, FLAGS.max_target_kl],
        ablating_variable_style=FLAGS.ablating_variable_style,
    )


# def ablator_run(ablator, n_runs: int, max_delta: float = 3.0) -> OutputForAblator:
#     # TODO: Is this wrong with the n_runs stuff?
#     model = ablator.find_model(
#         max_iters=FLAGS.kl_range_targeter_max_iters,
#         max_delta=max_delta,
#     )
#     output_logits = [
#         ablator.helper.eval_ctx.evaluate(model).logits
#         for _ in range(n_runs)
#     ]
#     return OutputForAblator(
#         selected_example_indices=ablator.example_inds,
#         output_logits=np.stack(output_logits, axis=0),
#     )


def ablator_run(ablator, n_runs: int, max_delta: float = 3.0) -> OutputForAblator:
    output_logits = []
    for _ in range(n_runs):
        print(cu.hlc('Starting ablator run.'))
        model = ablator.find_model(
            max_iters=FLAGS.kl_range_targeter_max_iters,
            max_delta=max_delta,
        )
        output_logits.append(ablator.helper.eval_ctx.evaluate(model).logits)
        print(cu.hlc('Finished ablator run.'))

    return OutputForAblator(
        selected_example_indices=ablator.example_inds,
        output_logits=np.stack(output_logits, axis=0),
    )


def run_for_component(exp, component_index: int, ablation_exp_types: Sequence[str]):
    helper = make_helper(exp, component_index)

    n_runs = FLAGS.n_kl_range_finds

    comp_ex_output, comp_H_output = None, None
    rand_ex_outputs, rand_H_outputs = [], []

    if 'component_examples' in ablation_exp_types:
        comp_ex_output = ablator_run(helper.get_component_examples_ablator(), n_runs)

    if 'component_examples_H' in ablation_exp_types:
        comp_H_output = ablator_run(helper.get_component_H_ablator(), n_runs, FLAGS.max_delta_H)

    if 'random_examples' in ablation_exp_types or 'random_examples_H' in ablation_exp_types:
        for i in range(n_runs):
            if 'random_examples' in ablation_exp_types:
                rand_ex_outputs.append(
                    ablator_run(helper.get_random_examples_ablator(), 1))
                
            if 'random_examples_H' in ablation_exp_types:
                rand_H_outputs.append(
                    ablator_run(helper.get_random_examples_H_ablator(), 1, FLAGS.max_delta_H))

            if i < n_runs - 1:
                helper.resample_random_examples()

    output = OutputForComponent(
        component_index=component_index,
        #
        kl_target_range=[FLAGS.min_target_kl, FLAGS.max_target_kl],
        ablating_variable_style=FLAGS.ablating_variable_style,
        #
        component_top_fisher_ablation=comp_ex_output,
        component_H_ablation=comp_H_output,
        random_examples_ablations=rand_ex_outputs,
        random_examples_H_ablations=rand_H_outputs,
        #
        W=exp.W,
        labels=exp.labels,
        og_logits=exp.mc.container.predicted_logits,
        #
        pef_path=os.path.expanduser(FLAGS.pef_path),
        nmf_path=os.path.expanduser(FLAGS.nmf_path),
        retaining_fisher_path=os.path.expanduser(FLAGS.retaining_fisher_path),
        model=FLAGS.model,
        tokenizer=FLAGS.tokenizer or FLAGS.model,
    )

    filename = f'{FLAGS.output_prefix}guided_kl_ablation.comp{component_index}.h5'
    filepath = os.path.join(os.path.expanduser(FLAGS.output_dir), filename)
    output.save(filepath)


# def run_for_component(exp, component_index: int, ablation_exp_types: Sequence[str]):
#     helper = make_helper(exp, component_index)

#     n_runs = FLAGS.n_kl_range_finds

#     comp_ex_output = ablator_run(helper.get_component_examples_ablator(), n_runs)
#     comp_H_output = ablator_run(helper.get_component_H_ablator(), n_runs, FLAGS.max_delta_H)

#     rand_ex_outputs = []
#     for i in range(n_runs):
#         rand_ex_outputs.append(
#             ablator_run(helper.get_random_examples_ablator(), 1))
#         if i < n_runs - 1:
#             helper.resample_random_examples()

#     output = OutputForComponent(
#         component_index=component_index,
#         #
#         kl_target_range=[FLAGS.min_target_kl, FLAGS.max_target_kl],
#         ablating_variable_style=FLAGS.ablating_variable_style,
#         #
#         component_top_fisher_ablation=comp_ex_output,
#         component_H_ablation=comp_H_output,
#         random_examples_ablations=rand_ex_outputs,
#         #
#         W=exp.W,
#         labels=exp.labels,
#         og_logits=exp.mc.container.predicted_logits,
#         #
#         pef_path=os.path.expanduser(FLAGS.pef_path),
#         nmf_path=os.path.expanduser(FLAGS.nmf_path),
#         retaining_fisher_path=os.path.expanduser(FLAGS.retaining_fisher_path),
#         model=FLAGS.model,
#         tokenizer=FLAGS.tokenizer or FLAGS.model,
#     )

#     filename = f'{FLAGS.output_prefix}guided_kl_ablation.comp{component_index}.h5'
#     filepath = os.path.join(os.path.expanduser(FLAGS.output_dir), filename)
#     output.save(filepath)

##########################################################################


def main(_):
    output_dir = os.path.expanduser(FLAGS.output_dir)  
    assert os.path.exists(output_dir), f'Please create the directory: {output_dir}'

    ablation_exp_types = FLAGS.ablation_exp_types
    if not ablation_exp_types:
        ablation_exp_types = ABLATION_EXP_TYPES
    assert all(t in ABLATION_EXP_TYPES for t in ablation_exp_types)

    exp = read_in_to_exp()

    for component_index in get_component_indices(exp):
        print(cu.hlg(f'Starting run for component {component_index}.'))
        run_for_component(exp, component_index, ablation_exp_types)


if __name__ == "__main__":
    app.run(main)
