from pathlib import Path

import colorful
import torch
import torch.optim as optim

from implicit_geometric_regularization import (click,
                                               create_checkpoint_directory,
                                               load_model, load_trainer,
                                               save_model, save_trainer)
from implicit_geometric_regularization.datasets.shapenet import (
    MinibatchGenerator, MinibatchIterator, UniformPointCloudDataset)
from implicit_geometric_regularization.experiments.surface_reconstruction import (
    DecoderHyperparameters, Events, LossFunction, Trainer,
    TrainingHyperparameters, setup_decoder, log_message)


@click.command()
@click.argument("--data-path", type=str, required=True)
@click.argument("--checkpoint-directory", type=str, default=None)
@click.argument("--checkpoint-directory-root", type=str, default="run")
@click.argument("--log-every", type=int, default=100)
@click.argument("--serialize-every", type=int, default=1000)
@click.argument("--evaluate-every", type=int, default=1000)
@click.argument("--checkpoint-freq", type=int, default=10000)
@click.argument("--debug-mode", is_flag=True)
@click.argument("--max-epochs", type=int, default=1000)
@click.argument("--in-memory-dataset", is_flag=True)
@click.hyperparameter_class(DecoderHyperparameters)
@click.hyperparameter_class(TrainingHyperparameters)
def main(args: click.Arguments, decoder_hyperparams: DecoderHyperparameters,
         training_hyperparams: TrainingHyperparameters):
    device = torch.device("cuda", 0)

    # ==============================================================================
    # Checkpoint directory
    # ==============================================================================
    run_id = (decoder_hyperparams + training_hyperparams).hash()[:12]
    checkpoint_directory = create_checkpoint_directory(args, run_id)
    print(colorful.bold("Checkpoint directory: %s" % checkpoint_directory))

    args_path = checkpoint_directory / "args.json"
    model_path = checkpoint_directory / "model.pt"
    trainer_path = checkpoint_directory / "trainer.pt"

    args.save(args_path)

    # ==============================================================================
    # Dataset
    # ==============================================================================
    dataset_train = UniformPointCloudDataset(
        [(args.data_path, 0)], memory_caching=args.in_memory_dataset)
    minibatch_generator = MinibatchGenerator(
        num_point_samples=training_hyperparams.num_surface_samples,
        with_normal=training_hyperparams.with_normal,
        device=device)
    minibatch_iterator = MinibatchIterator(
        dataset_train, batchsize=1, minibatch_generator=minibatch_generator)

    # ==============================================================================
    # Model
    # ==============================================================================
    decoder = setup_decoder(decoder_hyperparams)
    decoder.to(device)
    loss_function = LossFunction(
        tau=training_hyperparams.loss_tau,
        lam=training_hyperparams.loss_lambda,
        num_eikonal_samples=training_hyperparams.num_eikonal_samples)

    # ==============================================================================
    # Optimizer
    # ==============================================================================
    optimizer = optim.Adam(decoder.parameters(),
                           lr=training_hyperparams.learning_rate)

    # ==============================================================================
    # Training
    # ==============================================================================
    trainer = Trainer(model=decoder,
                      optimizer=optimizer,
                      loss_function=loss_function,
                      debug_mode=args.debug_mode)
    load_trainer(trainer_path, trainer)

    @trainer.on(Events.EPOCH_COMPLETED)
    def on_epoch_completed(trainer: Trainer):
        epoch = trainer.state.epoch
        if epoch % args.log_every == 0:
            print(log_message(run_id, decoder, trainer), flush=True)

        if epoch % args.serialize_every == 0:
            save_model(model_path, decoder)
            save_trainer(trainer_path, trainer)

        if epoch % args.checkpoint_freq == 0:
            path = checkpoint_directory / f"model.{epoch}.pt"
            save_model(path, decoder)

    print(args)
    print(decoder)
    print("# Params:", decoder.num_parameters())
    trainer.run(minibatch_iterator, max_epochs=args.max_epochs)


if __name__ == "__main__":
    main()
