"""Computes and saves a set of fixed rank dense LRM-pefs."""
import json
import pydoc
from typing import Optional

from absl import app
from absl import flags

import torch
from torch.utils.data import DataLoader
import transformers

from npeff_torch.models import lm_mcqa
from npeff_torch.models import lm_suffix_mc
from npeff_torch.models import parameter_filtering
from npeff_torch.models import parameter_infos
from npeff_torch.peis import position_selectors
from npeff_torch.peis import random_projectors
from npeff_torch.peis.fishers import class_selectors
from npeff_torch.peis.fishers.computers import class_aligned_lrm_computers
from npeff_torch.peis.fishers.computers import evrp_rp_svd_lrm_computers
from npeff_torch.peis.fishers.computers import rp_svd_lrm_computers
from npeff_torch.peis.fishers.computers import sampling_rp_svd_lrm_computers
from npeff_torch.peis.fishers.computers import streaming_svd_lrm_computers
from npeff_torch.peis.fishers.formats import frdn_lrm_pefs

from npeff_torch.util import tokenizer_utils


###############################################################################
_SUPPORTED_CLASS_SELECTORS = ('ExhaustiveClassSubsetSelector', 'LabelledClassSubsetSelector', 'TopClassesSubsetSelector')
_SUPPORTED_POSITION_SELECTORS = ('NoopPositionSelector', 'UniformRandomPositionSelector', 'LastPositionSelector', 'SparseFeatureCircuitsPositionSelector')
_SUPPORTED_FISHER_COMPUTERS = ('ClassAlignedLrmComputer', 'StreamingSvdLrmComputer', 'RpSvdLrmComputer', 'BatchedRpSvdLrmComputer', 'BatchedEvrpRpSvdLrmComputer', 'BatchedSamplingRpSvdLrmComputer')
###############################################################################

FLAGS = flags.FLAGS


flags.DEFINE_string('output_filepath', None, '')

flags.DEFINE_string('model', None, '')
flags.DEFINE_string('model_cls', None, '')

flags.DEFINE_string('tokenizer', None, 'If set to None, will be the same as --model.')


flags.DEFINE_string('task', None, '')
flags.DEFINE_string('subtask', None, '')
flags.DEFINE_string('split', None, '')

flags.DEFINE_integer('sequence_length', None, '')
flags.DEFINE_integer('n_examples', None, '')

flags.DEFINE_integer('dataset_offset', None, '')


flags.DEFINE_string('load_dataset_fn_path', None, 'Should return a HuggingFace dataset ready to be passed to a dataloader.')
# If these are provided, they should be a JSON dict mapping parameter names to their values. Currently, only
# JSON-encodable values can be provided.
flags.DEFINE_string('load_dataset_fn_kwargs', None, '')


# TODO: Maybe allow for user-defined selectors using a path.
flags.DEFINE_enum('class_selector', 'ExhaustiveClassSubsetSelector', list(_SUPPORTED_CLASS_SELECTORS), '')
flags.DEFINE_enum('position_selector', 'NoopPositionSelector', list(_SUPPORTED_POSITION_SELECTORS), '')
flags.DEFINE_enum('fisher_computer', 'ClassAlignedLrmComputer', list(_SUPPORTED_FISHER_COMPUTERS), '')

# If these are provided, they should be a JSON dict mapping parameter names to their values. Currently, only
# JSON-encodable values can be provided.
#
# Some types of selectors/computers will require some flags to be passed in this manner.
flags.DEFINE_string('class_selector_kwargs', None, '')
flags.DEFINE_string('position_selector_kwargs', None, '')
flags.DEFINE_string('fisher_computer_kwargs', None, '')

flags.DEFINE_bool('error_on_null_grads', True, '')


flags.DEFINE_integer('d_projection', None, 'Leave unset to not do a random projection.')
flags.DEFINE_string('projection_type', 'hypercubic_v2', '')
flags.DEFINE_string('projection_algorithm', 'alg3', '')
flags.DEFINE_integer('random_projection_seed', 2137, '')
flags.DEFINE_integer('random_projection_sparse_region_size', None, '')


flags.DEFINE_bool('include_embeddings', True, 'Whether to include the embeddings in the parameters we use when computing Fishers.')
flags.DEFINE_bool('include_layer_norms', True, 'Whether to include the layer norms in the parameters we use when computing Fishers.')


flags.DEFINE_string('label_key', 'labels', '')

flags.DEFINE_bool('save_examples', True, '')
flags.DEFINE_bool('save_labels', True, '')
flags.DEFINE_bool('save_logits', True, '')

flags.DEFINE_integer('save_top_n_log_probs', None, 'Leave unset or set to 0 to not save these.')


# Multiple-choice question answering as language modeling stuff.
flags.DEFINE_bool('lm_mcqa', False, '')
flags.DEFINE_list('lm_mcqa_answer_labels', None, '')
flags.DEFINE_string('lm_mcqa_answer_label_prefix', '', '')


# Multiple-choice as sentence completion as language modeling stuff.
flags.DEFINE_bool('lm_suffix_mc', False, '')
flags.DEFINE_bool('lm_suffix_mc_save_only_first_example', False, '')


###############################################################################


def _read_flag_kwargs(flag_value: Optional[str]):
    if flag_value:
        return json.loads(flag_value)
    else:
        return {}


def _get_parameters_and_infos(model):
    named_parameters = model.named_parameters()

    if not FLAGS.include_embeddings:
        named_parameters = ((n, p) for n, p in named_parameters if not parameter_filtering.is_embedding(model, n))
    if not FLAGS.include_layer_norms:
        named_parameters = ((n, p) for n, p in named_parameters if not parameter_filtering.is_layer_norm(model, n))

    named_parameters = list(named_parameters)

    parameters = [p for _, p in named_parameters]
    param_infos = [parameter_infos.ParameterInfo.from_parameter(p, name=n) for n, p in named_parameters]

    return parameters, param_infos


def _maybe_read_random_projection_params():
    if FLAGS.d_projection is None:
        return None
    return random_projectors.RandomProjectionParams(
        d_projection=FLAGS.d_projection,
        projection_type=FLAGS.projection_type,
        algorithm=FLAGS.projection_algorithm,
        sparse_region_size=FLAGS.random_projection_sparse_region_size,
        seed=FLAGS.random_projection_seed,
    )


def _get_fisher_computer(model, parameters, random_projection_params):
    flag_kwargs = _read_flag_kwargs(FLAGS.fisher_computer_kwargs)
    if FLAGS.fisher_computer == 'ClassAlignedLrmComputer':
        return class_aligned_lrm_computers.ClassAlignedLrmComputer.create(
            model=model,
            parameters=parameters,
            error_on_null_grads=FLAGS.error_on_null_grads,
            **flag_kwargs,
        )
    elif FLAGS.fisher_computer == 'StreamingSvdLrmComputer':
        if flag_kwargs.get('output_rank') is None:
            raise ValueError('When using a StreamingSvdLrmComputer, the "output_rank" kwarg must be provided as a positive integer.')
        return streaming_svd_lrm_computers.StreamingSvdLrmComputer.create(
            model=model,
            parameters=parameters,
            error_on_null_grads=FLAGS.error_on_null_grads,
            **flag_kwargs,
        )

    elif FLAGS.fisher_computer in ('RpSvdLrmComputer', 'BatchedRpSvdLrmComputer'):
        FisherComputer = getattr(rp_svd_lrm_computers, FLAGS.fisher_computer)
        if flag_kwargs.get('output_rank') is None:
            raise ValueError('When using a RpSvdLrmComputer, the "output_rank" kwarg must be provided as a positive integer.')
        if random_projection_params is None:
            raise ValueError('When using a RpSvdLrmComputer, the random_projection_params must be provided.')
        return FisherComputer.create(
            model=model,
            parameters=parameters,
            random_projection_params=random_projection_params,
            error_on_null_grads=FLAGS.error_on_null_grads,
            **flag_kwargs,
        )

    elif FLAGS.fisher_computer == 'BatchedEvrpRpSvdLrmComputer':
        if flag_kwargs.get('output_rank') is None:
            raise ValueError('When using a EvrpRpSvdLrmComputer, the "output_rank" kwarg must be provided as a positive integer.')
        if random_projection_params is None:
            raise ValueError('When using a EvrpRpSvdLrmComputer, the random_projection_params must be provided.')
        return evrp_rp_svd_lrm_computers.BatchedEvrpRpSvdLrmComputer.create(
            model=model,
            parameters=parameters,
            random_projection_params=random_projection_params,
            error_on_null_grads=FLAGS.error_on_null_grads,
            **flag_kwargs,
        )

    elif FLAGS.fisher_computer == 'BatchedSamplingRpSvdLrmComputer':
        if flag_kwargs.get('output_rank') is None:
            raise ValueError('When using a BatchedSamplingRpSvdLrmComputer, the "output_rank" kwarg must be provided as a positive integer.')
        if flag_kwargs.get('n_samples') is None:
            raise ValueError('When using a BatchedSamplingRpSvdLrmComputer, the "n_samples" kwarg must be provided as a positive integer.')
        if flag_kwargs.get('sampler_seed') is None:
            raise ValueError('When using a BatchedSamplingRpSvdLrmComputer, the "sampler_seed" kwarg must be provided as an integer.')
        if random_projection_params is None:
            raise ValueError('When using a BatchedSamplingRpSvdLrmComputer, the random_projection_params must be provided.')

        return sampling_rp_svd_lrm_computers.BatchedSamplingRpSvdLrmComputer.create(
            model=model,
            parameters=parameters,
            random_projection_params=random_projection_params,
            error_on_null_grads=FLAGS.error_on_null_grads,
            **flag_kwargs,
        )

    else:
        raise ValueError(FLAGS.fisher_computer)

# There was a bug where this caused some stuff to be doubly projected since I did not update this. Leaving this here to know
# what cases might have the bug.
# def _has_random_projection_params_on_saver(fisher_computer):
#     return not isinstance(fisher_computer, (rp_svd_lrm_computers.RpSvdLrmComputer,))


###############################################################################


def main(_):
    assert FLAGS.n_examples > 0

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    ModelClass = getattr(transformers, FLAGS.model_cls)
    model = ModelClass.from_pretrained(FLAGS.model).to(device)

    tokenizer = tokenizer_utils.from_pretrained(FLAGS.tokenizer or FLAGS.model)

    if FLAGS.lm_mcqa:
        assert not FLAGS.lm_suffix_mc
        model = lm_mcqa.LmMcqaLogitsComputer(
            model=model,
            tokenizer=tokenizer,
            answer_labels=FLAGS.lm_mcqa_answer_labels,
            answer_label_prefix=FLAGS.lm_mcqa_answer_label_prefix,
            device=device,
        )
    elif FLAGS.lm_suffix_mc:
        assert not FLAGS.lm_mcqa
        model = lm_suffix_mc.LmSuffixMcLogitsComputer(
            model=model,
            device=device,
        )

    load_dataset_fn = pydoc.locate(FLAGS.load_dataset_fn_path)
    ds = load_dataset_fn(
        task=FLAGS.task,
        subtask=FLAGS.subtask,
        split=FLAGS.split,
        tokenizer=tokenizer,
        sequence_length=FLAGS.sequence_length,
        **_read_flag_kwargs(FLAGS.load_dataset_fn_kwargs)
    )
    if FLAGS.dataset_offset is not None:
        ds = ds.skip(FLAGS.dataset_offset)
    dataloader = DataLoader(ds, batch_size=1)

    class_selector = getattr(class_selectors, FLAGS.class_selector).create(**_read_flag_kwargs(FLAGS.class_selector_kwargs))
    position_selector = getattr(position_selectors, FLAGS.position_selector).create(**_read_flag_kwargs(FLAGS.position_selector_kwargs))

    random_projection_params = _maybe_read_random_projection_params()
    parameters, param_infos = _get_parameters_and_infos(model)
    fisher_computer = _get_fisher_computer(model, parameters, random_projection_params)

    saver = frdn_lrm_pefs.StreamingLrmPefSaver.create(
        model=model,
        fisher_computer=fisher_computer,
        position_selector=position_selector,
        class_subset_selector=class_selector,
        label_key=FLAGS.label_key,
        device=device,
        save_examples=FLAGS.save_examples,
        save_labels=FLAGS.save_labels,
        save_logits=FLAGS.save_logits,
        save_top_n_log_probs=FLAGS.save_top_n_log_probs,
        parameter_infos=param_infos,
        random_projection_params=None if fisher_computer.is_output_projected() else random_projection_params,
        lm_suffix_mc_save_only_first_example=FLAGS.lm_suffix_mc_save_only_first_example,
    )

    saver.compute_and_save_pefs(
        filepath=FLAGS.output_filepath,
        dataloader=dataloader,
        n_examples=FLAGS.n_examples,
    )


if __name__ == "__main__":
    app.run(main)
