import copy
import logging
import os
from typing import Optional, Union, List, Dict, Any, Iterable

import torch
import torch.cuda
from accelerate.optimizer import AcceleratedOptimizer
from accelerate.utils import gather
from omegaconf import DictConfig
from pado.core import PadoModule
from pado.core.accelerator import PadoAccelerator
from pado.core.base import PadoOptimizer, PadoOptimizerList, PadoScheduler, PadoSchedulerList
from pado.core.tracker import *
from pado.data.dataloader import PadoDataLoader
from pado.utils.param_utils import *
from pado.utils.swa_utils import SWAModel

__all__ = ["PadoTrainer"]


class PadoTrainer(object):

    def __init__(self,
                 model: PadoModule,
                 accelerator: PadoAccelerator,
                 save_dir: Optional[str] = None,
                 logger: Optional = None,
                 train_dataloader: Optional[PadoDataLoader] = None,
                 valid_dataloader: Optional[PadoDataLoader] = None,
                 test_dataloader: Optional[PadoDataLoader] = None,
                 optimizer: Optional[Union[PadoOptimizer, PadoOptimizerList]] = None,
                 scheduler: Optional[Union[PadoScheduler, PadoSchedulerList]] = None,
                 max_epochs: int = 1000,
                 max_iters: int = 10000000,
                 print_interval_iters: int = 50,
                 log_interval_iters: Optional[int] = None,
                 valid_interval_epochs: Union[int, float] = 1,
                 clip_grad: float = 0.0,
                 clip_grad_method: str = "norm",
                 acc_num_batches: int = 1,
                 start_epoch: int = 0,
                 start_iter: int = 0,
                 *,
                 run_test_with_valid: bool = False,
                 final_test_mode: str = "best",
                 ckpt_save_interval_iters: int = -1,
                 ckpt_save_latest: bool = True,
                 ckpt_save_best: bool = True,
                 ckpt_swa_start_epoch: int = -1,
                 ckpt_swa_start_iter: int = -1,
                 grad_debugging: bool = False,
                 verbose: bool = True) -> None:
        # -------------------------------------------------------------------------------------- #
        self.current_epoch = start_epoch
        self.current_iter = start_iter  # global iteration

        self.save_dir = save_dir
        self.verbose = verbose

        if logger is None:
            logger = logging.getLogger("pado")
        self.logger = logger

        self.max_epochs = max_epochs
        self.max_iters = max_iters
        self.print_interval_iters = print_interval_iters
        self.log_interval_iters = log_interval_iters if (log_interval_iters is not None) else print_interval_iters
        self.valid_interval_epochs = valid_interval_epochs  # if int, epoch / float: ratio
        self.clip_grad = max(clip_grad, 0.0)
        if clip_grad_method.lower() not in ("norm", "value"):
            raise ValueError(f"Gradient clipping method should be either `norm` or `value`, got {clip_grad_method}.")
        self.clip_grad_method = clip_grad_method.lower()

        # -------------------------------------------------------------------------------------- #
        # accumulate setting
        self.acc_num_batches = max(acc_num_batches, 1)
        self._current_acc_batches = 0

        # -------------------------------------------------------------------------------------- #
        # accelerate setup
        self.accelerator = accelerator
        self.fp16 = self.accelerator.use_fp16
        self.logger.info(f"Accelerator setup:\n"
                         f"... device: {accelerator.device} (fp16: {accelerator.use_fp16})\n"
                         f"... world_size: {accelerator.world_size}, local_rank: {accelerator.local_rank}\n")

        self.model = self.accelerator.prepare_model(model)

        if optimizer is None:
            self.optimizer = optimizer
        elif not isinstance(optimizer, PadoOptimizerList):
            self.optimizer = self.accelerator.prepare_optimizer(optimizer)
        else:  # is list, re-wrap the optimizer.
            self.optimizer = PadoOptimizerList([self.accelerator.prepare_optimizer(opt)
                                                for opt in optimizer.optimizers])
        self.scheduler = scheduler
        if (scheduler is not None) and hasattr(self.scheduler, "_num_iters") and (start_iter > 0):
            self.scheduler.set_num_iters(start_iter)

        self.device = self.accelerator.device
        self.is_master = self.accelerator.is_local_main_process
        self.world_size = self.accelerator.num_processes
        self.local_rank = self.accelerator.process_index

        # -------------------------------------------------------------------------------------- #
        if save_dir is not None:
            if self.is_master:
                os.makedirs(save_dir, exist_ok=True)
        self.save_dir = save_dir

        # -------------------------------------------------------------------------------------- #
        # dataloader, optimizer, scheduler
        self.train_dataloader = train_dataloader
        self.valid_dataloader = valid_dataloader if (valid_dataloader is not None) else test_dataloader
        self.test_dataloader = test_dataloader if (test_dataloader is not None) else valid_dataloader

        self.grad_debugging = grad_debugging  # check which parameter makes NaN or Inf

        # -------------------------------------------------------------------------------------- #
        # ckpt setup
        self.ckpt_save_interval_iters = ckpt_save_interval_iters
        self.ckpt_save_latest = ckpt_save_latest
        self.ckpt_save_best = ckpt_save_best
        if save_dir is not None:
            # currently we track three main checkpoints
            self.ckpt_best_path = os.path.join(save_dir, "best.ckpt")
            self.ckpt_latest_path = os.path.join(save_dir, "latest.ckpt")
            self.ckpt_swa_path = os.path.join(save_dir, "swa.ckpt")
        else:
            self.ckpt_best_path = self.ckpt_latest_path = self.ckpt_swa_path = None

        # -------------------------------------------------------------------------------------- #
        # swa setup
        self.ckpt_swa_start_epoch = ckpt_swa_start_epoch
        self.ckpt_swa_start_iters = ckpt_swa_start_iter
        if (ckpt_swa_start_epoch >= 0) or (ckpt_swa_start_iter >= 0):
            self.swa_model = SWAModel(self.accelerator.unwrap_model(self.model), device="cpu")
        else:
            self.swa_model = None

        # -------------------------------------------------------------------------------------- #
        # test and swa setup
        self.run_test_with_valid = run_test_with_valid
        self.final_test_mode = final_test_mode.lower()
        if self.final_test_mode not in ("latest", "best", "swa"):
            raise ValueError(f"Final test mode {self.final_test_mode} is unsupported.")
        elif (self.final_test_mode == "best") and (self.ckpt_best_path is None):
            raise ValueError(f"Final test mode BEST set but ckpt_best_path is None.")
        elif (self.final_test_mode == "swa") and ((self.swa_model is None) or (self.ckpt_swa_path is None)):
            raise ValueError(f"Final test mode is SWA set but swa_model is None.")

        # -------------------------------------------------------------------------------------- #
        # tracker setup
        self.time_tracker = TimeTrackerDict("epoch", "iter", "valid", "test",
                                            "data", "forward", "backward", "step")
        self.static_tracker = MetricTrackerDict("num_samples")  # tracker that is not reset through epoch
        self.epoch_tracker = MetricTrackerDict("loss")  # tracker that is reset occasionally.

        # -------------------------------------------------------------------------------------- #
        # extra
        self.init_additional()
        self._state_loaded = False

    def init_additional(self):
        # can be override by child class
        pass

    def reset_save_dir(self, save_dir: str) -> None:
        self.save_dir = save_dir
        if save_dir is not None:
            self.ckpt_best_path = os.path.join(save_dir, "best.ckpt")
            self.ckpt_latest_path = os.path.join(save_dir, "latest.ckpt")
            self.ckpt_swa_path = os.path.join(save_dir, "swa.ckpt")
        else:
            self.ckpt_best_path = self.ckpt_latest_path = self.ckpt_swa_path = None

    @property
    def current_acc_batches(self) -> int:
        return self._current_acc_batches

    def is_zero_grad_iter(self, num_iter: int) -> bool:
        # helper to use within train_epoch_body
        return num_iter % self.acc_num_batches == 0

    def is_step_iter(self, num_iter: int) -> bool:
        # helper to use within train_epoch_body
        return (num_iter % self.acc_num_batches) == (self.acc_num_batches - 1)

    def is_print_iter(self, num_iter: int) -> bool:
        # helper to use within train_epoch_body, valid_body
        return (num_iter % self.print_interval_iters == 0) and (num_iter > 0)

    def is_log_iter(self, num_iter: int) -> bool:
        # helper to use within train_epoch_body
        return (num_iter % self.log_interval_iters == 0) and (num_iter > 0)

    def is_valid_iter(self, num_iter: int, dataloader_length: int) -> bool:
        # helper to use within train_epoch_body
        c = int(self.valid_interval_epochs * dataloader_length)
        return (0 < self.valid_interval_epochs < 1) and (num_iter > 0) and (num_iter % c == 0)

    def is_swa_update(self) -> bool:
        if self.swa_model is None:
            return False

        # True if either condition is satisfied.
        is_epoch = (self.current_epoch >= self.ckpt_swa_start_epoch) if (self.ckpt_swa_start_epoch >= 0) else False
        is_iter = (self.current_iter >= self.ckpt_swa_start_iters) if (self.ckpt_swa_start_iters >= 0) else False
        return is_epoch or is_iter

    def compute_grad_norm(self) -> torch.Tensor:
        # convenient wrapper
        # we do not use self.accelerator.clip_grad_norm/value because it does not return value.
        if self.accelerator.state.use_fp16 and self.accelerator.native_amp:
            for opt in self.accelerator._optimizers:  # noqa
                self.accelerator.scaler.unscale_(opt)

        if self.clip_grad_method == "norm":
            return clip_grad_norm(self.model.parameters(), self.clip_grad)
        else:  # value
            return clip_grad_value(self.model.parameters(), self.clip_grad)

    def compute_param_norm(self) -> torch.Tensor:
        # convenient wrapper
        return compute_param_norm(self.model.parameters())

    def check_grad(self) -> None:
        """Check Gradient status. WARN: THIS MAKES NETWORK VERY SLOW."""
        if self.grad_debugging:  # double safety
            for p_name, p in self.model.named_parameters():
                if p.grad is None:
                    self.logger.warning(f"Parameter {p_name} grad is None.")
                elif torch.any(torch.isnan(p.grad)):
                    self.logger.warning(f"Parameter {p_name} grad is NaN.")
                elif torch.any(torch.isinf(p.grad)):
                    self.logger.warning(f"[WARN:TRAINER] Parameter {p_name} grad is Inf.")

    def current_lrs(self):
        if isinstance(self.optimizer, PadoOptimizerList):
            lrs = []
            for opt in self.optimizer.optimizers:
                lrs.append(opt.optimizer.current_lrs())
        else:
            assert isinstance(self.optimizer, AcceleratedOptimizer)
            return self.optimizer.optimizer.current_lrs()

    def gather(self, output):
        # can be override, default is to use accelerate.utils.gather
        output = gather(output)
        for output_key, output_value in output.items():
            if "loss" in output_key.lower():  # heuristic
                output[output_key] = torch.mean(output_value).detach()
            else:
                output[output_key] = torch.sum(output_value).detach()
        return output

    def train(self,
              train_dataloader: Optional[Iterable] = None,
              valid_dataloader: Optional[Iterable] = None,
              test_dataloader: Optional[Iterable] = None):
        # ------------------------------------------------------------ #
        if train_dataloader is None:
            train_dataloader = self.train_dataloader
        if valid_dataloader is None:
            valid_dataloader = self.valid_dataloader
        if test_dataloader is None:
            test_dataloader = self.test_dataloader

        if train_dataloader is None:
            self.logger.warning("SKIP train run because train_dataloader is None.")
            return
        # ------------------------------------------------------------ #
        self.on_train_start()

        if self._state_loaded:  # loaded from somewhere, should run valid first.
            self.valid(valid_dataloader, test_dataloader)
        else:
            self.scheduler.step()  # initial step
        # ------------------------------------------------------------ #
        while (self.current_epoch < self.max_epochs) and (self.current_iter < self.max_iters):
            # ---------------------------------------------- #
            try:
                train_dataloader.set_epoch(self.current_epoch)
            except AttributeError:
                self.logger.warning("Train dataloader does not support set_epoch. "
                                    "Maybe using inappropriate loader: consider using PadoDataLoader.")
            # ---------------------------------------------- #
            train_dataloader = self.on_train_epoch_start(train_dataloader)
            # ---------------------------------------------- #
            # don't forget to increase self.current_iter inside!
            self.train_epoch_body(train_dataloader)
            # ---------------------------------------------- #
            if isinstance(self.valid_interval_epochs, int) and (
                    self.current_epoch % self.valid_interval_epochs == 0):
                # valid per n epoch(s)
                self.valid(valid_dataloader, test_dataloader, track=True)
            elif 0 < self.valid_interval_epochs < 1:
                # valid inside train_body & at the end of every epoch
                self.valid(valid_dataloader, test_dataloader, track=True)
            # ---------------------------------------------- #
            self.on_train_epoch_end()
            # ---------------------------------------------- #
            self.current_epoch += 1
        # ------------------------------------------------------------ #
        self.on_train_end()
        # ------------------------------------------------------------ #
        self.test(test_dataloader, mode=self.final_test_mode)
        # ------------------------------------------------------------ #

    def valid(self,
              valid_dataloader: Optional[Iterable] = None,
              test_dataloader: Optional[Iterable] = None,
              *, track: bool = True):
        # ------------------------------------------------------------ #
        if valid_dataloader is None:
            valid_dataloader = self.valid_dataloader
        if test_dataloader is None:
            test_dataloader = self.test_dataloader

        if valid_dataloader is None:
            self.logger.warning("SKIP valid run because valid_dataloader is None.")
            return
        # ------------------------------------------------------------ #
        valid_dataloader = self.on_valid_start(valid_dataloader)
        # ------------------------------------------------------------ #
        self.valid_body(valid_dataloader, track=track)
        # ------------------------------------------------------------ #
        self.on_valid_end()
        # ------------------------------------------------------------ #
        if self.run_test_with_valid:
            if test_dataloader is None:
                self.logger.warning("Set run_test_with_valid ON, but test_dataloader is None.")
            self.test(test_dataloader)
        # ------------------------------------------------------------ #

    def test(self,
             test_dataloader: Optional[Iterable] = None,
             *, mode: str = "latest"):
        # ------------------------------------------------------------ #
        if test_dataloader is None:
            test_dataloader = self.test_dataloader

        if test_dataloader is None:
            self.logger.warning("SKIP test run because test_dataloader is None.")
            return
        # ------------------------------------------------------------ #
        mode = mode.lower()
        # WARNING: test with (mode != latest) loads parameter from past state, and do not recover current state.
        if not self._state_loaded:  # ignore mode when state is loaded
            if mode == "latest":
                self.logger.info("Test with 'latest' state_dict.")
            elif mode == "best":
                if self.ckpt_best_path is not None:
                    try:
                        self.load_state_dict(torch.load(self.ckpt_best_path, map_location=self.device),
                                             strict=True, model_only=False)
                        self.logger.info("Test with 'best' state_dict.")
                    except FileNotFoundError:  # best not yet created
                        self.logger.info("Test with 'best' state_dict requested but failed (no file), "
                                         "running with 'latest'.")
                else:
                    self.logger.info("Test with 'best' state_dict requested but failed (best not set), "
                                     "running with 'latest'.")
            elif mode == "swa":
                if (self.ckpt_swa_path is not None) and (self.swa_model is not None):
                    try:
                        self.load_state_dict(torch.load(self.ckpt_swa_path, map_location=self.device),
                                             strict=True, model_only=True)
                        self.logger.info("Test with 'swa' state_dict.")
                    except FileNotFoundError:  # swa not yet created
                        self.logger.info("Test with 'swa' state_dict requested but failed (no file), "
                                         "running with 'latest'.")
                else:
                    self.logger.info("Test with 'swa' state_dict requested but failed (swa not set), "
                                     "running with 'latest'.")
            else:
                self.logger.info(f"Test mode {mode} state_dict is invalid, running with 'latest'.")
        else:  # _state_loaded (from outside)
            self.logger.info(f"Test with loaded state_dict.")
        # ------------------------------------------------------------ #
        test_dataloader = self.on_test_start(test_dataloader)
        # ------------------------------------------------------------ #
        self.test_body(test_dataloader)
        # ------------------------------------------------------------ #
        self.on_test_end()
        # ------------------------------------------------------------ #

    def train_epoch_body(self, dataloader):
        self.model.train()
        self.epoch_tracker.reset()
        torch.set_grad_enabled(True)
        grad_norm = torch.tensor(0, dtype=torch.float32, device=self.device)  # placeholder

        # we decided to wrap dataloader every epoch
        dataloader = self.accelerator.prepare_data_loader(dataloader)

        # ------------------------------------------------------------ #
        for num_iter, batch in enumerate(dataloader):
            if self.current_iter >= self.max_iters:
                break
            # ------------------------------------------------------------ #
            self.time_tracker.update_and_reset("data", "forward")
            self.on_train_iter_start()
            # ------------------------------------------------------------ #
            # forward-backward
            if self.is_zero_grad_iter(num_iter):
                self.optimizer.zero_grad(set_to_none=True)

            output = self.train_iter_body(batch, num_iter, len(dataloader))
            self._current_acc_batches += 1
            self.time_tracker.update_and_reset("forward", "backward")

            if "loss" not in output.keys():
                raise ValueError("Wrapper does not contain 'loss' as key.")
            loss = output["loss"]
            if self.acc_num_batches > 1:
                loss = loss / self.acc_num_batches

            self.accelerator.backward(loss)
            self.time_tracker.update_and_reset("backward", "step")
            # ------------------------------------------------------------ #
            # gather and track
            output = self.gather(output)  # output will be stacked
            self.train_iter_track(output)
            # ------------------------------------------------------------ #
            # step
            if self.is_step_iter(num_iter):
                if self.grad_debugging:
                    self.check_grad()

                grad_norm = self.compute_grad_norm()
                self.optimizer.step()
                self.scheduler.step()
            # ------------------------------------------------------------ #
            # logging
            if self.is_master:
                # both logging and printing, don't forget to check by is_print_iter, is_log_iter.
                param_norm = self.compute_param_norm()
                self.train_iter_log_and_print(output, num_iter, len(dataloader), param_norm, grad_norm)
            # ------------------------------------------------------------ #
            # valid if needed
            if self.is_valid_iter(num_iter, len(dataloader)):
                self.valid()

                # restart
                self.model.train()
                self.epoch_tracker.reset()
                torch.set_grad_enabled(True)
            # ------------------------------------------------------------ #
            # iteration done
            self.on_train_iter_end()
            if self.is_step_iter(num_iter):
                self.current_iter += 1

            if self.is_master and (self.save_dir is not None):
                if (self.ckpt_save_interval_iters > 0) and (self.current_iter % self.ckpt_save_interval_iters == 0):
                    ckpt_interval_path = os.path.join(self.save_dir, f"interval-{self.current_iter}.ckpt")
                    self.logger.info(f"Save INTERVAL checkpoint to {ckpt_interval_path}.")
                    self.accelerator.save(self.state_dict(), ckpt_interval_path)
            self.time_tracker.reset("data")

        # ------------------------------------------------------------ #
        if self.is_master:
            if self.ckpt_save_latest:
                self.logger.info(f"Save LATEST checkpoint to {self.ckpt_latest_path}.")
                self.accelerator.save(self.state_dict(), self.ckpt_latest_path)
            if self.is_swa_update():
                self.swa_model.update_state(self.accelerator.unwrap_model(self.model))
                self.logger.info(
                    f"Save SWA checkpoint to {self.ckpt_swa_path}. (averaged count: {self.swa_model.count})")
                self.accelerator.save(self.swa_model.state_dict(), self.ckpt_swa_path)

    def valid_body(self, dataloader, *, track: bool = True):
        self.model.eval()
        self.epoch_tracker.reset()
        torch.set_grad_enabled(False)

        # we decided to wrap dataloader every epoch
        dataloader = self.accelerator.prepare_data_loader(dataloader)

        # ------------------------------------------------------------ #
        for num_iter, batch in enumerate(dataloader):
            # ------------------------------------------------------------ #
            self.time_tracker.update_and_reset("data", "forward")
            self.on_valid_iter_start()
            # ------------------------------------------------------------ #
            # forward
            output = self.valid_iter_body(batch, num_iter, len(dataloader))
            self.time_tracker.update("forward")
            # ------------------------------------------------------------ #
            # gather and track
            output = self.gather(output)  # output will be stacked
            for output_key, output_value in output.items():
                if "loss" in output_key:  # heuristic
                    output[output_key] = torch.mean(output_value).detach()
                else:
                    output[output_key] = torch.sum(output_value).detach()
            self.valid_iter_track(output)
            # ------------------------------------------------------------ #
            # logging
            if self.is_master and self.is_print_iter(num_iter):
                # only printing
                self.valid_iter_print(output, num_iter, len(dataloader))
            # ------------------------------------------------------------ #
            # iteration done
            self.on_valid_iter_end()
            self.time_tracker.reset("data")

        # ------------------------------------------------------------ #
        if self.is_master:
            self.valid_epoch_log_and_print()

            if self.ckpt_save_latest:
                self.logger.info(f"Save LATEST checkpoint to {self.ckpt_latest_path}.")
                self.accelerator.save(self.state_dict(), self.ckpt_latest_path)

        is_updated = track and self.valid_update_best()  # scheduler.update_best, and add wandb summary

        if self.is_master:
            if is_updated and self.ckpt_save_best:
                self.logger.info(f"Save BEST checkpoint to {self.ckpt_best_path}.")
                self.accelerator.save(self.state_dict(), self.ckpt_best_path)

    def test_body(self, dataloader):
        self.model.eval()
        self.epoch_tracker.reset()
        torch.set_grad_enabled(False)

        # we decided to wrap dataloader every epoch
        dataloader = self.accelerator.prepare_data_loader(dataloader)

        # ------------------------------------------------------------ #
        for num_iter, batch in enumerate(dataloader):
            # ------------------------------------------------------------ #
            self.time_tracker.update_and_reset("data", "forward")
            self.on_test_iter_start()
            # ------------------------------------------------------------ #
            # forward
            output = self.test_iter_body(batch, num_iter, len(dataloader))
            self.time_tracker.update("forward")
            # ------------------------------------------------------------ #
            # gather and track
            output = self.gather(output)  # output will be stacked
            for output_key, output_value in output.items():
                if "loss" in output_key:  # heuristic
                    output[output_key] = torch.mean(output_value).detach()
                else:
                    output[output_key] = torch.sum(output_value).detach()
            self.test_iter_track(output)
            # ------------------------------------------------------------ #
            # logging
            if self.is_master and self.is_print_iter(num_iter):
                # only printing
                self.test_iter_print(output, num_iter, len(dataloader))
            # ------------------------------------------------------------ #
            # iteration done
            self.on_test_iter_end()
            self.time_tracker.reset("data")

        # ------------------------------------------------------------ #
        if self.is_master:
            self.test_epoch_log_and_print()

    # -------------------------------------------------------------------------------------------------------- #
    # -------------------------------------------------------------------------------------------------------- #
    # Implement from here ...

    def train_iter_body(self, batch: Any,
                        num_iter: int, dataloader_length: int) -> Dict[str, Any]:
        pass

    def train_iter_track(self, output: Dict[str, Any]) -> None:
        pass

    def train_iter_log_and_print(self, output: Dict[str, Any],
                                 num_iter: int, dataloader_length: int,
                                 param_norm: torch.Tensor, grad_norm: torch.Tensor) -> None:
        pass

    def valid_iter_body(self, batch: Any,
                        num_iter: int, dataloader_length: int) -> Dict[str, Any]:
        pass

    def valid_iter_track(self, output: Dict[str, Any]) -> None:
        pass

    def valid_iter_print(self, output: Dict[str, Any],
                         num_iter: int, dataloader_length: int) -> None:
        pass

    def valid_epoch_log_and_print(self) -> None:
        pass

    def valid_update_best(self) -> bool:
        pass

    def test_iter_body(self, batch: Any,
                       num_iter: int, dataloader_length: int) -> Dict[str, Any]:
        pass

    def test_iter_track(self, output: Dict[str, Any]) -> None:
        pass

    def test_iter_print(self, output: Dict[str, Any],
                        num_iter: int, dataloader_length: int) -> None:
        pass

    def test_epoch_log_and_print(self) -> None:
        pass

    # ... to here.
    # -------------------------------------------------------------------------------------------------------- #
    # -------------------------------------------------------------------------------------------------------- #

    def on_train_start(self) -> None:
        s = "Train start!\n"
        # print simple statistic of network
        unwrapped_model = self.accelerator.unwrap_model(self.model)
        if self.verbose:
            s += param_log(unwrapped_model)
        param_num, param_count = count_params(unwrapped_model.parameters())
        param_norm = compute_param_norm(unwrapped_model.parameters())
        s += f"... Number of parameters: {param_num}, elements: {param_count}\n" \
             f"... Initial norm of parameters: {param_norm.item():.4f}\n"
        self.logger.info(s)

    def on_train_end(self) -> None:
        self.logger.info(f"Train done! "
                         f"Final epoch {self.current_epoch} / {self.max_epochs}, "
                         f"total iters {self.current_iter} / {self.max_iters}")

    def on_valid_start(self, dataloader):
        self.logger.info(f"Valid at epoch {self.current_epoch} / {self.max_epochs}, "
                         f"iter {self.current_iter} / {self.max_iters} start!")
        self.time_tracker.reset("valid")
        return dataloader

    def on_valid_end(self) -> None:
        valid_time = self.time_tracker.update("valid")
        self.logger.info(f"Valid done! (Time: {valid_time:.4f} s)")

    def on_test_start(self, dataloader):
        self.logger.info("Test start!")
        self.time_tracker.reset("test")
        return dataloader

    def on_test_end(self) -> None:
        test_time = self.time_tracker.update("test")
        self.logger.info(f"Test done! (Time: {test_time:.4f} s)")

    def on_train_epoch_start(self, dataloader):
        self.logger.info(f"Train epoch {self.current_epoch} / {self.max_epochs}, "
                         f"iter {self.current_iter} / {self.max_iters} start!")
        self.time_tracker.reset()  # reset all
        self._current_acc_batches = 0
        return dataloader

    def on_train_epoch_end(self) -> None:
        epoch_time = self.time_tracker.update("epoch")
        self.logger.info(f"Train epoch done! (Time: {epoch_time:.4f} s)")

    def on_train_iter_start(self) -> None:
        pass

    def on_train_iter_end(self) -> None:
        pass

    def on_valid_iter_start(self) -> None:
        pass

    def on_valid_iter_end(self) -> None:
        pass

    def on_test_iter_start(self) -> None:
        pass

    def on_test_iter_end(self) -> None:
        pass

    def state_dict(self) -> dict:
        state = dict()
        state["epoch"] = self.current_epoch
        state["iteration"] = self.current_iter

        state["network"] = self.accelerator.unwrap_model(self.model).state_dict()
        state["optimizer"] = self.optimizer.state_dict()
        state["scheduler"] = self.scheduler.state_dict()

        state["tracker"] = {
            "static": self.static_tracker.state_dict(),
            "epoch": self.epoch_tracker.state_dict(),
        }
        if self.fp16:
            state["grad_scaler"] = self.accelerator.scaler.state_dict()
        return state

    def load_state_dict(self,
                        state_dict: dict,
                        strict: bool = True,
                        model_only: bool = False,
                        load_keys: Optional[List] = None,
                        ignore_keys: Optional[List] = None) -> None:
        """Load state dict.
        Possible keys:
            [network, optimizer, scheduler, epoch, iteration, tracker, grad_scaler]

        model_only: 1st priority, force keys to [network]
        load_keys or ignore_keys: 2nd priority, exclusive. both non-None is prohibited.
        """
        if (load_keys is not None) and (ignore_keys is not None):
            raise ValueError(f"Load state dict load_keys and ignore_keys cannot be both activated.")

        default_keys = ["network", "optimizer", "scheduler", "epoch", "iteration", "tracker", "grad_scaler"]

        if model_only:
            load_keys = ["network"]
        elif load_keys is not None:
            load_keys = [k.lower() for k in load_keys]
        elif ignore_keys is not None:
            ignore_keys = [k.lower() for k in ignore_keys]
            load_keys = [k for k in default_keys if k not in ignore_keys]
        else:  # both are None
            load_keys = default_keys

        if "network" in load_keys:
            unwrapped_model = self.accelerator.unwrap_model(self.model)
            unwrapped_model.load_state_dict(state_dict["network"], strict=strict)
            self.logger.info(f"Load state dict 'network' (strict={strict}).")
        else:
            self.logger.warning(f"Load state dict does not contain 'network' as key.")

        if "epoch" in load_keys:
            self.current_epoch = state_dict.get("epoch", 0)
            self.logger.info(f"Load state dict 'epoch' (current_epoch: {self.current_epoch}).")
        if "iteration" in load_keys:
            self.current_iter = state_dict.get("iteration", 0)
            self.logger.info(f"Load state dict 'iteration' (current_iteration: {self.current_iter}).")

        if ("optimizer" in load_keys) and (self.optimizer is not None):
            if "optimizer" in state_dict:
                self.optimizer.load_state_dict(state_dict["optimizer"])
                self.logger.info(f"Load state dict 'optimizer'.")
            else:
                self.logger.info(f"Tried to load state dict 'optimizer', but does not exist.")

        if ("scheduler" in load_keys) and (self.scheduler is not None):
            if "scheduler" in state_dict:
                self.scheduler.load_state_dict(state_dict["scheduler"])
                self.logger.info(f"Load state dict 'scheduler'.")
            else:
                self.logger.info(f"Tried to load state dict 'scheduler', but does not exist.")

        # always sync with trainer steps
        if (self.scheduler is not None) and hasattr(self.scheduler, "_num_iters"):
            self.logger.info(f"Scheduler iteration synced to {self.current_iter}.")
            self.scheduler.set_num_iters(self.current_iter)

        if "grad_scaler" in load_keys:
            if self.fp16 and ("grad_scaler" in state_dict):
                self.accelerator.scaler.load_state_dict(state_dict["grad_scaler"])
                self.logger.info(f"Load state dict 'grad_scaler'.")

        if "tracker" in load_keys:
            if "tracker" in state_dict:
                self.static_tracker.load_state_dict(state_dict["tracker"]["static"])
                self.epoch_tracker.load_state_dict(state_dict["tracker"]["epoch"])

        self._state_loaded = True

    def resume(self, cfg: DictConfig) -> bool:
        # return true if resume success.

        if "resume" in cfg:  # check if config is the outer-most.
            cfg = copy.deepcopy(cfg["resume"])

        if cfg["from_scratch"] and (cfg["checkpoint"] is None):
            return False

        # start from checkpoint
        if cfg["checkpoint"] is None:
            raise ValueError(f"Resume ON, but checkpoint is None.")
        if not os.path.isfile(cfg["checkpoint"]):
            raise ValueError(f"Resume checkpoint {cfg['checkpoint']} does not exist.")

        self.load_state_dict(
            state_dict=torch.load(cfg["checkpoint"], map_location=self.device),
            strict=cfg.get("strict", True),
            model_only=cfg.get("model_only", False),
            load_keys=cfg.get("load_keys", None),
            ignore_keys=cfg.get("ignore_keys", None),
        )
        return True
