"""Computes and saves a set of fixed rank, fixed nnz LRM-pefs."""
import json
import pydoc
from typing import Optional

from absl import app
from absl import flags

import torch
from torch.utils.data import DataLoader
import transformers

from npeff_torch.models import parameter_filtering
from npeff_torch.models import parameter_infos
from npeff_torch.peis import position_selectors
from npeff_torch.peis import sparsifiers
from npeff_torch.peis.fishers import class_selectors
from npeff_torch.peis.fishers.computers import class_aligned_lrm_computers
from npeff_torch.peis.fishers.computers import streaming_svd_lrm_computers
from npeff_torch.peis.fishers.formats import frfn_lrm_pefs

from npeff_torch.util import tokenizer_utils


###############################################################################
_SUPPORTED_CLASS_SELECTORS = ('ExhaustiveClassSubsetSelector', 'LabelledClassSubsetSelector', 'TopClassesSubsetSelector')
_SUPPORTED_POSITION_SELECTORS = ('NoopPositionSelector', 'UniformRandomPositionSelector', 'SparseFeatureCircuitsPositionSelector')
_SUPPORTED_FISHER_COMPUTERS = ('ClassAlignedLrmComputer', 'StreamingSvdLrmComputer')
###############################################################################

FLAGS = flags.FLAGS


flags.DEFINE_string('output_filepath', None, '')

flags.DEFINE_string('model', None, '')
flags.DEFINE_string('model_cls', None, '')

flags.DEFINE_string('tokenizer', None, 'If set to None, will be the same as --model.')


flags.DEFINE_string('task', None, '')
flags.DEFINE_string('subtask', None, '')
flags.DEFINE_string('split', None, '')

flags.DEFINE_integer('sequence_length', None, '')
flags.DEFINE_integer('n_examples', None, '')

flags.DEFINE_integer('dataset_offset', None, '')


flags.DEFINE_string('load_dataset_fn_path', None, 'Should return a HuggingFace dataset ready to be passed to a dataloader.')


# TODO: Maybe allow for user-defined selectors using a path.
flags.DEFINE_enum('class_selector', 'ExhaustiveClassSubsetSelector', list(_SUPPORTED_CLASS_SELECTORS), '')
flags.DEFINE_enum('position_selector', 'NoopPositionSelector', list(_SUPPORTED_POSITION_SELECTORS), '')
flags.DEFINE_enum('fisher_computer', 'ClassAlignedLrmComputer', list(_SUPPORTED_FISHER_COMPUTERS), '')

# If these are provided, they should be a JSON dict mapping parameter names to their values. Currently, only
# JSON-encodable values can be provided.
#
# Some types of selectors/computers will require some flags to be passed in this manner.
flags.DEFINE_string('class_selector_kwargs', None, '')
flags.DEFINE_string('position_selector_kwargs', None, '')
flags.DEFINE_string('fisher_computer_kwargs', None, '')

flags.DEFINE_bool('error_on_null_grads', True, '')


flags.DEFINE_integer('nnz_per_example', None, '')

flags.DEFINE_bool('include_embeddings', True, 'Whether to include the embeddings in the parameters we use when computing Fishers.')
flags.DEFINE_bool('include_layer_norms', True, 'Whether to include the layer norms in the parameters we use when computing Fishers.')


flags.DEFINE_string('label_key', 'labels', '')

flags.DEFINE_bool('save_examples', True, '')
flags.DEFINE_bool('save_labels', True, '')
flags.DEFINE_bool('save_logits', True, '')


###############################################################################


def _read_flag_kwargs(flag_value: Optional[str]):
    if flag_value:
        return json.loads(flag_value)
    else:
        return {}


def _get_parameters_and_infos(model):
    named_parameters = model.named_parameters()

    if not FLAGS.include_embeddings:
        named_parameters = ((n, p) for n, p in named_parameters if not parameter_filtering.is_embedding(model, n))
    if not FLAGS.include_layer_norms:
        named_parameters = ((n, p) for n, p in named_parameters if not parameter_filtering.is_layer_norm(model, n))

    named_parameters = list(named_parameters)

    parameters = [p for _, p in named_parameters]
    param_infos = [parameter_infos.ParameterInfo.from_parameter(p, name=n) for n, p in named_parameters]

    return parameters, param_infos


def _get_fisher_computer(model, parameters):
    flag_kwargs = _read_flag_kwargs(FLAGS.fisher_computer_kwargs)
    if FLAGS.fisher_computer == 'ClassAlignedLrmComputer':
        return class_aligned_lrm_computers.ClassAlignedLrmComputer.create(
            model=model,
            parameters=parameters,
            error_on_null_grads=FLAGS.error_on_null_grads,
            **flag_kwargs,
        )
    elif FLAGS.fisher_computer == 'StreamingSvdLrmComputer':
        if flag_kwargs.get('output_rank') is None:
            raise ValueError('When using a StreamingSvdLrmComputer, the "output_rank" kwarg must be provided as a positive integer.')
        return streaming_svd_lrm_computers.StreamingSvdLrmComputer.create(
            model=model,
            parameters=parameters,
            error_on_null_grads=FLAGS.error_on_null_grads,
            **flag_kwargs,
        )

    else:
        raise ValueError(FLAGS.fisher_computer)


###############################################################################


def main(_):
    assert FLAGS.n_examples > 0

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    ModelClass = getattr(transformers, FLAGS.model_cls)
    model = ModelClass.from_pretrained(FLAGS.model).to(device)

    tokenizer = tokenizer_utils.from_pretrained(FLAGS.tokenizer or FLAGS.model)

    load_dataset_fn = pydoc.locate(FLAGS.load_dataset_fn_path)
    ds = load_dataset_fn(
        task=FLAGS.task,
        subtask=FLAGS.subtask,
        split=FLAGS.split,
        tokenizer=tokenizer,
        sequence_length=FLAGS.sequence_length,
    )
    if FLAGS.dataset_offset is not None:
        ds = ds.skip(FLAGS.dataset_offset)
    dataloader = DataLoader(ds, batch_size=1)

    class_selector = getattr(class_selectors, FLAGS.class_selector).create(**_read_flag_kwargs(FLAGS.class_selector_kwargs))
    position_selector = getattr(position_selectors, FLAGS.position_selector).create(**_read_flag_kwargs(FLAGS.position_selector_kwargs))

    parameters, param_infos = _get_parameters_and_infos(model)
    fisher_computer = _get_fisher_computer(model, parameters)

    saver = frfn_lrm_pefs.StreamingLrmPefSaver.create(
        model=model,
        fisher_computer=fisher_computer,
        nnz_per_example=FLAGS.nnz_per_example,
        position_selector=position_selector,
        class_subset_selector=class_selector,
        label_key=FLAGS.label_key,
        device=device,
        save_examples=FLAGS.save_examples,
        save_labels=FLAGS.save_labels,
        save_logits=FLAGS.save_logits,
        parameter_infos=param_infos,
    )

    saver.compute_and_save_pefs(
        filepath=FLAGS.output_filepath,
        dataloader=dataloader,
        n_examples=FLAGS.n_examples,
    )


if __name__ == "__main__":
    app.run(main)
