from typing import Tuple

from dataclasses import fields as datafields
from functools import partialmethod
import sys
import logging
import timeit
import torch
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.cli import LightningCLI


class Timer(object):
    def __enter__(self):
        self.t_start = timeit.default_timer()
        return self

    def __exit__(self, _1, _2, _3):
        self.t_end = timeit.default_timer()
        self.dt = self.t_end - self.t_start


# XXX
def dataclass_from_dict(klass, dikt):
    try:
        fieldtypes = {f.name: f.type for f in datafields(klass)}
        return klass(**{f: dataclass_from_dict(fieldtypes[f], dikt[f]) for f in dikt})
    except:  # noqa: E722
        return dikt


def get_logger(name=__name__) -> logging.Logger:
    """Initializes multi-GPU-friendly python command line logger."""

    logger = logging.getLogger(name)

    # this ensures all logging levels get marked with the rank zero decorator
    # otherwise logs would get multiplied for each GPU process in multi-GPU setup
    for level in (
        "debug",
        "info",
        "warning",
        "error",
        "exception",
        "fatal",
        "critical",
    ):
        setattr(logger, level, rank_zero_only(getattr(logger, level)))

    return logger


def bootstrap(x: torch.Tensor, Nboot: int, binsize: int) -> Tuple[torch.Tensor, torch.Tensor]:
    """Bootstrapping the mean of tensor

    Args:
        x (torch.Tensor): _description_
        Nboot (int): _description_
        binsize (int): _description_

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: bootstrapped mean and bootstrapped variance
    """
    boots = []
    x = x.reshape(-1, binsize, *x.shape[1:])
    for i in range(Nboot):
        boots.append(torch.mean(x[torch.randint(len(x), (len(x),))], axis=(0, 1)))
    return torch.tensor(boots).mean(), torch.tensor(boots).std()


# From https://stackoverflow.com/questions/38911146/python-equivalent-of-functools-partial-for-a-class-constructor
def partialclass(name, cls, *args, **kwds):
    new_cls = type(name, (cls,), {"__init__": partialmethod(cls.__init__, *args, **kwds)})

    # The following is copied nearly ad verbatim from `namedtuple's` source.
    """
    # For pickling to work, the __module__ variable needs to be set to the frame
    # where the named tuple is created.  Bypass this step in enviroments where
    # sys._getframe is not defined (Jython for example) or sys._getframe is not
    # defined for arguments greater than 0 (IronPython).
    """
    try:
        new_cls.__module__ = sys._getframe(1).f_globals.get("__name__", "__main__")
    except (AttributeError, ValueError):
        pass

    return new_cls


# CPU device metrics
_CPU_VM_PERCENT = "cpu_vm_percent"
_CPU_PERCENT = "cpu_percent"
_CPU_SWAP_PERCENT = "cpu_swap_percent"


def get_cpu_stats():
    import psutil

    return {
        _CPU_VM_PERCENT: psutil.virtual_memory().percent,
        _CPU_PERCENT: psutil.cpu_percent(),
        _CPU_SWAP_PERCENT: psutil.swap_memory().percent,
    }


class FNOCLI(LightningCLI):
    def add_arguments_to_parser(self, parser) -> None:
        parser.link_arguments("data.time_gap", "model.time_gap")
        parser.link_arguments("data.time_history", "model.time_history")
        parser.link_arguments("data.time_future", "model.time_future")
        parser.link_arguments("data.pde", "model.pdeconfig")
        parser.link_arguments("data.usegrid", "model.usegrid")


# @rank_zero_only
# def log_hyperparameters(
#     config,
#     model: pl.LightningModule,
#     datamodule: pl.LightningDataModule,
#     trainer: pl.Trainer,
#     callbacks: List[pl.Callback],
#     logger: List[pl.loggers.LightningLoggerBase],
# ) -> None:
#     """Controls which config parts are saved by Lightning loggers.
#     Additionaly saves:
#     - number of model parameters
#     """

#     if not trainer.logger:
#         return

#     hparams = {}

#     # choose which parts of hydra config will be saved to loggers
#     hparams["model"] = config["model"]

#     # save number of model parameters
#     hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
#     hparams["model/params/trainable"] = sum(
#         p.numel() for p in model.parameters() if p.requires_grad
#     )
#     hparams["model/params/non_trainable"] = sum(
#         p.numel() for p in model.parameters() if not p.requires_grad
#     )

#     hparams["datamodule"] = config["datamodule"]
#     hparams["trainer"] = config["trainer"]

#     if "seed" in config:
#         hparams["seed"] = config["seed"]
#     if "callbacks" in config:
#         hparams["callbacks"] = config["callbacks"]

#     # send hparams to all loggers
#     trainer.logger.log_hyperparams(hparams)
