import functools
import logging
import os
from typing import TYPE_CHECKING, Any, List, Mapping, Optional, TypeVar

import lightning as L
import torch
from lightning.fabric.connector import _is_using_cli
from lightning.fabric.loggers import TensorBoardLogger
from lightning.fabric.utilities.rank_zero import rank_zero_only
from omegaconf import DictConfig, OmegaConf

from fusion_bench.constants import RuntimeConstants
from fusion_bench.utils import import_object
from fusion_bench.utils.instantiate_utils import instantiate

if TYPE_CHECKING:
    import lightning.fabric.loggers.tensorboard
    from lightning.fabric.strategies import FSDPStrategy

log = logging.getLogger(__name__)

TensorOrModule = TypeVar("TensorOrModule", torch.Tensor, torch.nn.Module, Any)


def get_policy(*args: str) -> set:
    """
    Get the policy from the provided list of policy names.

    Args:
        *args (str): A list of policy names.

    Returns:
        set: A set of policy objects.
    """
    return {import_object(arg) for arg in args}


def get_size_based_auto_wrap_policy(*args, **kwargs):
    from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy

    policy = functools.partial(size_based_auto_wrap_policy, *args, **kwargs)
    return policy


class LightningFabricMixin:
    """
    A mixin class for integrating Lightning Fabric into a project.

    This class provides methods to initialize and manage a Lightning Fabric instance for distributed computing,
    including setup with optional logging, device management for tensors and modules, and hyperparameter logging.
    It leverages the Lightning framework to facilitate distributed training and inference across multiple devices
    and nodes, with support for custom logging via TensorBoard.

    Attributes:

    - _fabric (L.Fabric): The Lightning Fabric instance used for distributed computing.

    Note:

    This mixin is designed to be used with classes that require distributed computing capabilities and wish to
    leverage the Lightning Fabric for this purpose. It assumes the presence of a `config` attribute or parameter
    in the consuming class for configuration.
    """

    _fabric_instance: L.Fabric = None

    def setup_lightning_fabric(self, config: DictConfig):
        """
        Initializes and launches the Lightning Fabric with optional logging.

        This method sets up the Lightning Fabric for distributed computing based on the provided configuration. If a fabric
        configuration is not found, it logs a warning and exits. Optionally, if a fabric logger configuration is provided,
        it initializes a TensorBoardLogger with the specified settings.

        Expected configuration keys:
        - fabric: The configuration for the Lightning Fabric.
        - fabric.loggers: The configuration for the TensorBoardLogger.
        """
        if self._fabric_instance is None:
            if config.get("fabric", None) is None:
                log.warning("No fabric configuration found. use default settings.")
                self._fabric_instance = L.Fabric()
            else:
                self._fabric_instance = instantiate(config.fabric)
            if not _is_using_cli():  # if not using cli, launch the fabric
                self._fabric_instance.launch()
            # Set the log directory in config if it is not already set
            if (
                self.log_dir is not None
                and hasattr(config, "log_dir")
                and config.get("log_dir", None) is None
            ):
                if self._fabric_instance.is_global_zero:
                    log.info(f"Setting log_dir to {self.log_dir}")
                config.log_dir = self.log_dir

    @property
    def fabric(self):
        """
        Get the Lightning Fabric instance, initializing it if necessary.

        Returns:
            L.Fabric: The Lightning Fabric instance for distributed computing.
        """
        if self._fabric_instance is None:
            self.setup_lightning_fabric(getattr(self, "config", DictConfig({})))
        return self._fabric_instance

    @fabric.setter
    def fabric(self, instance: L.Fabric):
        """
        Set the Lightning Fabric instance.

        Args:
            instance: The Lightning Fabric instance to use.
        """
        self._fabric_instance = instance

    @property
    def log_dir(self):
        """
        Retrieves the log directory from the fabric's logger.
        """
        if self.fabric is not None and len(self.fabric._loggers) > 0:
            log_dir = self.fabric.logger.log_dir

            # Special handling for SwanLabLogger to get the correct log directory
            if (
                log_dir is None
                and self.fabric.logger.__class__.__name__ == "SwanLabLogger"
            ):
                log_dir = self.fabric.logger.save_dir or self.fabric.logger._logdir

            assert log_dir is not None, "log_dir should not be None"
            if self.fabric.is_global_zero and not os.path.exists(log_dir):
                os.makedirs(log_dir, exist_ok=True)
            return log_dir
        else:
            return None

    def to_device(self, obj: TensorOrModule) -> TensorOrModule:
        """
        Moves a tensor or module to the proper device.

        Args:
            obj (TensorOrModule): The tensor or module to move to the device.

        Returns:
            TensorOrModule: the same type of object as the input, moved to the device.
        """
        return self.fabric.to_device(obj)

    @rank_zero_only
    def log_hyperparams(
        self,
        config: Optional[DictConfig] = None,
        save_dir: Optional[str] = None,
        filename: str = "config.yaml",
    ):
        R"""
        Logs the hyperparameters and saves the configuration to a YAML file.
        The YAML file is saved in the log directory by default with the name `config.yaml`, or in the specified save directory `save_dir`.

        Args:
            config (Optional[DictConfig]): The configuration to log and save. If not provided, the class's `config` attribute is used.
            save_dir (Optional[str]): The directory in which to save the configuration file. If not provided, the log directory is used.
            filename (str): The name of the configuration file. Default is `config.yaml`.
        """
        if config is None:
            config = self.config
        if save_dir is None:
            save_dir = self.log_dir
        self.fabric.logger.log_hyperparams(
            OmegaConf.to_container(config, resolve=True, enum_to_str=True)
        )
        if not os.path.exists(save_dir):
            os.makedirs(save_dir, exist_ok=True)
        OmegaConf.save(
            config,
            os.path.join(self.log_dir if save_dir is None else save_dir, filename),
        )

    @property
    def tensorboard_summarywriter(
        self,
    ) -> "lightning.fabric.loggers.tensorboard.SummaryWriter":
        """
        Get the TensorBoard SummaryWriter for detailed logging.

        Returns:
            SummaryWriter: The TensorBoard SummaryWriter instance.

        Raises:
            AttributeError: If the logger is not a TensorBoardLogger.
        """
        if isinstance(self.fabric.logger, TensorBoardLogger):
            return self.fabric.logger.experiment
        else:
            raise AttributeError("the logger is not a TensorBoardLogger.")

    @property
    def is_debug_mode(self):
        """
        Check if the program is running in debug mode (fast_dev_run).

        Returns:
            bool: True if fast_dev_run is enabled, False otherwise.
        """
        return RuntimeConstants().debug

    def log(self, name: str, value: Any, step: Optional[int] = None):
        """
        Logs a single metric to the fabric's logger.

        Args:
            name: The name of the metric to log.
            value: The value of the metric.
            step: Optional step number for the metric.
        """
        self.fabric.log(name, value, step=step)

    def log_dict(self, metrics: Mapping[str, Any], step: Optional[int] = None):
        """
        Logs multiple metrics to the fabric's logger.

        Args:
            metrics: Dictionary of metric names and values.
            step: Optional step number for the metrics.
        """
        self.fabric.log_dict(metrics, step=step)

    def log_optimizer_lr(
        self,
        optimizer: torch.optim.Optimizer,
        step: Optional[int] = None,
        name_template: str = "train/lr_group_{0}",
    ):
        """
        Logs the learning rate of each parameter group in the optimizer.

        Args:
            optimizer: The optimizer whose learning rates should be logged.
            step: Optional step number for the log entry.
            name_template: Template string for the log name. Use {0} as placeholder for group index.
        """
        for i, param_group in enumerate(optimizer.param_groups):
            self.fabric.log(name_template.format(i), param_group["lr"], step=step)
