"""Runs perturbations for LRM-NPEFF decompositions."""
import collections
import gc
import json
import os
import pydoc
from typing import List, Optional, Sequence, Tuple

from absl import app
from absl import flags

import torch
from tqdm import tqdm
import transformers

from npeff_torch.decomps.kmeans import kmeans
from npeff_torch.decomps.npeff import lrm_npeff_decomps
from npeff_torch.models import lm_mcqa
from npeff_torch.models import lm_suffix_mc
from npeff_torch.peis import random_projectors
from npeff_torch.peis.fishers.formats import frdn_lrm_pefs
from npeff_torch.peis.fishers.formats import pef_format_common
from npeff_torch.perturbations import evaluation_contexts
from npeff_torch.perturbations import perturbation_contexts
from npeff_torch.perturbations import perturbation_results

from npeff_torch.util import hdf5_utils
from npeff_torch.util import tokenizer_utils


###############################################################################
_DECOMPOSITION_TYPES = ['npeff', 'kmeans']
_DECOMPOSITION_WRAPPER_CLS_ENUMS = ['RandomlyProjectedDecompositionWrapper']
###############################################################################

FLAGS = flags.FLAGS


flags.DEFINE_string('output_filepath', None, '')


flags.DEFINE_string('npeff_filepath', None, 'The LRM-NPEFF decomposition.')
flags.DEFINE_enum('decomposition_type', 'npeff', _DECOMPOSITION_TYPES, 'Hack for allowing us to use this for gradient clusters.')

flags.DEFINE_list('pef_filepaths', None, 
                  'The PEF files used to compute the NPEFF coefficients, which MUST be in '
                  'the same order. These MUST have the examples saved.')

flags.DEFINE_list('n_examples_per_pef', None,
                  "Comma-separated list of integers indicating the number of examples to use from each PEF file. "
                  "If provided, the list must be the same length as the --pef_filepaths list. "
                  "Leave empty to use all examples from all PEFs. "
                  "Use a value of -1 for a particular PEF to use all examples from that particular PEF.")


flags.DEFINE_string('model', None, '')
flags.DEFINE_string('model_cls', None, '')

flags.DEFINE_string("tokenizer", None, "If left None, assumed to be equal to --model.")


flags.DEFINE_enum('decomposition_wrapper_cls', None, _DECOMPOSITION_WRAPPER_CLS_ENUMS, '')


flags.DEFINE_list('component_indices', None, 
                  'Leave set to None to run on all components. Does not affect semi-orthogonalization.')
flags.DEFINE_integer('max_non_empty_components', None,
                     'Leave None to run on all selected components. If provided, then runs perturbation experiments '
                     'for at most this many non-empty components.')


flags.DEFINE_integer('n_top_examples', None,
                     'The number of top examples for each component to evaluate on.')
flags.DEFINE_integer('min_n_top_examples', 1,
                     'Examples with a coefficient of zero will be excluded from the top examples. Hence, do '
                     'run evaulations for components with fewer than this number of examples with non-zero '
                     'coefficients.')
flags.DEFINE_integer('n_baseline_examples', None,
                     'The number of baseline examples to evaluate on.')

flags.DEFINE_integer('evaluation_batch_size', None,
                     'The batch size used when evaluating the models.')

flags.DEFINE_float("perturbation_magnitude", None, 'Must be a positive float.')
flags.DEFINE_float("rejection_max_abs_cos_similarity", None, 'Leave None to not do the semi-orthogonalization.')
flags.DEFINE_bool('cache_abs_cos_similarities', True, '')


# Flags for when --decomposition_wrapper_cls=RandomlyProjectedDecompositionWrapper:

flags.DEFINE_string('reconstructor_cls_path', None, 'Must be a subclass of compressed_sensing_common.ReconstructorAbc.')
# If these are provided, they should be a JSON dict mapping parameter names to their values. Currently, only
# JSON-encodable values can be provided.
flags.DEFINE_string('reconstructor_kwargs', None, '')


# Misc stuff I added.
flags.DEFINE_string('reprocess_example_fn_path', None, '')
flags.DEFINE_string('random_projector_params', None,
                    'If provided, should be JSON of kwargs for random_projectors.RandomProjectionParams that overrides whatever '
                    'would be read from the file.')


# Multiple-choice question answering as language modeling stuff.
flags.DEFINE_bool('lm_mcqa', False, '')
flags.DEFINE_list('lm_mcqa_answer_labels', None, '')
flags.DEFINE_string('lm_mcqa_answer_label_prefix', '', '')

# Multiple-choice as sentence completion as language modeling stuff.
flags.DEFINE_bool('lm_suffix_mc', False, '')


# TODO: Probably add some more configuration flags.
#       - Maybe whether to use first --n_baseline_examples as the baseline or examples with lowest coefficients.


###############################################################################


def _check_valid_output_filepath():
    output_filepath = FLAGS.output_filepath
    assert FLAGS.output_filepath is not None, 'The --output_filepath flag must be provided.'

    output_filepath = os.path.expanduser(output_filepath)
    assert os.path.isdir(os.path.dirname(output_filepath)), 'Invalid --output_filepath.'

    return output_filepath


def _read_n_examples_per_pef_flag(flag_value: Optional[List[str]]) -> Optional[List[Optional[int]]]:
    if not flag_value:
        return None
    ret = []
    for n_examples in flag_value:
        n_examples = int(n_examples)
        if n_examples < 0:
            ret.append(None)
        else:
            ret.append(n_examples)
    return ret


def _read_flag_kwargs(flag_value: Optional[str]):
    if flag_value:
        return json.loads(flag_value)
    else:
        return {}


def _get_component_indices(n_components: int) -> Sequence[int]:
    if FLAGS.component_indices:
        ret = [int(ci) for ci in FLAGS.component_indices]
        assert all(0 <= ci < n_components for ci in ret)
        return ret
    else:
        return range(n_components)


def _read_in_model(tokenizer, device: torch.device) -> 'transformers.PreTrainedModel':
    ModelClass = getattr(transformers, FLAGS.model_cls)
    model = ModelClass.from_pretrained(FLAGS.model).to(device)
    model.eval()

    # Prevent the following error:
    #   ValueError: Cannot handle batch sizes > 1 if no padding token is defined.
    if model.config.pad_token_id is None:
        model.config.pad_token_id = tokenizer.pad_token_id

    if FLAGS.lm_mcqa:
        assert not FLAGS.lm_suffix_mc
        model = lm_mcqa.LmMcqaLogitsComputer(
            model=model,
            tokenizer=tokenizer,
            answer_labels=FLAGS.lm_mcqa_answer_labels,
            answer_label_prefix=FLAGS.lm_mcqa_answer_label_prefix,
            device=device,
        )
    elif FLAGS.lm_suffix_mc:
        assert not FLAGS.lm_mcqa
        model = lm_suffix_mc.LmSuffixMcLogitsComputer(
            model=model,
            device=device,
        )

    return model


###############################################################################


def _read_random_projector() -> 'random_projectors.RandomProjector':
    if FLAGS.random_projector_params:
        params = random_projectors.RandomProjectionParams(**_read_flag_kwargs(FLAGS.random_projector_params))
        return random_projectors.RandomProjector(params=params)

    all_params = set(
        frdn_lrm_pefs.read_random_projector_params_from_file(pef_filepath)
        for pef_filepath in FLAGS.pef_filepaths
    )
    assert len(all_params) == 1, 'Inconsistent random projector parameters from the different PEF files.'
    params, = list(all_params)
    return random_projectors.RandomProjector(params=params)


###############################################################################


def main(_):
    # Check this now so that we don't error out after doing all the work.
    output_filepath = _check_valid_output_filepath()
    assert FLAGS.perturbation_magnitude > 0.0, 'The --perturbation_magnitude must be a positive float.'

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    tokenizer = tokenizer_utils.from_pretrained(FLAGS.tokenizer or FLAGS.model)

    original_model = _read_in_model(tokenizer, device)
    perturbed_model = _read_in_model(tokenizer, device)

    n_examples_per_pef = _read_n_examples_per_pef_flag(FLAGS.n_examples_per_pef)
    pef_extra_infos = pef_format_common.PefExtraInfos.read_from_files(FLAGS.pef_filepaths, n_examples_per_pef)
    assert pef_extra_infos.parameter_infos is not None
    assert pef_extra_infos.examples is not None
    assert pef_extra_infos.n_examples is not None

    if FLAGS.decomposition_type == 'npeff':
        raw_decomposition = lrm_npeff_decomps.LrmNpeffDecomposition.load(
            FLAGS.npeff_filepath, load_W=True, load_G=True)

    elif FLAGS.decomposition_type == 'kmeans':
        raw_decomposition = kmeans.KmeansClusteringTorch.load(FLAGS.npeff_filepath)

    else:
        raise ValueError(FLAGS.decomposition_type)

    assert raw_decomposition.n_examples == pef_extra_infos.n_examples
    assert FLAGS.n_baseline_examples <= raw_decomposition.n_examples

    raw_decomposition = raw_decomposition.to(device)
    raw_decomposition.normalize_reduced_components_to_unit_norm_()

    component_indices = _get_component_indices(raw_decomposition.n_components)

    #
    #
    #


    examples = {k: torch.from_numpy(v) for k, v in pef_extra_infos.examples.items()}
    if FLAGS.reprocess_example_fn_path:
        reprocess_example_fn = pydoc.locate(FLAGS.reprocess_example_fn_path)

        examples2 = collections.defaultdict(list)
        for i in tqdm(range(pef_extra_infos.n_examples)):
            example = {k: v[i] for k, v in examples.items()}
            example2 = reprocess_example_fn(example, tokenizer=tokenizer)
            for k, v in example2.items():
                examples2[k].append(v)

        examples = {k: torch.stack(v, dim=0) for k, v in examples2.items()}



    all_examples = evaluation_contexts.Examples(
        examples=examples,
        token_positions=torch.from_numpy(pef_extra_infos.token_positions) if pef_extra_infos.token_positions is not None else None,
    )
    all_examples = all_examples.to(device)

    if FLAGS.decomposition_wrapper_cls == 'RandomlyProjectedDecompositionWrapper':
        d_original = sum(pi.n_elements() for pi in pef_extra_infos.parameter_infos)
        print(f'd_original: {d_original}')

        random_projector = _read_random_projector()

        Reconstructor = pydoc.locate(FLAGS.reconstructor_cls_path)
        reconstructor = Reconstructor(
            random_projector=random_projector,
            d_original=d_original,
            **_read_flag_kwargs(FLAGS.reconstructor_kwargs),
        )

        decomposition_wrapper = perturbation_contexts.RandomlyProjectedDecompositionWrapper(
            decomposition=raw_decomposition,
            rejection_max_abs_cos_similarity=FLAGS.rejection_max_abs_cos_similarity,
            cache_abs_cos_similarities=FLAGS.cache_abs_cos_similarities,
            parameter_infos=pef_extra_infos.parameter_infos,
            reconstructor=reconstructor,
        )

    else:
        raise ValueError('Invalid --decomposition_wrapper_cls flag value.')

    #
    #
    #

    baseline_examples = all_examples.gather_examples(torch.arange(FLAGS.n_baseline_examples, device=device))

    all_component_results = []

    for component_index in tqdm(component_indices):
        # Early exit if --max_non_empty_components is set.
        if FLAGS.max_non_empty_components is not None and len(all_component_results) >= FLAGS.max_non_empty_components:
            break

        top_example_indices = decomposition_wrapper.get_top_example_indices_for_component(
            component_index, FLAGS.n_top_examples)

        n_actual_top_examples = top_example_indices.numel()
        if n_actual_top_examples < FLAGS.min_n_top_examples:
            continue

        top_examples = all_examples.gather_examples(top_example_indices)

        component_perturber = perturbation_contexts.ComponentPerturber(
            component_index=component_index,
            decomposition_wrapper=decomposition_wrapper,
            original_model=original_model,
            perturbed_model=perturbed_model,
            top_examples=top_examples,
            baseline_examples=baseline_examples,
            evaluation_batch_size=FLAGS.evaluation_batch_size,
        )
        all_component_results.append(
            component_perturber.evaluate_perturbation_pm(FLAGS.perturbation_magnitude))

        gc.collect()
        torch.cuda.empty_cache()

    # TODO: Maybe save occasionally in the loop above.
    experiment_results = perturbation_results.ExperimentPerturbationResults(component_results=all_component_results)
    # TODO: Save, maybe put metadata in the ExperimentPerturbationResults
    experiment_results.save(output_filepath)


if __name__ == "__main__":
    app.run(main)
