"""Prints the accuracies of tasks."""
import json
import pydoc
from typing import Optional

from absl import app
from absl import flags

import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import transformers

from npeff_torch.models import lm_mcqa
from npeff_torch.models import lm_suffix_mc
from npeff_torch.models import model_utils

from npeff_torch.util import tokenizer_utils

###############################################################################

FLAGS = flags.FLAGS

flags.DEFINE_string('model', None, '')
flags.DEFINE_string('model_cls', None, '')

flags.DEFINE_string('tokenizer', None, 'If set to None, will be the same as --model.')


flags.DEFINE_string('task', None, '')
flags.DEFINE_string('subtask', None, '')
flags.DEFINE_string('split', None, '')

flags.DEFINE_integer('sequence_length', None, '')


flags.DEFINE_string('load_dataset_fn_path', None, 'Should return a HuggingFace dataset ready to be passed to a dataloader.')
# If these are provided, they should be a JSON dict mapping parameter names to their values. Currently, only
# JSON-encodable values can be provided.
flags.DEFINE_string('load_dataset_fn_kwargs', None, '')


# Multiple-choice question answering as language modeling stuff.
flags.DEFINE_bool('lm_mcqa', False, '')
flags.DEFINE_list('lm_mcqa_answer_labels', None, '')
flags.DEFINE_string('lm_mcqa_answer_label_prefix', '', '')


# Multiple-choice as sentence completion as language modeling stuff.
flags.DEFINE_bool('lm_suffix_mc', False, '')


flags.DEFINE_integer('batch_size', None, '')
flags.DEFINE_integer('n_examples', None, '')

flags.DEFINE_string('label_key', 'labels', '')


flags.DEFINE_list('label_permutation', [], '')


###############################################################################


def _read_flag_kwargs(flag_value: Optional[str]):
    if flag_value:
        return json.loads(flag_value)
    else:
        return {}


def _get_label_permutation(device):
    if FLAGS.label_permutation:
        return torch.tensor([int(i) for i in FLAGS.label_permutation], dtype=torch.int64, device=device)
    else:
        return None


###############################################################################


@torch.no_grad()
def main(_):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    ModelClass = getattr(transformers, FLAGS.model_cls)
    model = ModelClass.from_pretrained(FLAGS.model).to(device)

    tokenizer = tokenizer_utils.from_pretrained(FLAGS.tokenizer or FLAGS.model)

    if FLAGS.lm_mcqa:
        assert not FLAGS.lm_suffix_mc
        model = lm_mcqa.LmMcqaLogitsComputer(
            model=model,
            tokenizer=tokenizer,
            answer_labels=FLAGS.lm_mcqa_answer_labels,
            answer_label_prefix=FLAGS.lm_mcqa_answer_label_prefix,
            device=device,
        )
    elif FLAGS.lm_suffix_mc:
        assert not FLAGS.lm_mcqa
        model = lm_suffix_mc.LmSuffixMcLogitsComputer(
            model=model,
            device=device,
        )

    load_dataset_fn = pydoc.locate(FLAGS.load_dataset_fn_path)
    ds = load_dataset_fn(
        task=FLAGS.task,
        subtask=FLAGS.subtask,
        split=FLAGS.split,
        tokenizer=tokenizer,
        sequence_length=FLAGS.sequence_length,
        **_read_flag_kwargs(FLAGS.load_dataset_fn_kwargs)
    )

    if FLAGS.n_examples is not None:
        ds = ds.take(FLAGS.n_examples)

    dataloader = DataLoader(ds, batch_size=FLAGS.batch_size)

    #

    preds = []
    labels = []

    for i, batch in enumerate(tqdm(dataloader)):
        logits = model_utils.compute_logits(model, batch, device)

        assert len(logits.shape) == 2

        preds.append(torch.argmax(logits, dim=-1))
        labels.append(batch[FLAGS.label_key])

    #

    preds = torch.cat(preds)
    labels = torch.cat(labels).to(preds.device)

    label_permutation = _get_label_permutation(device)
    if label_permutation is not None:
        labels = label_permutation[labels]

    acc = (preds == labels).type(torch.float32).mean().detach().cpu().numpy()
    print(f'accuracy: {acc}')


if __name__ == "__main__":
    app.run(main)
