R"""Script for running guided ablations."""
import dataclasses
import os
from typing import List, Sequence

from absl import app
from absl import flags
import numpy as np
import tensorflow as tf
from transformers import AutoTokenizer

from em import datasets as em_datasets
from em.models import em_models

from em.fishers import diagonal
from em.fishers import per_example
from em.perturbations import examples_context
from em.perturbations import h_to_fishers
from em.perturbations import kl_targeter
from em.perturbations import mm_perturbations
from em.perturbations import perturbation_exp_util as pe_util
from em.perturbations import sign_patterns
from em.perturbations.scripts_util import guided_ablations_util
from em.tools.nmf import nmf_common

from em.util.color_util import cu


ABLATION_EXP_TYPES = guided_ablations_util.ABLATION_EXP_TYPES
SUPPORTED_IMAGE_EXP_TYPES = guided_ablations_util.SUPPORTED_IMAGE_EXP_TYPES

OutputForAblator = guided_ablations_util.OutputForAblator
OutputForComponent = guided_ablations_util.OutputForComponent

ExamplesContext = examples_context.ExamplesContext

InvalidKlFnOutputError = kl_targeter.InvalidKlFnOutputError


FLAGS = flags.FLAGS

# Flags describing what files to read.
flags.DEFINE_string("pef_path", None, "")
flags.DEFINE_string("nmf_path", None, "")
flags.DEFINE_string("retaining_fisher_path", None, "")

# Other inputs.
flags.DEFINE_string("model", None, "")
flags.DEFINE_bool("from_pt_model", True, "")
flags.DEFINE_string("tokenizer", None, "Assumed to be equal to model if left None. Not needed for image models.")

# Inputs required when the examples are not present in or do not come from the PEFs file.
flags.DEFINE_string("ds_task", None, "")
flags.DEFINE_string("ds_split", None, "")
flags.DEFINE_integer("ds_n_examples", None, "")
flags.DEFINE_integer("ds_sequence_length", None, "Note that this is used to set image sizes as well.")

# Outputs
flags.DEFINE_string("output_dir", None, "Path to directory to write output to. Must already exist.")
flags.DEFINE_string("output_prefix", '', "Optional prefix to prepend to output file names.")

# Flags describing what components to run on.
flags.DEFINE_bool('sort_by_coeff_mag', True, "")
flags.DEFINE_integer('start_component_index', 0, "")
flags.DEFINE_integer('n_components', None, "")


# Flags for the guided ablations.

flags.DEFINE_list('ablation_exp_types', None, '')

flags.DEFINE_float('min_target_kl', None, '')
flags.DEFINE_float('max_target_kl', None, '')
flags.DEFINE_integer('kl_range_targeter_max_iters', 25, "")

flags.DEFINE_integer('n_selected_examples', None, "")
# flags.DEFINE_enum('ablating_variable_style',
#                   guided_ablations.ABLATING_VARIABLE_STYLES[0],
#                   guided_ablations.ABLATING_VARIABLE_STYLES,
#                   '')

flags.DEFINE_integer('n_kl_range_finds', None, "")

flags.DEFINE_float('max_delta_H', 5.0, 'Using the H as the Fisher can sometimes require larger deltas.')


# Misc flags.
flags.DEFINE_string("exp_special_processing", None, "")
# flags.DEFINE_string("exp_special_processing", 'HF_MNLI', "")

flags.DEFINE_integer('eval_batch_size', 32, "")
flags.DEFINE_integer('kl_gradient_batch_size', 32, "")

flags.DEFINE_bool('skip_existing_results', False, "")


flags.mark_flags_as_required([
    'pef_path', 'nmf_path', 'retaining_fisher_path',
    'n_components', 'output_dir',
    'min_target_kl', 'max_target_kl', 'n_selected_examples',
    'n_kl_range_finds',
])


##########################################################################

def read_in_examples_context(model, tokenizer, pef):
    if FLAGS.ds_task is None:
        # Load from PEF.
        return ExamplesContext.from_pef(pef, tokenizer)

    else:
        ds = em_datasets.load(
            FLAGS.ds_task,
            split=FLAGS.ds_split,
            tokenizer=tokenizer,
            sequence_length=FLAGS.ds_sequence_length)
        ds = ds.take(FLAGS.ds_n_examples).batch(FLAGS.eval_batch_size)
        return ExamplesContext.from_dataset(ds, model)


##########################################################################

@dataclasses.dataclass
class Experiment:
    model: tf.keras.Model
    output_model: tf.keras.Model
    retaining_fisher: List[tf.Tensor]
    pef: per_example.PerExampleFlatFishers
    nmf: nmf_common.SparseNmfDecomposition
    examples_context: ExamplesContext

    def __post_init__(self):
        self.variables = self.model.trainable_variables
        self.output_variables = self.output_model.trainable_variables

    def get_component_indices(self):
        start = FLAGS.start_component_index
        end = start + FLAGS.n_components
        if not FLAGS.sort_by_coeff_mag:
            return range(start, end)
        ordered_comps = np.argsort(np.sort(-self.nmf.W, axis=0).sum(axis=0))
        return ordered_comps[start : end]

    def get_helper(self, component_index: int) -> 'ComponentHelper':
        return ComponentHelper(exp=self, component_index=component_index)


###########################################################

@dataclasses.dataclass
class ComponentHelper:
    exp: Experiment
    component_index: int

    def __post_init__(self):
        self.batch_size = FLAGS.eval_batch_size

        self.top_example_inds = pe_util.get_top_example_indices(
            self.exp.nmf, self.component_index, n_examples=FLAGS.n_selected_examples)
        self.top_examples_gradient = self._compute_normalized_gradient(self.top_example_inds)

    #######################################################

    def _get_component_H_fisher(self):
        return h_to_fishers.single_component_to_fishers(
            self.exp.nmf, self.exp.variables, self.component_index)

    def _get_random_example_indices(self) -> np.ndarray:
        if not em_models.is_image_model(FLAGS.model):
            raise NotImplementedError('TODO: Sample according to same predictions as top component examples.')
        return pe_util.get_uniformly_random_example_indices(self.exp.nmf, n_examples=FLAGS.n_selected_examples)

    def _compute_normalized_gradient(self, example_inds: np.ndarray):
        grads = sign_patterns.compute_kl_gradient(
            model=self.exp.model,
            variables=self.exp.variables,
            examples_context=self.exp.examples_context,
            example_indices=example_inds,
            batch_size=FLAGS.kl_gradient_batch_size,
            allow_recompile=True
        )
        # Normalization.
        inv_norm = tf.math.rsqrt(tf.reduce_sum([tf.reduce_sum(tf.square(g)) for g in grads]))
        return [inv_norm * g for g in grads]

    #######################################################

    def get_component_examples_ablator(self) -> 'Ablator':
        raise NotImplementedError('TODO')

    def get_random_examples_ablator(self) -> 'Ablator':
        raise NotImplementedError('TODO')
        
    #######################################################

    def get_component_H_ablator(self) -> 'Ablator':
        return Ablator(
            helper=self,
            example_inds=self.top_example_inds,
            examples_fisher=self._get_component_H_fisher(),
            examples_gradient=self.top_examples_gradient,
        )

    def get_random_examples_H_ablator(self) -> 'Ablator':
        random_example_inds = self._get_random_example_indices()
        random_examples_gradient = self._compute_normalized_gradient(random_example_inds)
        return Ablator(
            helper=self,
            example_inds=random_example_inds,
            examples_fisher=self._get_component_H_fisher(),
            examples_gradient=random_examples_gradient,
        )


###########################################################

@dataclasses.dataclass
class Ablator:
    helper: ComponentHelper

    example_inds: np.ndarray
    examples_fisher: List[tf.Tensor]
    examples_gradient: List[tf.Tensor]

    def __post_init__(self):
        self.exp = self.helper.exp

        self.batch_size = FLAGS.eval_batch_size

        self.perturber = mm_perturbations.MmPerturber(
            variables=self.exp.variables,
            retaining_fisher=self.exp.retaining_fisher,
            ablating_shift=[tf.sign(g) for g in self.examples_gradient],
            ablating_fisher=self.examples_fisher,
            output_variables=self.exp.output_variables,
        )

    def _targeter_kl_fn(self, delta: float, lmbda: float) -> float:
        self.perturber.update_output_variables(delta, lmbda)
        kl = self.exp.examples_context.evaluate(
            self.exp.output_model, self.batch_size, self.example_inds
        ).kl()
        # print(cu.hlr(kl))
        return kl

    def find_model(self, max_iters: int, max_delta: float):
        targeter = kl_targeter.GenericKlTargeter(
            kl_fn=self._targeter_kl_fn,
            kl_range=[FLAGS.min_target_kl, FLAGS.max_target_kl],
            delta_mag_range=[1e-5, max_delta],
        )
        # TODO: Do something if not found.
        found = targeter.search(max_iters)
        self.perturber.update_output_variables(targeter._last_delta, targeter._last_lmbda)
        return self.exp.output_model

    def evaluate_output_model(self):
        return self.exp.examples_context.evaluate(self.exp.output_model, self.batch_size)


##########################################################################


def ablator_run(ablator, n_runs: int, max_delta: float = 3.0) -> OutputForAblator:
    output_logits = []
    for _ in range(n_runs):
        print(cu.hlg('Starting ablator run.'))
        ablator.find_model(
            max_iters=FLAGS.kl_range_targeter_max_iters,
            max_delta=max_delta,
        )
        output_logits.append(ablator.evaluate_output_model().logits)
        print(cu.hlg('Finished ablator run.'))
    return OutputForAblator(
        selected_example_indices=ablator.example_inds,
        output_logits=np.stack(output_logits, axis=0),
    )

##########################################################################


def load_experiment():
    model = em_models.from_pretrained(FLAGS.model, from_pt=FLAGS.from_pt_model)
    output_model = em_models.from_pretrained(FLAGS.model, from_pt=FLAGS.from_pt_model)
    retaining_fisher = diagonal.DiagonalFisher.load(FLAGS.retaining_fisher_path).fishers

    tokenizer = em_models.load_tokenizer(FLAGS.tokenizer or FLAGS.model)

    pef = per_example.PerExampleFlatFishers.load(
        FLAGS.pef_path,
        n_examples=None,
        # This leads to the Fishers not being loaded, which ends up being much faster.
        start_fisher_index=0,
        end_fisher_index=0,
    )

    nmf = nmf_common.SparseNmfDecomposition.load(FLAGS.nmf_path)
    nmf.normalize_components_to_unit_norm()

    if pef.input_ids.shape[0] > nmf.W.shape[0]:
        pef = pef.create_for_subset(list(range(nmf.W.shape[0])))

    if FLAGS.exp_special_processing == 'HF_MNLI':
        # am.PefNmfAnalysisContainer
        # shift_labels=self.special_processing == 'HF_MNLI',
        raise NotImplementedError("TODO (probably)")
    elif FLAGS.exp_special_processing is not None:
        raise ValueError(FLAGS.exp_special_processing)

    examples_context = read_in_examples_context(model, tokenizer, pef)

    return Experiment(
        model=model,
        output_model=output_model,
        retaining_fisher=retaining_fisher,
        pef=pef,
        nmf=nmf,
        examples_context=examples_context,
    )


##########################################################################

def run_for_component(exp, component_index: int, ablation_exp_types: Sequence[str]):
    filename = f'{FLAGS.output_prefix}guided_kl_ablation.comp{component_index}.h5'
    filepath = os.path.join(os.path.expanduser(FLAGS.output_dir), filename)
    if FLAGS.skip_existing_results and os.path.exists(filepath):
        return

    helper = exp.get_helper(component_index)

    n_runs = FLAGS.n_kl_range_finds

    comp_ex_output, comp_H_output = None, None
    rand_ex_outputs, rand_H_outputs = [], []

    if 'component_examples' in ablation_exp_types:
        comp_ex_output = ablator_run(helper.get_component_examples_ablator(), n_runs)

    if 'component_examples_H' in ablation_exp_types:
        comp_H_output = ablator_run(helper.get_component_H_ablator(), n_runs, FLAGS.max_delta_H)

    if 'random_examples' in ablation_exp_types or 'random_examples_H' in ablation_exp_types:
        for i in range(n_runs):
            if 'random_examples' in ablation_exp_types:
                rand_ex_outputs.append(
                    ablator_run(helper.get_random_examples_ablator(), 1))
                
            if 'random_examples_H' in ablation_exp_types:
                rand_H_outputs.append(
                    ablator_run(helper.get_random_examples_H_ablator(), 1, FLAGS.max_delta_H))

    output = OutputForComponent(
        component_index=component_index,
        #
        kl_target_range=[FLAGS.min_target_kl, FLAGS.max_target_kl],
        ablating_variable_style='fixed_offset',
        #
        component_top_fisher_ablation=comp_ex_output,
        component_H_ablation=comp_H_output,
        random_examples_ablations=rand_ex_outputs,
        random_examples_H_ablations=rand_H_outputs,
        #
        W=exp.nmf.W,
        labels=exp.examples_context.labels,
        og_logits=exp.examples_context.og_logits,
        #
        pef_path=os.path.expanduser(FLAGS.pef_path),
        nmf_path=os.path.expanduser(FLAGS.nmf_path),
        retaining_fisher_path=os.path.expanduser(FLAGS.retaining_fisher_path),
        model=FLAGS.model,
        tokenizer=FLAGS.tokenizer or FLAGS.model,
    )

    output.save(filepath)

##########################################################################


def main(_):
    is_image_model = em_models.is_image_model(FLAGS.model)

    output_dir = os.path.expanduser(FLAGS.output_dir)  
    assert os.path.exists(output_dir), f'Please create the directory: {output_dir}'

    ablation_exp_types = FLAGS.ablation_exp_types
    if not ablation_exp_types:
        if is_image_model:
            ablation_exp_types = SUPPORTED_IMAGE_EXP_TYPES
        else:
            ablation_exp_types = ABLATION_EXP_TYPES

    assert all(t in ABLATION_EXP_TYPES for t in ablation_exp_types)
    if is_image_model:
        assert all(t in SUPPORTED_IMAGE_EXP_TYPES for t in ablation_exp_types)

    exp = load_experiment()

    for component_index in exp.get_component_indices():
        try:
            run_for_component(exp, component_index, ablation_exp_types)
        except InvalidKlFnOutputError:
            print(cu.hlr(f'Invalid KL fn output, skipping component {component_index}.'))


if __name__ == "__main__":
    app.run(main)
