import json
import math
import sys
from pathlib import Path

import colorful
import horovod.torch as hvd
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR

from pcn import (click, create_checkpoint_directory, load_trainer, save_model,
                 save_trainer)
from pcn.datasets.partial_sampling import (Dataset, MinibatchDescription,
                                           MinibatchGenerator,
                                           MinibatchIterator)
from pcn.experiment import (Events, LossFunction, Model, ModelHyperparameters,
                            Trainer, TrainerState, TrainingHyperparameters,
                            log_loss, log_message, setup_model)


def _split_list(array, segments):
    assert len(array) >= segments
    num_elements_per_segment = math.ceil(len(array) / segments)
    ret = []
    for _ in range(segments - 1):
        ret.append(array[:num_elements_per_segment])
        array = array[num_elements_per_segment:]
    ret.append(array)
    return ret


class LearningRateScheduler:
    def __init__(self, decrease_every: int, initial_lr: float, min_lr: float,
                 last_steps: int, factor: float):
        self.decrease_every = decrease_every
        self.initial_lr = initial_lr
        self.min_lr = min_lr
        self.factor = factor
        self.last_gamma = 1
        self.last_steps = last_steps
        self.current_num_decay = 0
        for k in range(last_steps):
            self(k + 1)

    def __call__(self, num_grad_updates: int):
        if num_grad_updates == self.last_steps:
            return self.last_gamma

        current_lr = self.initial_lr * self.factor**self.current_num_decay
        if current_lr <= self.min_lr:
            return self.last_gamma

        self.last_steps = num_grad_updates
        if num_grad_updates == 0:
            return self.last_gamma

        if num_grad_updates % self.decrease_every == 0:
            self.last_gamma = self.last_gamma * self.factor
            self.current_num_decay += 1
            return self.last_gamma
        return self.last_gamma


@click.command()
@click.argument("--dataset-directory", type=str, required=True)
@click.argument("--split-path", type=str, required=True)
@click.argument("--checkpoint-directory", type=str, default=None)
@click.argument("--checkpoint-directory-root", type=str, default="run")
@click.argument("--log-every", type=int, default=1)
@click.argument("--serialize-every", type=int, default=1)
@click.argument("--checkpoint-freq", type=int, default=1000)
@click.argument("--debug-mode", is_flag=True)
@click.argument("--max-epochs", type=int, default=2000)
@click.argument("--in-memory-dataset", is_flag=True)
@click.hyperparameter_class(ModelHyperparameters)
@click.hyperparameter_class(TrainingHyperparameters)
def main(args: click.Arguments, model_hyperparams: ModelHyperparameters,
         training_hyperparams: TrainingHyperparameters):
    hvd.init()
    rank = hvd.rank()
    local_rank = hvd.local_rank()
    num_processes = hvd.size()
    device = torch.device("cuda", local_rank)

    # ==============================================================================
    # Dataset
    # ==============================================================================
    dataset_directory = Path(args.dataset_directory)
    split_path = Path(args.split_path)
    assert split_path.is_file()

    with open(split_path) as f:
        split = json.load(f)

    data_path_list = []
    for category in split:
        model_id_list = split[category]
        for model_id in model_id_list:
            surface_data = dataset_directory / category / model_id / "partial_point_cloud.npz"
            if not surface_data.exists():
                if rank == 0:
                    print(f"{surface_data} missing.")
                continue
            data_path_list.append(surface_data)

    dataset_size = len(data_path_list)
    assert dataset_size >= training_hyperparams.batchsize

    data_path_list = _split_list(data_path_list, num_processes)[rank]

    dataset = Dataset(data_path_list,
                      num_coarse_points=model_hyperparams.num_coarse_gt_points,
                      num_dense_points=model_hyperparams.num_dense_gt_points)

    minibatch_generator = MinibatchGenerator(
        num_input_points=training_hyperparams.num_input_points, device=device)

    minibatch_iterator_train = MinibatchIterator(
        dataset,
        batchsize=training_hyperparams.batchsize // num_processes,
        minibatch_generator=minibatch_generator,
        drop_last=True)
    minibatch_iterations = hvd.allgather(
        torch.full((1, ), len(minibatch_iterator_train), dtype=int))
    max_minibatch_iterations = torch.min(minibatch_iterations).item()
    print(
        f"# Minibatch iterations: {len(minibatch_iterator_train)} -> {max_minibatch_iterations}",
        flush=True)

    # ==============================================================================
    # Checkpoint directory
    # ==============================================================================
    run_id = (model_hyperparams + training_hyperparams).hash()[:12]
    checkpoint_directory = create_checkpoint_directory(args, run_id)
    if rank == 0:
        print(colorful.bold("Checkpoint directory: %s" % checkpoint_directory))

    args_path = checkpoint_directory / "args.json"
    model_path = checkpoint_directory / "model.pt"
    trainer_checkpoint_path = checkpoint_directory / "trainer.pt"

    if rank == 0:
        args.save(args_path)

    # ==============================================================================
    # Model
    # ==============================================================================
    model = setup_model(model_hyperparams)
    model.to(device)

    # ==============================================================================
    # Optimizer
    # ==============================================================================
    optimizer = optim.Adam(model.parameters(),
                           lr=training_hyperparams.learning_rate)

    optimizer = hvd.DistributedOptimizer(
        optimizer, named_parameters=model.named_parameters())
    hvd.broadcast_parameters(model.state_dict(), root_rank=0)
    hvd.broadcast_optimizer_state(optimizer, root_rank=0)

    # ==============================================================================
    # Training
    # ==============================================================================
    compute_loss = LossFunction()

    def loss_function(model: Model, data: MinibatchDescription):
        steps = trainer.state.num_gradient_updates
        if steps < 10000:
            alpha = 0.01
        elif steps < 20000:
            alpha = 0.1
        elif steps < 50000:
            alpha = 0.5
        else:
            alpha = 1.0
        trainer.state.alpha = alpha
        return compute_loss(model=model, data=data, alpha=alpha)

    initial_state = TrainerState()
    trainer = Trainer(model=model,
                      optimizer=optimizer,
                      loss_function=loss_function,
                      debug_mode=args.debug_mode,
                      initial_state=initial_state,
                      max_minibatch_iterations=max_minibatch_iterations)
    load_trainer(trainer_checkpoint_path, trainer)

    lr_lambda = LearningRateScheduler(
        decrease_every=training_hyperparams.lr_decay_every,
        initial_lr=training_hyperparams.learning_rate,
        min_lr=training_hyperparams.lr_clip,
        last_steps=trainer.state.num_gradient_updates,
        factor=training_hyperparams.lr_decay_factor)

    last_epoch = (trainer.state.num_gradient_updates
                  if trainer.state.num_gradient_updates > 0 else -1)
    scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda, last_epoch=last_epoch)

    class GpuStatus:
        checked = False

        def __call__(self):
            if self.checked is False:
                print("The device is working properly.")
                self.checked = True

    check_gpu_status = GpuStatus()

    def print_exception(e: Exception):
        tb = sys.exc_info()[2]
        print(colorful.bold_red(e.with_traceback(tb)))

    @trainer.on(Events.ITERATION_COMPLETED)
    def on_iteration_completed(trainer: Trainer):
        check_gpu_status()
        scheduler.step()

    @trainer.on(Events.EPOCH_COMPLETED)
    def on_epoch_completed(trainer: Trainer):
        epoch = trainer.state.epoch
        if rank == 0:
            if epoch == 1 or epoch % args.log_every == 0:
                print(log_message(run_id, model, trainer), flush=True)

            if epoch % args.serialize_every == 0:
                try:
                    save_model(model_path, model)
                    save_trainer(trainer_checkpoint_path, trainer)
                    log_loss(trainer, checkpoint_directory / "loss.csv")
                except OSError as e:
                    print_exception(e)

            if epoch % args.checkpoint_freq == 0:
                try:
                    path = checkpoint_directory / f"model.{epoch}.pt"
                    save_model(path, model)
                except OSError as e:
                    print_exception(e)

    if rank == 0:
        print(args)
        print(model)
        print("# Data:", dataset_size)
        print("# Params:", model.num_parameters(), flush=True)
        if trainer.state.epoch > 1:
            print("Resume training from epoch",
                  trainer.state.epoch,
                  flush=True)

    with torch.cuda.device(device):
        trainer.run(minibatch_iterator_train, max_epochs=args.max_epochs)


if __name__ == "__main__":
    main()
