import argparse
import json
import math
import pathlib
import sys

import torch

from rau.tasks.common.data import load_prepared_data_file
from rau.tasks.common.training_loop import MicroAveragedScoreAccumulator
from rau.tasks.language_modeling.model import LanguageModelingModelInterface

from intervention_sampling.neural_networks.batching import group_into_batches_with_extra

def generate_batches(examples, max_tokens):
    return group_into_batches_with_extra(examples, lambda b, n: b * n <= max_tokens)

def evaluate_batch(model, model_interface, prepared_batch):
    model_input, correct_target = prepared_batch
    pad_index = model_interface.output_padding_index
    logits = model_interface.get_logits(model, model_input)
    token_cross_entropy = torch.nn.functional.cross_entropy(
        logits.permute(0, 2, 1),
        correct_target,
        ignore_index=pad_index,
        reduction='none'
    )
    sequence_cross_entropy = torch.sum(token_cross_entropy, dim=1)
    token_negative_log_probabilities = -torch.nn.functional.log_softmax(
        logits,
        dim=2
    )
    return (
        sequence_cross_entropy,
        token_negative_log_probabilities
    )

def evaluate(model, model_interface, batches, num_examples):
    device = model_interface.get_device(None)
    example_scores = [None] * num_examples
    model.eval()
    with torch.inference_mode():
        for indexed_batch in batches:
            batch = [x for x, i in indexed_batch]
            prepared_batch = model_interface.prepare_batch(batch, device)
            cross_entropy, token_negative_log_probabilities = evaluate_batch(
                model,
                model_interface,
                prepared_batch
            )
            for (
                (x, i),
                example_cross_entropy,
                example_token_negative_log_probabilities
            ) in zip(
                indexed_batch,
                cross_entropy.tolist(),
                token_negative_log_probabilities.cpu().numpy()
            ):
                example_scores[i] = (
                    (example_cross_entropy, len(x)+1),
                    example_token_negative_log_probabilities[:len(x)+1]
                )
    return example_scores

def main():

    model_interface = LanguageModelingModelInterface(
        use_load=True,
        use_init=False,
        use_output=False,
        require_output=False
    )

    parser = argparse.ArgumentParser(
        description=
        'Evaluate a language model on a dataset. Output the results as JSON.'
    )
    parser.add_argument('--training-data', type=pathlib.Path,
        help='A directory containing training data. The file '
             '<training-data>/datasets/<input>/main.prepared will be used as '
             'input.')
    parser.add_argument('--input',
        help='Name of a dataset in the training data directory that will be '
             'used as input. The file '
             '<training-data>/datasets/<input>/main.prepared will be used as '
             'input.')
    parser.add_argument('--input-file', type=pathlib.Path,
        help='A .prepared file to be used as input. This overrides '
             '--training-data and --input.')
    parser.add_argument('--output', type=pathlib.Path, required=True,
        help='A directory where output files will be written.')
    parser.add_argument('--batching-max-tokens', type=int, required=True,
        help='The maximum number of tokens allowed per batch.')
    model_interface.add_arguments(parser)
    model_interface.add_forward_arguments(parser)
    args = parser.parse_args()

    if args.input_file is not None:
        input_file = args.input_file
    elif args.training_data is not None and args.input is not None:
        input_file = args.training_data / 'datasets' / args.input / 'main.prepared'
    else:
        parser.error('either --training-data and --input or --input-file is required')

    examples = load_prepared_data_file(input_file)
    examples = [(x, i) for i, x in enumerate(examples)]
    saver = model_interface.construct_saver(args)
    batches = generate_batches(examples, args.batching_max_tokens)
    results = evaluate(saver.model, model_interface, batches, len(examples))
    fout_name = args.output / 'token-negative-log-probabilities.pt'
    print(f'writing {fout_name}')
    torch.save([x for _, x in results], fout_name)
    results = [x for x, _ in results]
    accumulator = MicroAveragedScoreAccumulator()
    fout_name = args.output / 'scores.jsonl'
    print(f'writing {fout_name}')
    with fout_name.open('w') as fout:
        for numerator, denominator in results:
            accumulator.update(numerator, denominator)
            json.dump({
                'cross_entropy_per_token' : (numerator, denominator)
            }, fout)
            print(file=fout)
    cross_entropy_per_token = accumulator.get_value()
    fout_name = args.output / 'scores.json'
    print(f'writing {fout_name}')
    with fout_name.open('w') as fout:
        json.dump({ 'cross_entropy_per_token' : cross_entropy_per_token }, fout, indent=2)
        print(file=fout)

if __name__ == '__main__':
    main()
