"""Prints information about the entropies of p(y|x)."""
import json
import pydoc
from typing import Optional

from absl import app
from absl import flags

import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import transformers

from npeff_torch.models import lm_mcqa
from npeff_torch.models import lm_suffix_mc
from npeff_torch.models import model_utils

from npeff_torch.util import tokenizer_utils

###############################################################################

FLAGS = flags.FLAGS

flags.DEFINE_string('model', None, '')
flags.DEFINE_string('model_cls', None, '')

flags.DEFINE_string('tokenizer', None, 'If set to None, will be the same as --model.')


flags.DEFINE_string('task', None, '')
flags.DEFINE_string('subtask', None, '')
flags.DEFINE_string('split', None, '')

flags.DEFINE_integer('sequence_length', None, '')


flags.DEFINE_string('load_dataset_fn_path', None, 'Should return a HuggingFace dataset ready to be passed to a dataloader.')
# If these are provided, they should be a JSON dict mapping parameter names to their values. Currently, only
# JSON-encodable values can be provided.
flags.DEFINE_string('load_dataset_fn_kwargs', None, '')


# Multiple-choice question answering as language modeling stuff.
flags.DEFINE_bool('lm_mcqa', False, '')
flags.DEFINE_list('lm_mcqa_answer_labels', None, '')
flags.DEFINE_string('lm_mcqa_answer_label_prefix', '', '')


# Multiple-choice as sentence completion as language modeling stuff.
flags.DEFINE_bool('lm_suffix_mc', False, '')


flags.DEFINE_integer('batch_size', None, '')
flags.DEFINE_float('entropy_threshold', None, '')

flags.DEFINE_integer('n_batches', None, '')
flags.DEFINE_integer('n_examples', None, '')


###############################################################################


def _read_flag_kwargs(flag_value: Optional[str]):
    if flag_value:
        return json.loads(flag_value)
    else:
        return {}

###############################################################################


@torch.no_grad()
def main(_):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    ModelClass = getattr(transformers, FLAGS.model_cls)
    model = ModelClass.from_pretrained(FLAGS.model).to(device)

    tokenizer = tokenizer_utils.from_pretrained(FLAGS.tokenizer or FLAGS.model)

    if FLAGS.lm_mcqa:
        assert not FLAGS.lm_suffix_mc
        model = lm_mcqa.LmMcqaLogitsComputer(
            model=model,
            tokenizer=tokenizer,
            answer_labels=FLAGS.lm_mcqa_answer_labels,
            answer_label_prefix=FLAGS.lm_mcqa_answer_label_prefix,
            device=device,
        )
    elif FLAGS.lm_suffix_mc:
        assert not FLAGS.lm_mcqa
        model = lm_suffix_mc.LmSuffixMcLogitsComputer(
            model=model,
            device=device,
        )

    load_dataset_fn = pydoc.locate(FLAGS.load_dataset_fn_path)
    ds = load_dataset_fn(
        task=FLAGS.task,
        subtask=FLAGS.subtask,
        split=FLAGS.split,
        tokenizer=tokenizer,
        sequence_length=FLAGS.sequence_length,
        **_read_flag_kwargs(FLAGS.load_dataset_fn_kwargs)
    )

    if FLAGS.n_examples is not None:
        ds = ds.take(FLAGS.n_examples)

    dataloader = DataLoader(ds, batch_size=FLAGS.batch_size)

    #

    entropies = []

    for i, batch in enumerate(tqdm(dataloader)):
        logits = model_utils.compute_logits(model, batch, device)

        if len(logits.shape) == 3:
            # NOTE: Specific to triviaqa.
            positions = (batch['attention_mask'] != 0).to(torch.int64).sum(dim=-1) - 1
            logits = torch.stack([logits[i, p] for i, p in enumerate(positions)], dim=0)
        else:
            assert len(logits.shape) == 2

        log_probs = torch.log_softmax(logits, dim=-1)
        probs = torch.softmax(logits, dim=-1)
        batch_entropy = -torch.einsum('bi,bi->b', probs, log_probs)

        entropies.append(batch_entropy)

        if FLAGS.n_batches is not None and i + 1 >= FLAGS.n_batches:
            break

    #

    entropies = torch.cat(entropies)
    print(f'mean_entropy: {entropies.mean().detach().cpu().numpy()}')

    frac_greater_entropy = (entropies >= FLAGS.entropy_threshold).to(torch.float32).mean()
    print(f'frac_greater_entropy: {frac_greater_entropy.detach().cpu().numpy()}')

    n_greater_entropy = (entropies >= FLAGS.entropy_threshold).to(torch.int64).sum().detach().cpu().numpy()
    print(f'{n_greater_entropy} / {entropies.numel()} had greater entropy')


if __name__ == "__main__":
    app.run(main)
