from typing import Optional, List, Dict, Any, Iterable
import os
import copy
import logging

from omegaconf import DictConfig
import torch
import torch.cuda
from accelerate.utils import gather

from pado.core import PadoModule
from pado.core.accelerator import PadoAccelerator
from pado.data.dataloader import PadoDataLoader

from pado.core.tracker import *
from pado.utils.param_utils import *

__all__ = ["PadoEvaluator"]


class PadoEvaluator(object):

    def __init__(self,
                 model: PadoModule,
                 accelerator: PadoAccelerator,
                 save_dir: Optional[str] = None,
                 logger: Optional = None,
                 dataloader: Optional[PadoDataLoader] = None,
                 print_interval_iters: int = 50,
                 *,
                 verbose: bool = True) -> None:
        # -------------------------------------------------------------------------------------- #
        self.save_dir = save_dir
        self.verbose = verbose

        if logger is None:
            logger = logging.getLogger("pado")
        self.logger = logger

        self.print_interval_iters = print_interval_iters

        # -------------------------------------------------------------------------------------- #
        # accelerate setup
        self.accelerator = accelerator
        self.fp16 = self.accelerator.use_fp16
        self.logger.info(f"Accelerator setup:\n"
                         f"... device: {accelerator.device} (fp16: {accelerator.use_fp16})\n"
                         f"... world_size: {accelerator.world_size}, local_rank: {accelerator.local_rank}\n")

        self.model = self.accelerator.prepare_model(model)

        self.device = self.accelerator.device
        self.is_master = self.accelerator.is_local_main_process
        self.world_size = self.accelerator.num_processes
        self.local_rank = self.accelerator.process_index

        # -------------------------------------------------------------------------------------- #
        if save_dir is not None:
            if self.is_master:
                os.makedirs(save_dir, exist_ok=True)
        self.save_dir = save_dir

        # -------------------------------------------------------------------------------------- #
        # dataloader
        self.dataloader = dataloader

        # -------------------------------------------------------------------------------------- #
        # tracker setup
        self.time_tracker = TimeTrackerDict("infer", "iter", "data", "forward", "backward", "step")
        self.static_tracker = MetricTrackerDict("num_samples")  # tracker that is not reset through epoch
        self.epoch_tracker = MetricTrackerDict("loss")  # tracker that is reset occasionally.

        # -------------------------------------------------------------------------------------- #
        # extra
        self.init_additional()
        self._state_loaded = False

    def init_additional(self):
        # can be override by child class
        pass

    def is_print_iter(self, num_iter: int) -> bool:
        # helper to use within train_epoch_body, valid_body
        return (num_iter % self.print_interval_iters == 0) and (num_iter > 0)

    def gather(self, output):
        # can be override, default is to use accelerate.utils.gather
        output = gather(output)
        for output_key, output_value in output.items():
            if "loss" in output_key.lower():  # heuristic
                output[output_key] = torch.mean(output_value).detach()
            else:
                output[output_key] = torch.sum(output_value).detach()
        return output

    def infer(self,
              dataloader: Optional[Iterable] = None):
        # ------------------------------------------------------------ #
        if dataloader is None:
            dataloader = self.dataloader

        if dataloader is None:
            self.logger.warning("SKIP infer run because dataloader is None,")
            return

        if not self._state_loaded:
            self.logger.warning("Parameters are not loaded. Is this intended?")
        # ------------------------------------------------------------ #
        dataloader = self.on_infer_start(dataloader)
        # ------------------------------------------------------------ #
        self.infer_body(dataloader)
        # ------------------------------------------------------------ #
        self.on_infer_end()
        # ------------------------------------------------------------ #

    def infer_body(self, dataloader):
        self.model.eval()
        self.epoch_tracker.reset()
        torch.set_grad_enabled(False)

        dataloader = self.accelerator.prepare_data_loader(dataloader)

        # ------------------------------------------------------------ #
        for num_iter, batch in enumerate(dataloader):
            # ------------------------------------------------------------ #
            self.time_tracker.update_and_reset("data", "forward")
            self.on_infer_iter_start()
            # ------------------------------------------------------------ #
            # forward
            output = self.infer_iter_body(batch, num_iter, len(dataloader))
            self.time_tracker.update("forward")
            # ------------------------------------------------------------ #
            # gather and track
            output = self.gather(output)  # output will be stacked
            for output_key, output_value in output.items():
                if "loss" in output_key:  # heuristic
                    output[output_key] = torch.mean(output_value).detach()
                else:
                    output[output_key] = torch.sum(output_value).detach()
            self.infer_iter_track(output)
            # ------------------------------------------------------------ #
            # logging
            if self.is_master and self.is_print_iter(num_iter):
                # only printing
                self.infer_iter_print(output, num_iter, len(dataloader))
            # ------------------------------------------------------------ #
            # iteration done
            self.on_infer_iter_end()
            self.time_tracker.reset("data")

        # ------------------------------------------------------------ #
        if self.is_master:
            self.infer_epoch_log_and_print()

    # -------------------------------------------------------------------------------------------------------- #
    # -------------------------------------------------------------------------------------------------------- #
    # Implement from here ...

    def infer_iter_body(self, batch: Any,
                        num_iter: int, dataloader_length: int) -> Dict[str, Any]:
        pass

    def infer_iter_track(self, output: Dict[str, Any]) -> None:
        pass

    def infer_epoch_log_and_print(self) -> None:
        pass

    def infer_iter_print(self, output: Dict[str, Any],
                         num_iter: int, dataloader_length: int) -> None:
        pass

    # ... to here.
    # -------------------------------------------------------------------------------------------------------- #
    # -------------------------------------------------------------------------------------------------------- #

    def on_infer_start(self, dataloader):
        s = "Inference start!\n"
        # print simple statistic of network
        unwrapped_model = self.accelerator.unwrap_model(self.model)
        if self.verbose:
            s += param_log(unwrapped_model)
        param_num, param_count = count_params(unwrapped_model.parameters())
        param_norm = compute_param_norm(unwrapped_model.parameters())
        s += f"... Number of parameters: {param_num}, elements: {param_count}\n" \
             f"... Norm of parameters: {param_norm.item():.4f}\n"
        self.logger.info(s)
        self.time_tracker.reset("infer")
        return dataloader

    def on_infer_end(self) -> None:
        infer_time = self.time_tracker.update("infer")
        self.logger.info(f"Inference done! (Time: {infer_time:.4f} s)")

    def on_infer_iter_start(self) -> None:
        pass

    def on_infer_iter_end(self) -> None:
        pass

    def load_state_dict(self,
                        state_dict: dict,
                        strict: bool = True):
        """
        Load state dict for inference. Only load network.
        """
        if "network" in state_dict:
            unwrapped_model = self.accelerator.unwrap_model(self.model)
            unwrapped_model.load_state_dict(state_dict["network"], strict=strict)
            self.logger.info(f"Load state dict 'network' (strict={strict}).")
        else:
            self.logger.warning(f"Load state dict does not contain 'network' as key.")

        self._state_loaded = True

    def resume(self, cfg: DictConfig) -> bool:
        # return true if resume success.

        if "resume" in cfg:  # check if config is the outer-most.
            cfg = copy.deepcopy(cfg["resume"])

        if cfg["from_scratch"] and (cfg["checkpoint"] is None):
            return False

        # start from checkpoint
        if cfg["checkpoint"] is None:
            raise ValueError(f"Resume ON, but checkpoint is None.")
        if not os.path.isfile(cfg["checkpoint"]):
            raise ValueError(f"Resume checkpoint {cfg['checkpoint']} does not exist.")

        self.load_state_dict(
            state_dict=torch.load(cfg["checkpoint"], map_location=self.device),
            strict=cfg.get("strict", True),
        )
        return True
