"""The handler for training and evaluation."""

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

import torch
import time
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.utilities import rank_zero_info

from arguments import arg_parser, save_args
# from pl_tsp_model_rddm import TSPModel
# from pl_tsp_model_rddm_largescale import TSPModel
# from pl_tsp_model_rddm_tsplib import TSPModel
# from pl_tsp_puredecoding import TSPModel
from pl_mis_model_rddm import MISModel
import multiprocessing
import warnings
 
warnings.filterwarnings("ignore")
def main(args):
  epochs = args.num_epochs

  if args.task == 'tsp':
    model_class = TSPModel
    saving_mode = 'min'
  elif args.task == 'mis':
    model_class = MISModel
    saving_mode = 'max'
  else:
    raise NotImplementedError

  model = model_class(param_args=args)

  custom = args.logger_name
  timeStr = custom + '-' + time.strftime('%Y.%m.%d-%H-%M-%S', time.localtime(time.time()))

  tb_logger = TensorBoardLogger(
    save_dir=os.path.join(args.storage_path, f'models'),
    name=timeStr 
  )
  print(timeStr)
  save_args(args, timeStr)

  checkpoint_callback = ModelCheckpoint(
      monitor='val/solved_cost', mode=saving_mode,
      save_top_k=5, save_last=True,
      dirpath=os.path.join(tb_logger.save_dir,
                           timeStr,
                           'checkpoints'),
  )
  lr_callback = LearningRateMonitor(logging_interval='step')

  trainer = Trainer(
      accelerator="auto",
      devices=torch.cuda.device_count() if torch.cuda.is_available() else None,
      max_epochs=epochs,
      callbacks=[TQDMProgressBar(refresh_rate=20), checkpoint_callback, lr_callback],
      logger=tb_logger,
      check_val_every_n_epoch=1,
      strategy=DDPStrategy(static_graph=True),
      precision=16 if args.fp16 else 32,
  )

  ckpt_path = args.ckpt_path

  if args.do_train:
    if args.resume_weight_only:
      model = model_class.load_from_checkpoint(ckpt_path, param_args=args)
      trainer.fit(model)
    else:
      trainer.fit(model, ckpt_path=ckpt_path)

    if args.do_test:
      trainer.test(ckpt_path=checkpoint_callback.best_model_path)

  elif args.do_test:
    trainer.validate(model, ckpt_path=ckpt_path)
    if not args.do_valid_only:
      trainer.test(model, ckpt_path=ckpt_path)
  
  elif args.do_predict:
    trainer.predict(model, ckpt_path=ckpt_path)

  trainer.logger.finalize("success")


if __name__ == '__main__':
  multiprocessing.set_start_method('spawn')
  args = arg_parser()
  main(args)
