from pathlib import Path

import colorful
import torch
import torch.optim as optim

from deep_sdf import (click, create_checkpoint_directory, load_model,
                      load_trainer, save_model, save_trainer)
from deep_sdf.datasets.shapenet import (MinibatchGenerator, MinibatchIterator,
                                        SdfDataset)
from deep_sdf.experiments.surface_reconstruction import (
    Decoder, DecoderHyperparameters, Events, LossFunction, Trainer,
    TrainingHyperparameters, define_decoder, log_message)


@click.command()
@click.argument("--data-path", type=str, required=True)
@click.argument("--checkpoint-directory", type=str, default=None)
@click.argument("--checkpoint-directory-root", type=str, default="run")
@click.argument("--log-every", type=int, default=1000)
@click.argument("--serialize-every", type=int, default=1000)
@click.argument("--checkpoint-freq", type=int, default=2000000)
@click.argument("--debug-mode", is_flag=True)
@click.argument("--max-epochs", type=int, default=2000000)
@click.argument("--in-memory-dataset", is_flag=True)
@click.hyperparameter_class(DecoderHyperparameters)
@click.hyperparameter_class(TrainingHyperparameters)
def main(args: click.Arguments, decoder_hyperparams: DecoderHyperparameters,
         training_hyperparams: TrainingHyperparameters):
    device = torch.device("cuda", 0)

    # ==============================================================================
    # Checkpoint directory
    # ==============================================================================
    run_id = (decoder_hyperparams + training_hyperparams).hash()[:12]
    checkpoint_directory = create_checkpoint_directory(args, run_id)
    print(colorful.bold("Checkpoint directory: %s" % checkpoint_directory))

    args_path = checkpoint_directory / "args.json"
    model_path = checkpoint_directory / "model.pt"
    trainer_checkpoint_path = checkpoint_directory / "trainer.pt"

    args.save(args_path)

    # ==============================================================================
    # Dataset
    # ==============================================================================
    dataset_train = SdfDataset([(args.data_path, 0)], memory_caching=True)
    minibatch_generator = MinibatchGenerator(
        num_sdf_samples=training_hyperparams.num_sdf_samples, device=device)
    minibatch_iterator = MinibatchIterator(
        dataset_train, batchsize=1, minibatch_generator=minibatch_generator)

    # ==============================================================================
    # Model
    # ==============================================================================
    decoder = define_decoder(decoder_hyperparams)
    decoder.to(device)
    loss_function = LossFunction(
        clamping_distance=training_hyperparams.clamping_distance)

    # ==============================================================================
    # Optimizer
    # ==============================================================================
    optimizer = optim.Adam(decoder.parameters(),
                           lr=training_hyperparams.learning_rate)

    # ==============================================================================
    # Training
    # ==============================================================================
    trainer = Trainer(model=decoder,
                      optimizer=optimizer,
                      loss_function=loss_function,
                      debug_mode=args.debug_mode)
    load_trainer(trainer_checkpoint_path, trainer)

    @trainer.on(Events.EPOCH_COMPLETED)
    def on_epoch_completed(trainer: Trainer):
        epoch = trainer.state.epoch
        if epoch % args.log_every == 0:
            print(log_message(run_id, decoder, trainer), flush=True)

        if epoch % args.serialize_every == 0:
            save_model(model_path, decoder)
            save_trainer(trainer_checkpoint_path, trainer)

        if epoch % args.checkpoint_freq == 0:
            path = checkpoint_directory / f"model.{epoch}.pt"
            save_model(path, decoder)

    print(args)
    print(decoder)
    print("# Params:", decoder.num_parameters())
    trainer.run(minibatch_iterator, max_epochs=args.max_epochs)


if __name__ == "__main__":
    main()
