"""Runs perturbations focusing on KL divergences."""
import os

from absl import app
from absl import flags
import numpy as np
from transformers import AutoTokenizer
from tqdm import tqdm

from em.models import em_models
from em.projects.m_npeff import perturbations

from em.util.color_util import cu

FLAGS = flags.FLAGS

# Flags describing what files to read.
flags.DEFINE_string("nmf_filepath", None, "")
flags.DEFINE_string("pef_filepath", None, "")

# Other inputs.
flags.DEFINE_string("model", None, "")
flags.DEFINE_boolean("from_pt", False, "")
flags.DEFINE_string("tokenizer", None, "If left None, assumed to be equal to model for text models and None for image models.")

flags.DEFINE_integer("sequence_length", None, 'Used for image size for image models, normal meaning for text models.')
flags.DEFINE_string("task", None, "")
flags.DEFINE_string("split", None, "")

# Outputs
flags.DEFINE_string("output_filepath", None, "Path to file to write json output to.")

# Flags describing what components to run on.
flags.DEFINE_list("component_indices", None, 'Leave set to None to run on all components.')

flags.DEFINE_integer("n_top_examples", None, '')
flags.DEFINE_integer("n_total_examples", None, '')

flags.DEFINE_float("perturbation_magnitude", None, '')
flags.DEFINE_float("max_abs_cos_sim", None, 'Leave None to not do the semi-orthogolization.')

flags.DEFINE_bool('nli_label_swapping', False, '')

##########################################################################


def read_in_to_exp():
    model = em_models.from_pretrained(FLAGS.model, from_pt=FLAGS.from_pt)
    tokenizer = em_models.load_tokenizer(FLAGS.tokenizer or FLAGS.model)

    ret = perturbations.PerturbationExperiment.from_filepaths(
        nmf_filepath=FLAGS.nmf_filepath,
        pef_filepath=FLAGS.pef_filepath,
        model=model,
        variables=model.trainable_variables,
        tokenizer=tokenizer,
        task=FLAGS.task,
        split=FLAGS.split,
        n_top_examples=FLAGS.n_top_examples,
        n_total_examples=FLAGS.n_total_examples,
        sequence_length=FLAGS.sequence_length,
    )
    if FLAGS.nli_label_swapping:
        ret.eval_ctx.all_examples = (ret.eval_ctx.all_examples[0], (ret.eval_ctx.all_examples[1] + 1) % 3)
    return ret


def run_for_component(exp, component_index) -> perturbations.ComponentPmPerturbationOutput:
    perturber = perturbations.ComponentPerturber(
        exp=exp,
        component_index=component_index,
        max_sim=FLAGS.max_abs_cos_sim if FLAGS.max_abs_cos_sim > 0 else None,
    )

    pm_results = perturber.evaluate_pm(FLAGS.perturbation_magnitude)

    print(pm_results.minus_results.total_results.acc())
    print(pm_results.plus_results.total_results.acc())

    return perturbations.ComponentPmPerturbationOutput(
        component_index=component_index,
        plus_results=perturbations.PerturbationStats(
            top_kl=float(pm_results.plus_results.top_results.kl()),
            top_loss=float(pm_results.plus_results.top_results.loss()),
            top_acc=float(pm_results.plus_results.top_results.acc()),
            total_kl=float(pm_results.plus_results.total_results.kl()),
            total_loss=float(pm_results.plus_results.total_results.loss()),
            total_acc=float(pm_results.plus_results.total_results.acc()),
        ),
        minus_results=perturbations.PerturbationStats(
            top_kl=float(pm_results.minus_results.top_results.kl()),
            top_loss=float(pm_results.minus_results.top_results.loss()),
            top_acc=float(pm_results.minus_results.top_results.acc()),
            total_kl=float(pm_results.minus_results.total_results.kl()),
            total_loss=float(pm_results.minus_results.total_results.loss()),
            total_acc=float(pm_results.minus_results.total_results.acc()),
        ),

    )


def main(_):
    exp = read_in_to_exp()

    if FLAGS.component_indices is None:
        component_indices = list(range(exp.nmf.W.shape[-1]))
    else:
        component_indices = [int(c) for c in FLAGS.component_indices]

    assert len(component_indices) > 0

    output = perturbations.PmPerturbationExperimentOutput(
        model_name=FLAGS.model,
        nmf_filepath=FLAGS.nmf_filepath,
        task=FLAGS.task,
        split=FLAGS.split,
        n_top_examples=FLAGS.n_top_examples,
        n_total_examples=FLAGS.n_total_examples,
        max_sim=FLAGS.max_abs_cos_sim if FLAGS.max_abs_cos_sim > 0 else None,
        magnitude=FLAGS.perturbation_magnitude,
        component_outputs=[],
    )

    for component_index in tqdm(component_indices):
        print(cu.hlc(f'Starting ablator run for component {component_index}.'))
        component_output = run_for_component(exp, component_index)
        output.component_outputs.append(component_output)
        # Save after every component so that we can ^c to end a run early
        # and still get some results
        output.save(FLAGS.output_filepath)


if __name__ == "__main__":
    app.run(main)
