R"""Stores the relationship between component coeffs and kl divergences of perturbations.

Written to work with ImageNet models as well.
"""
import dataclasses
import os
from typing import List

from absl import app
from absl import flags
import numpy as np
import tensorflow as tf
from transformers import AutoTokenizer

from em import datasets as em_datasets
from em.models import em_models

from em.fishers import diagonal
from em.fishers import per_example
from em.perturbations import examples_context
from em.perturbations import h_to_fishers
from em.perturbations import kl_targeter
from em.perturbations import mm_perturbations
from em.perturbations import perturbation_exp_util as pe_util
from em.perturbations.scripts_util import coeff_kl_relationship_util
from em.tools.nmf import nmf_common

from em.util.color_util import cu

ExamplesContext = examples_context.ExamplesContext

RunOutput = coeff_kl_relationship_util.RunOutput
OutputForComponent = coeff_kl_relationship_util.OutputForComponent


FLAGS = flags.FLAGS


# Flags describing what files to read.
flags.DEFINE_string("pef_path", None, "")
flags.DEFINE_string("nmf_path", None, "")
flags.DEFINE_string("retaining_fisher_path", None, "")

# Other inputs.
flags.DEFINE_string("model", None, "")
flags.DEFINE_bool("from_pt_model", True, "")
flags.DEFINE_string("tokenizer", None, "Assumed to be equal to model if left None. Not needed for image models.")

# Inputs required when the examples are not present in or do not come from the PEFs file.
flags.DEFINE_string("ds_task", None, "")
flags.DEFINE_string("ds_split", None, "")
flags.DEFINE_integer("ds_n_examples", None, "")
flags.DEFINE_integer("ds_sequence_length", None, "Note that this is used to set image sizes as well.")


# Outputs
flags.DEFINE_string("output_dir", None, "Path to directory to write output to. Must already exist.")
flags.DEFINE_string("output_prefix", '', "Optional prefix to prepend to output file names.")


# Flags describing what components to run on.
flags.DEFINE_bool('sort_by_coeff_mag', True, "")
flags.DEFINE_integer('start_component_index', 0, "")
flags.DEFINE_integer('n_components', None, "")

# Flags describing the run parameters.
flags.DEFINE_integer('n_evaluation_examples', None, "")
# flags.DEFINE_enum('ablation_style', _ABLATION_STYLE[0], _ABLATION_STYLE, '')
# flags.DEFINE_integer('ablating_fisher_top_k', None, 
#                      "Leave set to None or non-positive value to use all parameters present in the component's Fisher.")

# KL range targeter parameters.
flags.DEFINE_float('min_target_kl', None, '')
flags.DEFINE_float('max_target_kl', None, '')
# flags.DEFINE_enum('kl_range_targeter_examples_type', _KL_TARGETER_EXAMPLES_TYPES[0], _KL_TARGETER_EXAMPLES_TYPES, '')
flags.DEFINE_integer('n_kl_range_targeter_examples', None, "")
flags.DEFINE_integer('kl_range_targeter_max_iters', 100, "")

# Flags describing the number of runs to complete.
flags.DEFINE_integer('n_kl_range_finds', None, "")
flags.DEFINE_integer('n_sign_guides_per_kl_range_find', 1, "")

# Misc flags.
flags.DEFINE_string("exp_special_processing", None, "")
# flags.DEFINE_string("exp_special_processing", 'HF_MNLI', "")

flags.DEFINE_integer('eval_batch_size', 32, "")


##########################################################################


"""
Misc notes:
- Text models can use the PEFs for the inputs, but image models will need a dataset.
"""


##########################################################################

def read_in_examples_context(model, tokenizer, pef):
    if FLAGS.ds_task is None:
        # Load from PEF.
        return ExamplesContext.from_pef(pef, tokenizer)

    else:
        ds = em_datasets.load(
            FLAGS.ds_task,
            split=FLAGS.ds_split,
            tokenizer=tokenizer,
            sequence_length=FLAGS.ds_sequence_length)
        ds = ds.take(FLAGS.ds_n_examples).batch(FLAGS.eval_batch_size)
        return ExamplesContext.from_dataset(ds, model)


##########################################################################

@dataclasses.dataclass
class Experiment:
    model: tf.keras.Model
    output_model: tf.keras.Model
    retaining_fisher: List[tf.Tensor]
    pef: per_example.PerExampleFlatFishers
    nmf: nmf_common.SparseNmfDecomposition
    examples_context: ExamplesContext

    def __post_init__(self):
        self.variables = self.model.trainable_variables
        self.output_variables = self.output_model.trainable_variables

    def get_component_indices(self):
        start = FLAGS.start_component_index
        end = start + FLAGS.n_components
        if not FLAGS.sort_by_coeff_mag:
            return range(start, end)
        ordered_comps = np.argsort(np.sort(-self.nmf.W, axis=0).sum(axis=0))
        return ordered_comps[start : end]


@dataclasses.dataclass
class ComponentExp:
    exp: Experiment
    component_index: int

    def __post_init__(self):
        self.batch_size = FLAGS.eval_batch_size

        self.delta = None
        self.lmbda = None

        self.kl_range_targeter_ex_indices = pe_util.get_top_example_indices(
            self.exp.nmf, self.component_index, n_examples=FLAGS.n_kl_range_targeter_examples)

        if FLAGS.n_evaluation_examples is None:
            self.eval_ex_indices = np.arange(self.exp.examples_context.n_examples)
        else:
            self.eval_ex_indices = np.arange(FLAGS.n_evaluation_examples)

        self.perturber = mm_perturbations.MmPerturber(
            variables=self.exp.variables,
            retaining_fisher=self.exp.retaining_fisher,
            ablating_shift=self._random_ablating_shift(),
            ablating_fisher=self._make_ablating_fisher(),
            output_variables=self.exp.output_variables,
        )

    #################################################################

    def _make_ablating_fisher(self):
        return h_to_fishers.single_component_to_fishers(
            self.exp.nmf, self.exp.variables, self.component_index)

    def _random_ablating_shift(self):
        return [tf.sign(tf.random.normal(v.shape)) for v in self.exp.variables]

    def _targeter_kl_fn(self, delta: float, lmbda: float) -> float:
        self.perturber.update_output_variables(delta, lmbda)
        kl = self.exp.examples_context.evaluate(
            self.exp.output_model, self.batch_size, self.kl_range_targeter_ex_indices
        ).kl()
        print(cu.hlr(kl))
        return kl

    #################################################################

    def find_delta_lmbda(self) -> bool:
        targeter = kl_targeter.GenericKlTargeter(
            kl_fn=self._targeter_kl_fn,
            kl_range=[FLAGS.min_target_kl, FLAGS.max_target_kl],
            delta_mag_range=[1e-5, 3],
        )

        found = targeter.search(FLAGS.kl_range_targeter_max_iters)
        if found:
            self.delta = targeter._last_delta
            self.lmbda = targeter._last_lmbda

        return found

    def replace_ablating_shift(self):
        self.perturber.ablating_shift = self._random_ablating_shift()

    def evaluate_output_model(self):
        assert self.delta is not None and self.lmbda is not None
        self.perturber.update_output_variables(self.delta, self.lmbda)
        return self.exp.examples_context.evaluate(
            self.exp.output_model, self.batch_size, self.eval_ex_indices
        )


##########################################################################

def load_experiment():
    model = em_models.from_pretrained(FLAGS.model, from_pt=FLAGS.from_pt_model)
    output_model = em_models.from_pretrained(FLAGS.model, from_pt=FLAGS.from_pt_model)
    retaining_fisher = diagonal.DiagonalFisher.load(FLAGS.retaining_fisher_path).fishers

    tokenizer = em_models.load_tokenizer(FLAGS.tokenizer or FLAGS.model)

    pef = per_example.PerExampleFlatFishers.load(
        FLAGS.pef_path,
        n_examples=None,
        # This leads to the Fishers not being loaded, which ends up being much faster.
        start_fisher_index=0,
        end_fisher_index=0,
    )

    nmf = nmf_common.SparseNmfDecomposition.load(FLAGS.nmf_path)
    nmf.normalize_components_to_unit_norm()

    if pef.input_ids.shape[0] > nmf.W.shape[0]:
        pef = pef.create_for_subset(list(range(nmf.W.shape[0])))

    if FLAGS.exp_special_processing == 'HF_MNLI':
        # am.PefNmfAnalysisContainer
        # shift_labels=self.special_processing == 'HF_MNLI',
        raise NotImplementedError("TODO (probably)")
    elif FLAGS.exp_special_processing is not None:
        raise ValueError(FLAGS.exp_special_processing)

    examples_context = read_in_examples_context(model, tokenizer, pef)

    return Experiment(
        model=model,
        output_model=output_model,
        retaining_fisher=retaining_fisher,
        pef=pef,
        nmf=nmf,
        examples_context=examples_context,
    )


def run_for_component(exp, component_index):
    comp_exp = ComponentExp(exp=exp, component_index=component_index)

    runs = []
    for _ in range(FLAGS.n_kl_range_finds):
        for j in range(FLAGS.n_sign_guides_per_kl_range_find):
            if j == 0:
                found = comp_exp.find_delta_lmbda()
                print(cu.hlr('Found delta lmbda'))
                if not found:
                    break
            else:
                comp_exp.replace_ablating_shift()

            eval_results = comp_exp.evaluate_output_model()
            logits = eval_results.logits
            runs.append(
                RunOutput(
                    lmbda=comp_exp.lmbda,
                    delta=comp_exp.delta,
                    logits=logits,
                )
            )

    output = OutputForComponent(
        component_index=component_index,
        runs=runs,
        evaluation_ex_indices=comp_exp.eval_ex_indices,
        W=exp.nmf.W[comp_exp.eval_ex_indices, :],
        labels=exp.examples_context.labels[comp_exp.eval_ex_indices],
        og_logits=exp.examples_context.og_logits[comp_exp.eval_ex_indices],
        pef_path=os.path.expanduser(FLAGS.pef_path),
        nmf_path=os.path.expanduser(FLAGS.nmf_path),
        retaining_fisher_path=os.path.expanduser(FLAGS.retaining_fisher_path),
        model=FLAGS.model,
        tokenizer=FLAGS.tokenizer or FLAGS.model,
        #
        ds_task=FLAGS.ds_task,
        ds_split=FLAGS.ds_split,
        ds_n_examples=FLAGS.ds_n_examples,
        ds_sequence_length=FLAGS.ds_sequence_length,
    )

    filename = f'{FLAGS.output_prefix}coeff_kl_relationship.comp{component_index}.h5'
    filepath = os.path.join(os.path.expanduser(FLAGS.output_dir), filename)
    output.save(filepath)


##########################################################################

def main(_):
    output_dir = os.path.expanduser(FLAGS.output_dir)  
    assert os.path.exists(output_dir), f'Please create the directory: {output_dir}'

    exp = load_experiment()

    for component_index in exp.get_component_indices():
        run_for_component(exp, component_index)


if __name__ == "__main__":
    app.run(main)
