import json
import math
import sys
from pathlib import Path

import colorful
import horovod.torch as hvd
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR

from meta_learning_sdf import (click, create_checkpoint_directory,
                               load_trainer, save_model, save_trainer)
from meta_learning_sdf.datasets.shapenet.uniform_sparse_sampling import (
    UniformPointCloudDataset, MinibatchGenerator, MinibatchIterator)
from meta_learning_sdf.experiments.baseline import (
    Events, LossFunction, ModelHyperparameters, Trainer, TrainerState,
    TrainingHyperparameters, log_loss, log_message, setup_model)


def _split_list(array, segments):
    assert len(array) >= segments
    num_elements_per_segment = math.ceil(len(array) / segments)
    ret = []
    for _ in range(segments - 1):
        ret.append(array[:num_elements_per_segment])
        array = array[num_elements_per_segment:]
    ret.append(array)
    return ret


class LearningRateScheduler:
    def __init__(self, decrease_every: int, num_decay: int, last_epoch: int):
        self.decrease_every = decrease_every
        self.num_decay = num_decay
        self.last_gamma = 1
        self.last_epoch = last_epoch
        self.current_num_decay = 0
        for k in range(last_epoch):
            self(k + 1)

    def __call__(self, epoch):
        if epoch == self.last_epoch:
            return self.last_gamma

        self.last_epoch = epoch
        if epoch == 0:
            return self.last_gamma

        if self.current_num_decay >= self.num_decay:
            return self.last_gamma

        if epoch % self.decrease_every == 0:
            self.last_gamma = self.last_gamma * 0.5
            self.current_num_decay += 1
            return self.last_gamma
        return self.last_gamma


class KLDivergenceWeightScheduler:
    def __init__(self, initial_value: float, final_value: float,
                 max_epochs: int):
        self.initial_value = initial_value
        self.final_value = final_value
        self.max_epochs = max_epochs

    def __call__(self, epoch):
        if epoch == 0:
            return self.initial_value

        if epoch > self.max_epochs:
            return self.final_value

        diff = self.final_value - self.initial_value
        current_value = diff * (epoch / self.max_epochs) + self.initial_value
        return current_value


@click.command()
@click.argument("--dataset-directory", type=str, required=True)
@click.argument("--split-path", type=str, required=True)
@click.argument("--checkpoint-directory", type=str, default=None)
@click.argument("--checkpoint-directory-root", type=str, default="run")
@click.argument("--log-every", type=int, default=1)
@click.argument("--serialize-every", type=int, default=1)
@click.argument("--evaluate-every", type=int, default=100)
@click.argument("--checkpoint-freq", type=int, default=100)
@click.argument("--debug-mode", is_flag=True)
@click.argument("--max-epochs", type=int, default=10000)
@click.argument("--in-memory-dataset", is_flag=True)
@click.hyperparameter_class(ModelHyperparameters)
@click.hyperparameter_class(TrainingHyperparameters)
def main(args: click.Arguments, model_hyperparams: ModelHyperparameters,
         training_hyperparams: TrainingHyperparameters):
    hvd.init()
    rank = hvd.rank()
    local_rank = hvd.local_rank()
    num_processes = hvd.size()
    device = torch.device("cuda", local_rank)

    # ==============================================================================
    # Dataset
    # ==============================================================================
    dataset_directory = Path(args.dataset_directory)
    split_path = Path(args.split_path)
    assert split_path.is_file()

    with open(split_path) as f:
        split = json.load(f)

    data_path_list_train = []
    for category in split:
        model_id_list = split[category]
        for model_id in model_id_list:
            npz_path = dataset_directory / category / "depth" / f"{model_id}.npz"
            if not npz_path.exists():
                if rank == 0:
                    print(f"{npz_path} missing.")
                continue
            data_path_list_train.append(npz_path)

    dataset_size_train = len(data_path_list_train)
    assert dataset_size_train >= training_hyperparams.batchsize

    data_path_list_train = _split_list(data_path_list_train,
                                       num_processes)[rank]

    dataset_train = UniformPointCloudDataset(
        data_path_list_train, memory_caching=args.in_memory_dataset)

    minibatch_generator = MinibatchGenerator(
        min_num_context=training_hyperparams.min_num_context,
        max_num_context=training_hyperparams.max_num_context,
        min_num_target=training_hyperparams.min_num_target,
        max_num_target=training_hyperparams.max_num_target,
        device=device)

    minibatch_iterator_train = MinibatchIterator(
        dataset_train,
        batchsize=training_hyperparams.batchsize // num_processes,
        minibatch_generator=minibatch_generator,
        drop_last=True)
    minibatch_iterations = hvd.allgather(
        torch.full((1, ), len(minibatch_iterator_train), dtype=int))
    max_minibatch_iterations = torch.min(minibatch_iterations).item()
    print(
        f"# Minibatch iterations: {len(minibatch_iterator_train)} -> {max_minibatch_iterations}",
        flush=True)

    # ==============================================================================
    # Checkpoint directory
    # ==============================================================================
    run_id = (model_hyperparams + training_hyperparams).hash()[:12]
    checkpoint_directory = create_checkpoint_directory(args, run_id)
    if rank == 0:
        print(colorful.bold("Checkpoint directory: %s" % checkpoint_directory))

    args_path = checkpoint_directory / "args.json"
    model_path = checkpoint_directory / "model.pt"
    trainer_path = checkpoint_directory / "trainer.pt"

    if rank == 0:
        args.save(args_path)

    # ==============================================================================
    # Model
    # ==============================================================================
    model = setup_model(model_hyperparams)
    model.to(device)

    # ==============================================================================
    # Optimizer
    # ==============================================================================
    optimizer = optim.Adam(model.parameters(),
                           lr=training_hyperparams.learning_rate)

    optimizer = hvd.DistributedOptimizer(
        optimizer, named_parameters=model.named_parameters())
    hvd.broadcast_parameters(model.state_dict(), root_rank=0)
    hvd.broadcast_optimizer_state(optimizer, root_rank=0)

    kld_weight = training_hyperparams.loss_kld_initial_weight
    if training_hyperparams.anneal_kld_weight:
        kld_weight_scheduler = KLDivergenceWeightScheduler(
            initial_value=training_hyperparams.loss_kld_initial_weight,
            final_value=training_hyperparams.loss_kld_final_weight,
            max_epochs=training_hyperparams.kld_weight_annealing_epochs)
        kld_weight_func = lambda: kld_weight_scheduler(trainer.state.epoch)
    else:
        kld_weight_func = lambda: kld_weight

    loss_function = LossFunction(
        tau=training_hyperparams.loss_tau,
        lam=training_hyperparams.loss_lambda,
        kld_weight_func=kld_weight_func,
        num_eikonal_samples=training_hyperparams.num_eikonal_samples,
        eikonal_term_stddev=training_hyperparams.eikonal_term_stddev)

    # ==============================================================================
    # Training
    # ==============================================================================
    initial_state = TrainerState(kld_weight=kld_weight)
    trainer = Trainer(model=model,
                      optimizer=optimizer,
                      loss_function=loss_function,
                      debug_mode=args.debug_mode,
                      initial_state=initial_state,
                      max_minibatch_iterations=max_minibatch_iterations)
    load_trainer(trainer_path, trainer)

    lr_lambda = LearningRateScheduler(
        decrease_every=training_hyperparams.decrease_lr_every,
        num_decay=training_hyperparams.num_lr_decay,
        last_epoch=trainer.state.last_epoch)
    last_epoch = trainer.state.last_epoch if trainer.state.last_epoch > 0 else -1
    scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda, last_epoch=last_epoch)

    class GpuStatus:
        checked = False

        def __call__(self):
            if self.checked is False:
                print("The device is working properly.")
                self.checked = True

    check_gpu_status = GpuStatus()

    def print_exception(e: Exception):
        tb = sys.exc_info()[2]
        print(colorful.bold_red(e.with_traceback(tb)))

    @trainer.on(Events.ITERATION_COMPLETED)
    def on_iteration_completed(trainer: Trainer):
        check_gpu_status()

    @trainer.on(Events.EPOCH_COMPLETED)
    def on_epoch_completed(trainer: Trainer):
        epoch = trainer.state.epoch
        scheduler.step(epoch)
        if rank == 0:
            if epoch == 1 or epoch % args.log_every == 0:
                print(log_message(run_id, model, trainer,
                                  float(kld_weight_func())),
                      flush=True)

            if epoch % args.serialize_every == 0:
                try:
                    save_model(model_path, model)
                    save_trainer(trainer_path, trainer)
                    log_loss(trainer, checkpoint_directory / "loss.csv")
                except OSError as e:
                    print_exception(e)

            if epoch % args.checkpoint_freq == 0:
                try:
                    path = checkpoint_directory / f"model.{epoch}.pt"
                    save_model(path, model)
                except OSError as e:
                    print_exception(e)

    if rank == 0:
        print(args)
        print(model)
        print("# Data:", dataset_size_train)
        print("# Params:", model.num_parameters(), flush=True)
        if trainer.state.epoch > 1:
            print("Resume training from epoch",
                  trainer.state.epoch,
                  flush=True)

    trainer.run(minibatch_iterator_train, max_epochs=args.max_epochs)


if __name__ == "__main__":
    main()
