# -*- coding: UTF-8 -*-

import os
import time
from argparse import ArgumentParser, Namespace

import nni
import torch

import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.strategies import DDPStrategy

from config import proj_cfg
from utils.metric import BestAccuracy, NNIreport

from interface.model_imagenet224_interface import ImageNet224Interface

ModelInterface = ImageNet224Interface


def main(args):
    # If using nni, then recover the parameters with nni args
    if args.is_nni:
        nni_params = nni.get_next_parameter()
        params = vars(args)
        params = dict(list(params.items()) + list(nni_params.items()))
        args = Namespace(**params)
    print(args)

    # -------------- Set seed --------------
    if args.seed is not None:
        pl.seed_everything(args.seed)

    # -------------- Set DDP ---------------
    if args.strategy == 'ddp':
        # When using a single GPU per process and per
        # DistributedDataParallel, we need to divide the batch size
        # ourselves based on the total number of GPUs we have
        devices = int(args.devices)
        args.batch_size = int(args.batch_size / max(1, devices))
        args.workers = int(args.workers / max(1, devices))

    # -------------- Set Save Path ----------------------------
    model_info_str = ModelInterface.get_model_info(args)
    save_path = os.path.join(proj_cfg.save_root, args.data_name, args.arch, "%.5f" % args.lr, model_info_str,
                             time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time())) + "_%d" % args.seed)

    if args.is_nni:
        save_path += '_' + nni.get_trial_id()

    assert not (args.model_load_path and args.continual_train_model_path), "Loading path cannot exist with continual path."
    # -------------- Set Model with/without Load Path --------------
    if args.model_load_path:
        model = ModelInterface.load_from_checkpoint(args.model_load_path)
    else:
        model = ModelInterface(**vars(args))

    # -------------- Set Continual Training Path --------------
    resume_from_checkpoint = args.continual_train_model_path

    # -------------- Set Callbacks --------------
    callbacks = [ModelCheckpoint(os.path.join(save_path, "checkpoint"),
                                 monitor="val/acc1",
                                 filename="{val/acc1:.2f}",
                                 # save_weights_only=True,
                                 save_last=True,
                                 mode='max'),
                 # ModelCheckpoint(os.path.join(save_path, "checkpoint"), save_last=True),
                 BestAccuracy(1),
                 ]

    if args.save_continual_checkpoint:
        callbacks.append(ModelCheckpoint(os.path.join(save_path, "checkpoint")))

    if args.is_nni:
        callbacks.append(NNIreport())

    # -------------- Set Strategy --------------
    # DDP Strategy
    strategy = DDPStrategy(find_unused_parameters=False) if args.strategy == 'ddp' else None

    # -------------- Set Trainer --------------
    trainer = pl.Trainer.from_argparse_args(
        args,
        reload_dataloaders_every_n_epochs=1,
        logger=pl_loggers.TensorBoardLogger(save_path, None, 'log', default_hp_metric=False),
        callbacks=callbacks,
        strategy=strategy,
        resume_from_checkpoint=resume_from_checkpoint)

    # -------------- Training --------------
    trainer.fit(model)


if __name__ == '__main__':
    parent_parser = ArgumentParser(add_help=False)
    parent_parser = pl.Trainer.add_argparse_args(parent_parser)
    parent_parser.add_argument('--seed', type=int, default=233, help='normal random seed')
    parent_parser.add_argument("--is_nni", action="store_true", default=False)

    parent_parser.add_argument("--save_continual_checkpoint", action="store_true", default=False,
                        help='to save the last model for continual training')
    parent_parser.add_argument('--model_load_path', type=str, default=None,
                        help='a model load path, set to None to disable the load')
    parent_parser.add_argument('--continual_train_model_path', type=str, default=None,
                        help='continual training model folder, set to None to disable the keep training')

    parser = ModelInterface.add_model_specific_args(parent_parser)
    parser.set_defaults(
        accelerator='auto',
        profiler="simple",
        max_epochs=90,
    )

    args = parser.parse_args()
    main(args)
