"""Selects examples for ICL by taking them at random and evaluating.

Takes accuracy as the metric.
"""
import json
import os
import pydoc
import random
from typing import Optional, List, Tuple

from absl import app
from absl import flags

import numpy as np
import torch
import transformers
from transformers import PreTrainedTokenizer
from tqdm import tqdm

from npeff_torch.datasets import preprocessing_common
from npeff_torch.icl.example_selection import icl_evaluation_results
from npeff_torch.models import lm_mcqa
from npeff_torch.models import model_utils
from npeff_torch.peis.fishers.formats import pef_format_common
from npeff_torch.perturbations import evaluation_contexts

from npeff_torch.util import hdf5_utils
from npeff_torch.util import tokenizer_utils


R"""
- Basically, just randomly select examples (either from the full set or NPEFF top examples) and evaluate.
- Save something that is like example indices, labelled examples, and score for each evaluation.
    - Can choose later.
    - Make the saved examples be usable directly to construct the context.
"""

###############################################################################
FLAGS = flags.FLAGS

flags.DEFINE_string('output_filepath', None, '')


flags.DEFINE_list('pef_filepaths', None, 
                  'The PEF files used to compute the NPEFF coefficients, which MUST be in '
                  'the same order. These MUST have the examples saved.')

flags.DEFINE_list('n_examples_per_pef', None,
                  "Comma-separated list of integers indicating the number of examples to use from each PEF file. "
                  "If provided, the list must be the same length as the --pef_filepaths list. "
                  "Leave empty to use all examples from all PEFs. "
                  "Use a value of -1 for a particular PEF to use all examples from that particular PEF.")


flags.DEFINE_string('model', None, '')
flags.DEFINE_string('model_cls', None, '')

flags.DEFINE_string("tokenizer", None, "If left None, assumed to be equal to --model.")


flags.DEFINE_integer('n_context_examples', None, '')
flags.DEFINE_integer('n_trials', None, '')

flags.DEFINE_integer('n_evaluation_examples', None, '')
flags.DEFINE_integer('batch_size', None, '')


flags.DEFINE_integer('rng_seed', None, '')


# Multiple-choice question answering as language modeling stuff.
flags.DEFINE_list('lm_mcqa_answer_labels', None, '')
flags.DEFINE_string('lm_mcqa_answer_label_prefix', '', '')

flags.DEFINE_list('label_index_map', [],
                  'label_index_map[i] should be the index in the lm_mcqa_answer_labels corresponding to the label i '
                  'as represented in pef_filepaths.')


###############################################################################
_EXAMPLE_SEP_STR = '\n\n'
###############################################################################


def _check_valid_output_filepath():
    output_filepath = FLAGS.output_filepath
    assert FLAGS.output_filepath is not None, 'The --output_filepath flag must be provided.'

    output_filepath = os.path.expanduser(output_filepath)
    assert os.path.isdir(os.path.dirname(output_filepath)), 'Invalid --output_filepath.'

    return output_filepath


def _read_n_examples_per_pef_flag(flag_value: Optional[List[str]]) -> Optional[List[Optional[int]]]:
    if not flag_value:
        return None
    ret = []
    for n_examples in flag_value:
        n_examples = int(n_examples)
        if n_examples < 0:
            ret.append(None)
        else:
            ret.append(n_examples)
    return ret


def _get_label_index_map() -> Optional[np.ndarray]:
    if not FLAGS.label_index_map:
        return None
    return np.array([int(i) for i in FLAGS.label_index_map], dtype=np.int32)


def _read_in_model(tokenizer, device: torch.device) -> 'transformers.PreTrainedModel':
    ModelClass = getattr(transformers, FLAGS.model_cls)
    model = ModelClass.from_pretrained(FLAGS.model).to(device)
    # TODO: See if we want this for during unlearning?
    model.eval()

    # Prevent the following error:
    #   ValueError: Cannot handle batch sizes > 1 if no padding token is defined.
    if model.config.pad_token_id is None:
        model.config.pad_token_id = tokenizer.pad_token_id

    model = lm_mcqa.LmMcqaLogitsComputer(
        model=model,
        tokenizer=tokenizer,
        answer_labels=FLAGS.lm_mcqa_answer_labels,
        answer_label_prefix=FLAGS.lm_mcqa_answer_label_prefix,
        device=device,
    )

    return model


###############################################################################


def _make_evaluation_examples(
    *,
    rng: random.Random,
    tokenizer: PreTrainedTokenizer,
    pef_extra_infos: 'pef_format_common.PefExtraInfos',
) -> Tuple[List[str], np.ndarray, List[int]]:
    evaluation_example_indices = rng.sample(range(pef_extra_infos.n_examples), k=FLAGS.n_evaluation_examples)

    examples, labels = [], []
    for example_index in evaluation_example_indices:
        input_ids = pef_extra_infos.examples['input_ids'][example_index]
        attention_mask = pef_extra_infos.examples['attention_mask'][example_index]
        
        examples.append(tokenizer.decode(input_ids[attention_mask != 0]))
        labels.append(pef_extra_infos.labels[example_index])

    labels = np.stack(labels, axis=0)

    return examples, labels, evaluation_example_indices


def _make_icl_context(
    *,
    rng: random.Random,
    model: 'lm_mcqa.LmMcqaLogitsComputer',
    tokenizer: PreTrainedTokenizer,
    pef_extra_infos: 'pef_format_common.PefExtraInfos',
    context_example_index_options: List[int],
) -> Tuple[str, List[int]]:
    context_example_indices = rng.sample(context_example_index_options, k=FLAGS.n_context_examples)

    labelled_examples = []
    for example_index in context_example_indices:
        input_ids = pef_extra_infos.examples['input_ids'][example_index]
        attention_mask = pef_extra_infos.examples['attention_mask'][example_index]

        context = tokenizer.decode(input_ids[attention_mask != 0])
        label = model.answer_labels[int(pef_extra_infos.labels[example_index])]
        labelled_examples.append(f'{context}{model.answer_label_prefix}{label}')

    return _EXAMPLE_SEP_STR.join(labelled_examples), context_example_indices


@torch.no_grad()
def _evaluate_predictions(
    *,
    model: 'lm_mcqa.LmMcqaLogitsComputer',
    tokenizer: PreTrainedTokenizer,
    icl_context: str,
    evaluation_examples: List[str],
    device: torch.device,
) -> np.ndarray:
    batch_size = FLAGS.batch_size

    icl_examples = [
        _EXAMPLE_SEP_STR.join([icl_context, example])
        for example in evaluation_examples
    ]

    predictions = []
    for offset in tqdm(range(0, len(icl_examples), batch_size)):
        icl_examples_batch = icl_examples[offset : offset + batch_size]

        # TODO: See if the varying sequence length will cause problems.
        tokenized_batch = tokenizer(
            icl_examples_batch,
            return_tensors="pt",
            add_special_tokens=False,
            padding=True,
        )

        logits = model_utils.compute_logits(model, tokenized_batch, device=device)
        batch_predictions = torch.argmax(logits, dim=-1).detach().cpu().numpy()
        predictions.append(batch_predictions)

    return np.concatenate(predictions, axis=0)


###############################################################################


@torch.no_grad()
def main(_):
    # Check these now so that we don't error out after doing all the work.
    assert FLAGS.rng_seed is not None
    output_filepath = _check_valid_output_filepath()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    tokenizer = tokenizer_utils.from_pretrained(FLAGS.tokenizer or FLAGS.model)
    model = _read_in_model(tokenizer, device)

    n_examples_per_pef = _read_n_examples_per_pef_flag(FLAGS.n_examples_per_pef)
    pef_extra_infos = pef_format_common.PefExtraInfos.read_from_files(FLAGS.pef_filepaths, n_examples_per_pef)
    assert pef_extra_infos.examples is not None
    assert pef_extra_infos.n_examples is not None
    assert pef_extra_infos.labels is not None

    if (label_index_map := _get_label_index_map()) is not None:
        pef_extra_infos.labels = label_index_map[pef_extra_infos.labels]

    rng = random.Random(FLAGS.rng_seed)
    evaluation_examples, evaluation_labels, evaluation_example_indices = _make_evaluation_examples(
        rng=rng,
        tokenizer=tokenizer,
        pef_extra_infos=pef_extra_infos,
    )

    # Make sure we do not use evaluation examples as the context.
    context_example_index_options = list(set(range(pef_extra_infos.n_examples)) - set(evaluation_example_indices))

    results_json = []
    for trial_index in tqdm(range(FLAGS.n_trials)):
        icl_context, context_example_indices = _make_icl_context(
            rng=rng,
            model=model,
            tokenizer=tokenizer,
            pef_extra_infos=pef_extra_infos,
            context_example_index_options=context_example_index_options,
        )

        predictions = _evaluate_predictions(
            model=model,
            tokenizer=tokenizer,
            icl_context=icl_context,
            evaluation_examples=evaluation_examples,
            device=device,
        )

        accuracy = float((evaluation_labels == predictions).mean())
        print(f'acc: {accuracy}')

        result = icl_evaluation_results.ContextEvaluationResults(
            icl_context=icl_context,
            score=accuracy,
            icl_context_example_indices=context_example_indices,
        )
        results_json.append(result.to_json())

        with open(os.path.expanduser(output_filepath), 'wt') as f:
            json.dump(results_json, f)


if __name__ == "__main__":
    app.run(main)
