import json
import math
import random
from pathlib import Path

import colorful
import horovod.torch as hvd
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR

from deep_sdf import (click, create_checkpoint_directory, load_model,
                      load_trainer, save_model, save_trainer)
from deep_sdf.datasets.shapenet import (MinibatchGenerator, MinibatchIterator,
                                        SdfDataset)
from deep_sdf.experiments.learning_shape_space import (
    Decoder, DecoderHyperparameters, Events, LossFunction,
    ModelHyperparameters, Trainer, TrainingHyperparameters, log_loss,
    log_message, setup_model)


def _split_list(array, segments):
    assert len(array) >= segments
    num_elements_per_segment = math.ceil(len(array) / segments)
    ret = []
    for _ in range(segments - 1):
        ret.append(array[:num_elements_per_segment])
        array = array[num_elements_per_segment:]
    ret.append(array)
    return ret


class LearningRateScheduler:
    def __init__(self, decrease_every: int, num_decay: int, last_epoch: int):
        self.decrease_every = decrease_every
        self.num_decay = num_decay
        self.last_gamma = 1
        self.last_epoch = last_epoch
        self.current_num_decay = 0
        for k in range(last_epoch):
            self(k + 1)

    def __call__(self, epoch):
        if epoch == self.last_epoch:
            return self.last_gamma

        self.last_epoch = epoch
        if epoch == 0:
            return self.last_gamma

        if self.current_num_decay >= self.num_decay:
            return self.last_gamma

        if epoch % self.decrease_every == 0:
            self.last_gamma = self.last_gamma * 0.5
            self.current_num_decay += 1
            return self.last_gamma
        return self.last_gamma


@click.command()
@click.argument("--dataset-directory", type=str, required=True)
@click.argument("--train-split-path", type=str, required=True)
@click.argument("--checkpoint-directory", type=str, default=None)
@click.argument("--checkpoint-directory-root", type=str, default="run")
@click.argument("--log-every", type=int, default=1)
@click.argument("--serialize-every", type=int, default=1)
@click.argument("--checkpoint-freq", type=int, default=100)
@click.argument("--debug-mode", is_flag=True)
@click.argument("--max-epochs", type=int, default=2000)
@click.argument("--in-memory-dataset", is_flag=True)
@click.hyperparameter_class(ModelHyperparameters)
@click.hyperparameter_class(DecoderHyperparameters)
@click.hyperparameter_class(TrainingHyperparameters)
def main(args: click.Arguments, model_hyperparams: ModelHyperparameters,
         decoder_hyperparams: DecoderHyperparameters,
         training_hyperparams: TrainingHyperparameters):
    hvd.init()
    rank = hvd.rank()
    local_rank = hvd.local_rank()
    num_processes = hvd.size()
    device = torch.device("cuda", local_rank)

    # ==============================================================================
    # Dataset
    # ==============================================================================
    dataset_directory = Path(args.dataset_directory)
    train_split_path = Path(args.train_split_path)
    assert train_split_path.is_file()

    with open(train_split_path) as f:
        split = json.load(f)

    data_path_list_train = []
    for category in split:
        model_id_list = split[category]
        for model_id in model_id_list:
            npz_path = dataset_directory / category / model_id / "sdf.npz"
            if not npz_path.exists():
                continue
            data_path_list_train.append(npz_path)

    dataset_size_train = len(data_path_list_train)

    path_index_pair_list_train = []
    for path_index, path in enumerate(data_path_list_train):
        path_index_pair_list_train.append((path, path_index))

    path_index_pair_list_train = _split_list(path_index_pair_list_train,
                                             num_processes)[rank]

    dataset_train = SdfDataset(path_index_pair_list_train,
                               memory_caching=args.in_memory_dataset)
    minibatch_generator = MinibatchGenerator(
        num_sdf_samples=training_hyperparams.num_sdf_samples, device=device)

    minibatch_iterator_train = MinibatchIterator(
        dataset_train,
        batchsize=training_hyperparams.batchsize // num_processes,
        minibatch_generator=minibatch_generator,
        drop_last=True)
    minibatch_iterations = hvd.allgather(
        torch.full((1, ), len(minibatch_iterator_train), dtype=int))
    max_minibatch_iterations = torch.min(minibatch_iterations).item()
    print(
        f"# Minibatch iterations: {len(minibatch_iterator_train)} -> {max_minibatch_iterations}",
        flush=True)

    # ==============================================================================
    # Checkpoint directory
    # ==============================================================================
    model_hyperparams.num_data = dataset_size_train
    run_id = (model_hyperparams + decoder_hyperparams +
              training_hyperparams).hash()[:12]
    checkpoint_directory = create_checkpoint_directory(args, run_id)
    if rank == 0:
        print(colorful.bold("Checkpoint directory: %s" % checkpoint_directory))

    args_path = checkpoint_directory / "args.json"
    model_path = checkpoint_directory / "model.pt"
    trainer_checkpoint_path = checkpoint_directory / "trainer.pt"

    args.num_data = dataset_size_train
    args.save(args_path)

    # ==============================================================================
    # Model
    # ==============================================================================
    model = setup_model(model_hyperparams, decoder_hyperparams)
    model.to(device)
    loss_function = LossFunction(
        lam=training_hyperparams.loss_lam,
        clamping_distance=training_hyperparams.clamping_distance)

    # ==============================================================================
    # Optimizer
    # ==============================================================================
    optimizer = optim.Adam([{
        "params": model.decoder.parameters()
    }, {
        "params": model.z_map,
        "lr": training_hyperparams.learning_rate_for_latent
    }],
                           lr=training_hyperparams.learning_rate)

    optimizer = hvd.DistributedOptimizer(
        optimizer, named_parameters=model.named_parameters())
    hvd.broadcast_parameters(model.state_dict(), root_rank=0)
    hvd.broadcast_optimizer_state(optimizer, root_rank=0)

    # ==============================================================================
    # Training
    # ==============================================================================
    trainer = Trainer(model=model,
                      optimizer=optimizer,
                      loss_function=loss_function,
                      debug_mode=args.debug_mode,
                      max_minibatch_iterations=max_minibatch_iterations)
    load_trainer(trainer_checkpoint_path, trainer)

    lr_lambda = LearningRateScheduler(training_hyperparams.decrease_lr_every,
                                      training_hyperparams.num_lr_decay,
                                      trainer.state.epoch)
    last_epoch = trainer.state.last_epoch if trainer.state.last_epoch > 0 else -1
    scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda, last_epoch=last_epoch)

    class GpuStatus:
        checked = False

        def __call__(self):
            if self.checked is False:
                print("The device is working properly.")
                self.checked = True

    check_gpu_status = GpuStatus()

    def print_exception(e: Exception):
        tb = sys.exc_info()[2]
        print(colorful.bold_red(e.with_traceback(tb)))

    @trainer.on(Events.ITERATION_COMPLETED)
    def on_iteration_completed(trainer: Trainer):
        check_gpu_status()

    @trainer.on(Events.EPOCH_COMPLETED)
    def on_epoch_completed(trainer: Trainer):
        epoch = trainer.state.epoch
        scheduler.step(epoch)
        if rank == 0:
            if epoch == 1 or epoch % args.log_every == 0:
                print(log_message(run_id, model, trainer), flush=True)

            if epoch % args.serialize_every == 0:
                try:
                    save_model(model_path, model)
                    save_trainer(trainer_checkpoint_path, trainer)
                    log_loss(trainer, checkpoint_directory / "loss.csv")
                except OSError as e:
                    print_exception(e)

            if epoch % args.checkpoint_freq == 0:
                path = checkpoint_directory / f"model.{epoch}.pt"
                try:
                    save_model(path, model)
                except OSError as e:
                    print_exception(e)

    if rank == 0:
        print(args)
        print(model)
        print("# GPUs:", num_processes)
        print("# Training Data:", dataset_size_train)
        print("# Params:", model.num_parameters(), flush=True)
        if trainer.state.epoch > 1:
            print("Resume training from epoch",
                  trainer.state.epoch,
                  flush=True)

    trainer.run(minibatch_iterator_train, max_epochs=args.max_epochs)


if __name__ == "__main__":
    main()
