import argparse
import json
import logging
import pathlib
import sys

import humanfriendly
import wandb

from rau.tools.torch.profile import get_current_memory

from recognizers.neural_networks.data import (
    add_data_arguments,
    load_prepared_data,
    load_vocabulary_data
)
from recognizers.neural_networks.model_interface import RecognitionModelInterface
from recognizers.neural_networks.training_loop import (
    RecognitionTrainingLoop,
    add_training_loop_arguments,
    get_training_loop_kwargs
)

def main():

    # Configure logging to stdout.
    console_logger = logging.getLogger('main')
    console_logger.addHandler(logging.StreamHandler(sys.stdout))
    console_logger.setLevel(logging.INFO)
    console_logger.info(f'arguments: {sys.argv}')

    model_interface = RecognitionModelInterface()

    # Parse command-line arguments.
    parser = argparse.ArgumentParser(
        description=
        'Train a recognizer.'
    )
    add_data_arguments(parser)
    model_interface.add_arguments(parser)
    model_interface.add_forward_arguments(parser)
    add_training_loop_arguments(parser)
    parser.add_argument('--wandb-mode',
        choices=['online', 'offline', 'disabled'],
        default='disabled')
    parser.add_argument('--wandb-entity')
    parser.add_argument('--wandb-project')
    parser.add_argument('--wandb-name')
    args = parser.parse_args()
    console_logger.info(f'parsed arguments: {args}')

    # Are we training on CPU or GPU?
    device = model_interface.get_device(args)
    console_logger.info(f'device: {device}')
    do_profile_memory = device.type == 'cuda'

    # Load the tokens in the vocabulary. This determines the sizes of the
    # embedding and softmax layers in the model.
    vocabulary_data = load_vocabulary_data(args, parser)

    if do_profile_memory:
        memory_before = get_current_memory(device)
    # Construct the model.
    saver = model_interface.construct_saver(args, vocabulary_data)
    # Log some information about the model: parameter random seed, number of
    # parameters, GPU memory.
    if model_interface.parameter_seed is not None:
        console_logger.info(f'parameter random seed: {model_interface.parameter_seed}')
    num_parameters = sum(p.numel() for p in saver.model.parameters())
    console_logger.info(f'number of parameters: {num_parameters}')
    if do_profile_memory:
        model_size_in_bytes = get_current_memory(device) - memory_before
        console_logger.info(f'model size: {humanfriendly.format_size(model_size_in_bytes)}')
    else:
        model_size_in_bytes = None

    # Load the data.
    training_data, validation_data, vocabulary \
        = load_prepared_data(args, parser, vocabulary_data, model_interface)

    # Start logging to wandb, if enabled.
    wandb_enabled = args.wandb_mode != 'disabled'
    model_dir = pathlib.Path(saver.directory_name)
    with wandb.init(
        entity=args.wandb_entity,
        project=args.wandb_project,
        dir=model_dir.absolute(),
        name=args.wandb_name,
        config=dict(model_kwargs=saver.kwargs),
        mode=args.wandb_mode
    ) as wandb_run:
        if wandb_enabled:
            wandb_file_name = model_dir / 'wandb.json'
            console_logger.info(f'writing {wandb_file_name}')
            with wandb_file_name.open('w') as fout:
                json.dump(dict(
                    entity=wandb_run.entity,
                    project=wandb_run.project,
                    mode=args.wandb_mode,
                    id=wandb_run.id
                ), fout)
        # Start logging events to disk.
        with saver.logger() as event_logger:
            model_info = dict(
                parameter_seed=model_interface.parameter_seed,
                size_in_bytes=model_size_in_bytes,
                num_parameters=num_parameters
            )
            event_logger.log('model_info', model_info)
            wandb_run.config.update(dict(model_info=model_info))
            training_info = dict(
                max_tokens_per_batch=args.max_tokens_per_batch,
                language_modeling_loss_coefficient=args.language_modeling_loss_coefficient,
                next_symbols_loss_coefficient=args.next_symbols_loss_coefficient
            )
            event_logger.log('training_info', training_info)
            wandb_run.config.update(dict(training_info=training_info))
            # Configure the training loop.
            training_loop = RecognitionTrainingLoop(
                **get_training_loop_kwargs(parser, args),
                wandb_run=wandb_run
            )
            # Run the training loop.
            training_loop.run(
                saver,
                model_interface,
                training_data,
                validation_data,
                vocabulary,
                console_logger,
                event_logger
            )
            wandb_run.log(dict(training_completed=True))

if __name__ == '__main__':
    main()
