"""Saves logits and labels to disk."""
import json
import os
import pydoc
from typing import Optional

from absl import app
from absl import flags

import h5py
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import transformers

from npeff_torch.models import lm_mcqa
from npeff_torch.models import lm_suffix_mc
from npeff_torch.models import model_utils
from npeff_torch.util import hdf5_utils
from npeff_torch.util import tokenizer_utils

###############################################################################

FLAGS = flags.FLAGS


flags.DEFINE_string('output_filepath', None, '')

flags.DEFINE_string('model', None, '')
flags.DEFINE_string('model_cls', None, '')

flags.DEFINE_string('tokenizer', None, 'If set to None, will be the same as --model.')


flags.DEFINE_string('task', None, '')
flags.DEFINE_string('subtask', None, '')
flags.DEFINE_string('split', None, '')

flags.DEFINE_integer('sequence_length', None, '')
flags.DEFINE_integer('n_examples', None, '')
flags.DEFINE_integer('batch_size', 32, '')

flags.DEFINE_integer('dataset_offset', None, '')

flags.DEFINE_string('load_dataset_fn_path', None, 'Should return a HuggingFace dataset ready to be passed to a dataloader.')
# If these are provided, they should be a JSON dict mapping parameter names to their values. Currently, only
# JSON-encodable values can be provided.
flags.DEFINE_string('load_dataset_fn_kwargs', None, '')

flags.DEFINE_string('label_key', 'labels', '')


# Multiple-choice question answering as language modeling stuff.
flags.DEFINE_bool('lm_mcqa', False, '')
flags.DEFINE_list('lm_mcqa_answer_labels', None, '')
flags.DEFINE_string('lm_mcqa_answer_label_prefix', '', '')


# Multiple-choice as sentence completion as language modeling stuff.
flags.DEFINE_bool('lm_suffix_mc', False, '')


###############################################################################


def _read_flag_kwargs(flag_value: Optional[str]):
    if flag_value:
        return json.loads(flag_value)
    else:
        return {}


###############################################################################


@torch.no_grad()
def main(_):
    assert FLAGS.n_examples > 0

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    ModelClass = getattr(transformers, FLAGS.model_cls)
    model = ModelClass.from_pretrained(FLAGS.model).to(device)

    tokenizer = tokenizer_utils.from_pretrained(FLAGS.tokenizer or FLAGS.model)

    if FLAGS.lm_mcqa:
        assert not FLAGS.lm_suffix_mc
        model = lm_mcqa.LmMcqaLogitsComputer(
            model=model,
            tokenizer=tokenizer,
            answer_labels=FLAGS.lm_mcqa_answer_labels,
            answer_label_prefix=FLAGS.lm_mcqa_answer_label_prefix,
            device=device,
        )
    elif FLAGS.lm_suffix_mc:
        assert not FLAGS.lm_mcqa
        model = lm_suffix_mc.LmSuffixMcLogitsComputer(
            model=model,
            device=device,
        )

    load_dataset_fn = pydoc.locate(FLAGS.load_dataset_fn_path)
    ds = load_dataset_fn(
        task=FLAGS.task,
        subtask=FLAGS.subtask,
        split=FLAGS.split,
        tokenizer=tokenizer,
        sequence_length=FLAGS.sequence_length,
        **_read_flag_kwargs(FLAGS.load_dataset_fn_kwargs)
    )
    if FLAGS.dataset_offset is not None:
        ds = ds.skip(FLAGS.dataset_offset)
    dataloader = DataLoader(ds, batch_size=FLAGS.batch_size)
    dl_iter = iter(dataloader)

    n_examples = 0
    labels = []
    logits = []
    progress_bar = tqdm(total=FLAGS.n_examples)
    while n_examples < FLAGS.n_examples:
        batch = next(dl_iter)
        batch_labels = batch[FLAGS.label_key]
        batch_logits = model_utils.compute_logits(model, batch, device)

        n_to_keep = min(FLAGS.n_examples - n_examples, batch_logits.shape[0])
        labels.append(batch_labels[:n_to_keep])
        logits.append(batch_logits[:n_to_keep])

        progress_bar.update(n_to_keep)
        n_examples += n_to_keep

    logits = torch.cat(logits, dim=0)
    labels = torch.cat(labels, dim=0)
    
    with h5py.File(os.path.expanduser(FLAGS.output_filepath), "w") as f:
        hdf5_utils.save_h5_ds(f, 'data/logits', logits.detach().cpu().numpy())
        hdf5_utils.save_h5_ds(f, 'data/labels', labels.detach().cpu().numpy())


if __name__ == "__main__":
    app.run(main)
