"""Computes and saves a set of dense gradients."""
import json
import pydoc
from typing import Optional

from absl import app
from absl import flags

import torch
from torch.utils.data import DataLoader
import transformers

from npeff_torch.models import lm_mcqa
from npeff_torch.models import lm_suffix_mc
from npeff_torch.models import parameter_filtering
from npeff_torch.models import parameter_infos
from npeff_torch.peis import position_selectors
from npeff_torch.peis import random_projectors
from npeff_torch.peis.gradients import logit_functions
from npeff_torch.peis.gradients import gradient_computers
from npeff_torch.peis.gradients.formats import dn_gradients

from npeff_torch.util import tokenizer_utils


###############################################################################
_SUPPORTED_POSITION_SELECTORS = ('NoopPositionSelector', 'UniformRandomPositionSelector', 'LastPositionSelector', 'SparseFeatureCircuitsPositionSelector')
###############################################################################

FLAGS = flags.FLAGS


flags.DEFINE_string('output_filepath', None, '')

flags.DEFINE_string('model', None, '')
flags.DEFINE_string('model_cls', None, '')

flags.DEFINE_string('tokenizer', None, 'If set to None, will be the same as --model.')


flags.DEFINE_string('task', None, '')
flags.DEFINE_string('subtask', None, '')
flags.DEFINE_string('split', None, '')

flags.DEFINE_integer('sequence_length', None, '')
flags.DEFINE_integer('n_examples', None, '')

flags.DEFINE_integer('dataset_offset', None, '')


flags.DEFINE_string('load_dataset_fn_path', None, 'Should return a HuggingFace dataset ready to be passed to a dataloader.')


# TODO: Maybe allow for user-defined selectors using a path.
flags.DEFINE_string('logit_fn', 'cross_entropy_loss_logits_fn',
                    'Function of logits to compute gradients of. Either the name of a function from the logit_functions module '
                    'or a path to a function of type logit_functions.LogitFunctionType. These two cases will be distinguised '
                    'depending on whether the flag value contains a period.')
flags.DEFINE_enum('position_selector', 'NoopPositionSelector', list(_SUPPORTED_POSITION_SELECTORS), '')

# If these are provided, they should be a JSON dict mapping parameter names to their values. Currently, only
# JSON-encodable values can be provided.
#
# Some types of selectors/computers will require some flags to be passed in this manner.
flags.DEFINE_string('position_selector_kwargs', None, '')

flags.DEFINE_bool('error_on_null_grads', True, '')


# Flags related to the random projection.
flags.DEFINE_integer('d_projection', None, 'Leave unset to not do a random projection.')
flags.DEFINE_string('projection_type', 'hypercubic_v2', '')
flags.DEFINE_string('projection_algorithm', 'alg3', '')
flags.DEFINE_integer('random_projection_seed', 2137, '')
flags.DEFINE_integer('random_projection_sparse_region_size', None, '')


flags.DEFINE_bool('include_embeddings', True, 'Whether to include the embeddings in the parameters we use when computing Fishers.')
flags.DEFINE_bool('include_layer_norms', True, 'Whether to include the layer norms in the parameters we use when computing Fishers.')


flags.DEFINE_string('label_key', 'labels', '')

flags.DEFINE_bool('save_examples', True, '')
flags.DEFINE_bool('save_labels', True, '')
flags.DEFINE_bool('save_logits', True, '')

flags.DEFINE_integer('save_top_n_log_probs', None, 'Leave unset or set to 0 to not save these.')


# Multiple-choice question answering as language modeling stuff.
flags.DEFINE_bool('lm_mcqa', False, '')
flags.DEFINE_list('lm_mcqa_answer_labels', None, '')
flags.DEFINE_string('lm_mcqa_answer_label_prefix', '', '')

# Multiple-choice as sentence completion as language modeling stuff.
flags.DEFINE_bool('lm_suffix_mc', False, '')
flags.DEFINE_bool('lm_suffix_mc_save_only_first_example', False, '')

###############################################################################


def _read_flag_kwargs(flag_value: Optional[str]):
    if flag_value:
        return json.loads(flag_value)
    else:
        return {}


def _get_parameters_and_infos(model):
    named_parameters = model.named_parameters()

    if not FLAGS.include_embeddings:
        named_parameters = ((n, p) for n, p in named_parameters if not parameter_filtering.is_embedding(model, n))
    if not FLAGS.include_layer_norms:
        named_parameters = ((n, p) for n, p in named_parameters if not parameter_filtering.is_layer_norm(model, n))

    named_parameters = list(named_parameters)

    parameters = [p for _, p in named_parameters]
    param_infos = [parameter_infos.ParameterInfo.from_parameter(p, name=n) for n, p in named_parameters]

    return parameters, param_infos


def _get_logit_fn(logit_fn_flag: str):
    if '.' in logit_fn_flag:
        return pydoc.locate(logit_fn_flag)
    else:
        return getattr(logit_functions, logit_fn_flag)


def _get_gradient_computer(model, parameters):
    return gradient_computers.GradientComputer(
        model=model,
        parameters=parameters,
        logit_fn=_get_logit_fn(FLAGS.logit_fn),
        error_on_null_grads=FLAGS.error_on_null_grads,
    )


def _maybe_read_random_projection_params():
    if FLAGS.d_projection is None:
        return None
    return random_projectors.RandomProjectionParams(
        d_projection=FLAGS.d_projection,
        projection_type=FLAGS.projection_type,
        algorithm=FLAGS.projection_algorithm,
        sparse_region_size=FLAGS.random_projection_sparse_region_size,
        seed=FLAGS.random_projection_seed,
    )


###############################################################################


def main(_):
    assert FLAGS.n_examples > 0

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    ModelClass = getattr(transformers, FLAGS.model_cls)
    model = ModelClass.from_pretrained(FLAGS.model).to(device)
    
    tokenizer = tokenizer_utils.from_pretrained(FLAGS.tokenizer or FLAGS.model)

    if FLAGS.lm_mcqa:
        assert not FLAGS.lm_suffix_mc
        model = lm_mcqa.LmMcqaLogitsComputer(
            model=model,
            tokenizer=tokenizer,
            answer_labels=FLAGS.lm_mcqa_answer_labels,
            answer_label_prefix=FLAGS.lm_mcqa_answer_label_prefix,
            device=device,
        )
    elif FLAGS.lm_suffix_mc:
        assert not FLAGS.lm_mcqa
        model = lm_suffix_mc.LmSuffixMcLogitsComputer(
            model=model,
            device=device,
        )

    load_dataset_fn = pydoc.locate(FLAGS.load_dataset_fn_path)
    ds = load_dataset_fn(
        task=FLAGS.task,
        subtask=FLAGS.subtask,
        split=FLAGS.split,
        tokenizer=tokenizer,
        sequence_length=FLAGS.sequence_length,
    )
    if FLAGS.dataset_offset is not None:
        ds = ds.skip(FLAGS.dataset_offset)
    dataloader = DataLoader(ds, batch_size=1)

    position_selector = getattr(position_selectors, FLAGS.position_selector).create(**_read_flag_kwargs(FLAGS.position_selector_kwargs))

    parameters, param_infos = _get_parameters_and_infos(model)
    gradient_computer = _get_gradient_computer(model, parameters)

    saver = dn_gradients.StreamingGradientSaver.create(
        model=model,
        gradient_computer=gradient_computer,
        position_selector=position_selector,
        label_key=FLAGS.label_key,
        device=device,
        save_examples=FLAGS.save_examples,
        save_labels=FLAGS.save_labels,
        save_logits=FLAGS.save_logits,
        save_top_n_log_probs=FLAGS.save_top_n_log_probs,
        parameter_infos=param_infos,
        random_projection_params=_maybe_read_random_projection_params(),
        lm_suffix_mc_save_only_first_example=FLAGS.lm_suffix_mc_save_only_first_example,
    )

    saver.compute_and_save_gradients(
        filepath=FLAGS.output_filepath,
        dataloader=dataloader,
        n_examples=FLAGS.n_examples,
    )


if __name__ == "__main__":
    app.run(main)
