import json
import math
import sys
from pathlib import Path

import colorful
import horovod.torch as hvd
import torch
import torch.optim as optim

from occupancy_networks import (click, create_checkpoint_directory,
                                load_trainer, save_model, save_trainer)
from occupancy_networks.datasets.surface.uniform_sparse_sampling import (
    MinibatchGenerator, MinibatchIterator, SdfSurfacePairDataset)
from occupancy_networks.experiments.encoder import (
    Events, LossFunction, ModelHyperparameters, Trainer,
    TrainingHyperparameters, log_loss, log_message, setup_model)


def _split_list(array, segments):
    assert len(array) >= segments
    num_elements_per_segment = math.ceil(len(array) / segments)
    ret = []
    for _ in range(segments - 1):
        ret.append(array[:num_elements_per_segment])
        array = array[num_elements_per_segment:]
    ret.append(array)
    return ret


@click.command()
@click.argument("--sdf-dataset-directory", type=str, required=True)
@click.argument("--surface-dataset-directory", type=str, required=True)
@click.argument("--split-path", type=str, required=True)
@click.argument("--checkpoint-directory", type=str, default=None)
@click.argument("--checkpoint-directory-root", type=str, default="run")
@click.argument("--log-every", type=int, default=1)
@click.argument("--serialize-every", type=int, default=1)
@click.argument("--checkpoint-freq", type=int, default=100)
@click.argument("--debug-mode", is_flag=True)
@click.argument("--max-epochs", type=int, default=2000)
@click.argument("--in-memory-dataset", is_flag=True)
@click.hyperparameter_class(ModelHyperparameters)
@click.hyperparameter_class(TrainingHyperparameters)
def main(args: click.Arguments, model_hyperparams: ModelHyperparameters,
         training_hyperparams: TrainingHyperparameters):
    hvd.init()
    rank = hvd.rank()
    local_rank = hvd.local_rank()
    num_processes = hvd.size()
    device = torch.device("cuda", local_rank)

    # ==============================================================================
    # Dataset
    # ==============================================================================
    sdf_dataset_directory = Path(args.sdf_dataset_directory)
    surface_dataset_directory = Path(args.surface_dataset_directory)
    split_path = Path(args.split_path)
    assert split_path.is_file()

    with open(split_path) as f:
        split = json.load(f)

    data_path_pair_list = []
    for category in split:
        model_id_list = split[category]
        for model_id in model_id_list:
            sdf_data_path = sdf_dataset_directory / category / model_id / "sdf.npz"
            if not sdf_data_path.exists():
                if rank == 0:
                    print(f"{sdf_data_path} missing.")
                continue
            surface_data_path = surface_dataset_directory / category / model_id / "point_cloud.npz"
            if not surface_data_path.exists():
                if rank == 0:
                    print(f"{surface_data_path} missing.")
                continue
            data_path_pair_list.append((sdf_data_path, surface_data_path))

    dataset_size = len(data_path_pair_list)
    assert dataset_size >= training_hyperparams.batchsize

    data_path_pair_list = _split_list(data_path_pair_list, num_processes)[rank]

    dataset = SdfSurfacePairDataset(data_path_pair_list,
                                    memory_caching=args.in_memory_dataset)

    data_generator = MinibatchGenerator(
        num_input_points=training_hyperparams.num_input_points,
        num_gt_points=training_hyperparams.num_gt_points,
        noise_stddev=training_hyperparams.input_points_noise_stddev,
        device=device)

    minibatch_iterator_train = MinibatchIterator(
        dataset,
        batchsize=training_hyperparams.batchsize // num_processes,
        data_generator=data_generator,
        drop_last=True)
    minibatch_iterations = hvd.allgather(
        torch.full((1, ), len(minibatch_iterator_train), dtype=int))
    max_minibatch_iterations = torch.min(minibatch_iterations).item()
    print(
        f"# Minibatch iterations: {len(minibatch_iterator_train)} -> {max_minibatch_iterations}",
        flush=True)

    # ==============================================================================
    # Checkpoint directory
    # ==============================================================================
    run_id = (model_hyperparams + training_hyperparams).hash()[:12]
    checkpoint_directory = create_checkpoint_directory(args, run_id)
    if rank == 0:
        print(colorful.bold("Checkpoint directory: %s" % checkpoint_directory))

    args_path = checkpoint_directory / "args.json"
    model_path = checkpoint_directory / "model.pt"
    trainer_checkpoint_path = checkpoint_directory / "trainer.pt"

    if rank == 0:
        args.save(args_path)

    # ==============================================================================
    # Model
    # ==============================================================================
    model = setup_model(model_hyperparams)
    model.to(device)

    # ==============================================================================
    # Optimizer
    # ==============================================================================
    optimizer = optim.Adam(model.parameters(),
                           lr=training_hyperparams.learning_rate)

    optimizer = hvd.DistributedOptimizer(
        optimizer, named_parameters=model.named_parameters())
    hvd.broadcast_parameters(model.state_dict(), root_rank=0)
    hvd.broadcast_optimizer_state(optimizer, root_rank=0)

    # ==============================================================================
    # Training
    # ==============================================================================
    loss_function = LossFunction()
    trainer = Trainer(model=model,
                      optimizer=optimizer,
                      loss_function=loss_function,
                      debug_mode=args.debug_mode,
                      max_minibatch_iterations=max_minibatch_iterations)
    load_trainer(trainer_checkpoint_path, trainer)

    class GpuStatus:
        checked = False

        def __call__(self):
            if self.checked is False:
                print("The device is working properly.")
                self.checked = True

    check_gpu_status = GpuStatus()

    def print_exception(e: Exception):
        tb = sys.exc_info()[2]
        print(colorful.bold_red(e.with_traceback(tb)))

    @trainer.on(Events.ITERATION_COMPLETED)
    def on_iteration_completed(trainer: Trainer):
        check_gpu_status()

    @trainer.on(Events.EPOCH_COMPLETED)
    def on_epoch_completed(trainer: Trainer):
        epoch = trainer.state.epoch
        if rank == 0:
            if epoch == 1 or epoch % args.log_every == 0:
                print(log_message(run_id, model, trainer), flush=True)

            if epoch % args.serialize_every == 0:
                try:
                    save_model(model_path, model)
                    save_trainer(trainer_checkpoint_path, trainer)
                    log_loss(trainer, checkpoint_directory / "loss.csv")
                except OSError as e:
                    print_exception(e)

            if epoch % args.checkpoint_freq == 0:
                try:
                    path = checkpoint_directory / f"model.{epoch}.pt"
                    save_model(path, model)
                except OSError as e:
                    print_exception(e)

    if rank == 0:
        print(args)
        print(model)
        print("# Data:", dataset_size)
        print("# Params:", model.num_parameters(), flush=True)
        if trainer.state.epoch > 1:
            print("Resume training from epoch",
                  trainer.state.epoch,
                  flush=True)

    trainer.run(minibatch_iterator_train, max_epochs=args.max_epochs)


if __name__ == "__main__":
    main()
