import re
from typing import List, Optional, Set

import torch.nn as nn
from transformers import Trainer
from transformers.utils import logging
from packaging import version

try:
    from accelerate.utils import DummyOptim
except Exception:
    DummyOptim = None

try:
    from transformers.utils import is_sagemaker_mp_enabled
except Exception:
    try:
        from transformers.trainer_utils import is_sagemaker_mp_enabled
    except Exception:
        def is_sagemaker_mp_enabled():
            return False
from .skipping_trainer import TrainerDataSkipping

if is_sagemaker_mp_enabled():
    import smdistributed.modelparallel.torch as smp
    from smdistributed.modelparallel import __version__ as SMP_VERSION

    IS_SAGEMAKER_MP_POST_1_10 = version.parse(
        SMP_VERSION) >= version.parse("1.10")

else:
    IS_SAGEMAKER_MP_POST_1_10 = False
    smp = None

log = logging.get_logger(__name__)


class LayerwiseLRTrainer(TrainerDataSkipping):
    def __init__(
        self,
        *args,
        target_name_patterns: Optional[List[str]] = None,
        boosted_lr: Optional[float] = None,
        boost_factor: float = 2.0,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self._compiled_pats = [
            re.compile(p) for p in (target_name_patterns or [])
        ]
        self._boosted_lr = boosted_lr
        self._boost_factor = boost_factor

    
    def _is_target(self, name: str) -> bool:
        if not self._compiled_pats:
            return False
        for pat in self._compiled_pats:
            if pat.search(name):
                return True
        return False

    def _base_and_boosted_lr(self):
        base_lr = float(self.args.learning_rate)
        if self._boosted_lr is not None:
            boosted_lr = float(self._boosted_lr)
        else:
            boosted_lr = base_lr * float(self._boost_factor)
        return base_lr, boosted_lr

    def _build_param_groups(self, opt_model):
        decay_names = self.get_decay_parameter_names(opt_model)

        base_lr, boosted_lr = self._base_and_boosted_lr()

        named_params = [
            (n, p) for n, p in opt_model.named_parameters()
            if p.requires_grad
        ]

        def pick(use_decay: bool, use_target: bool):
            out = []
            for n, p in named_params:
                in_decay = n in decay_names
                is_target = self._is_target(n)
                if in_decay != use_decay:
                    continue
                if is_target != use_target:
                    continue
                out.append(p)
            return out

        groups = [
            {
                "params": pick(use_decay=True, use_target=False),
                "weight_decay": self.args.weight_decay,
                "lr": base_lr,
            },
            {
                "params": pick(use_decay=False, use_target=False),
                "weight_decay": 0.0,
                "lr": base_lr,
            },
            {
                "params": pick(use_decay=True, use_target=True),
                "weight_decay": self.args.weight_decay,
                "lr": boosted_lr,
            },
            {
                "params": pick(use_decay=False, use_target=True),
                "weight_decay": 0.0,
                "lr": boosted_lr,
            },
        ]

        groups = [g for g in groups if g["params"]]

        seen: Set[int] = set()
        for g in groups:
            uniq = []
            for p in g["params"]:
                pid = id(p)
                if pid in seen:
                    continue
                seen.add(pid)
                uniq.append(p)
            g["params"] = uniq

        boosted_cnt = sum(len(g["params"]) for g in groups[2:])
        total_cnt = sum(len(g["params"]) for g in groups)
        if boosted_cnt == 0 and self._compiled_pats:
            log.warning("no parameters matched target_name_patterns")
        log.info(
            "param groups ready (base=%g, boosted=%g): %d/%d boosted",
            base_lr,
            boosted_lr,
            boosted_cnt,
            total_cnt,
        )
        return groups

    def create_optimizer(self):
        opt_model = (
            self.model_wrapped if is_sagemaker_mp_enabled() else self.model
        )

        if self.optimizer is not None:
            return self.optimizer

        
        if self.args.fsdp and not (
            getattr(self.args, "fsdp_config", None) or {}
        ).get("use_orig_params", False):
            log.warning(
                "fsdp without use_orig_params=True may break name-based "
                "grouping"
            )

        optimizer_grouped_parameters = self._build_param_groups(opt_model)

        ds_plugin = getattr(self.accelerator.state, "deepspeed_plugin", None)
        ds_cfg = getattr(ds_plugin, "deepspeed_config",
                         {}) if ds_plugin else {}
        ds_has_optimizer = (
            isinstance(ds_cfg, dict) and "optimizer" in ds_cfg
        )

        if ds_plugin and ds_has_optimizer:
            if DummyOptim is None:
                raise RuntimeError(
                    "deepspeed config defines an optimizer; "
                    "accelerate DummyOptim is required"
                )
            base_lr, _ = self._base_and_boosted_lr()
            self.optimizer = DummyOptim(
                optimizer_grouped_parameters,
                lr=base_lr,
                weight_decay=self.args.weight_decay,
            )
        else:
            if self.optimizer_cls_and_kwargs is not None:
                optimizer_cls, optimizer_kwargs = (
                    self.optimizer_cls_and_kwargs
                )
            else:
                optimizer_cls, optimizer_kwargs = (
                    self.get_optimizer_cls_and_kwargs(self.args, opt_model)
                )

            if "params" in optimizer_kwargs:
                optimizer_grouped_parameters = optimizer_kwargs.pop("params")
            if "model" in optimizer_kwargs:
                optimizer_grouped_parameters = optimizer_kwargs.pop("model")
            if "optimizer_dict" in optimizer_kwargs:
                optimizer_grouped_parameters = optimizer_kwargs.pop(
                    "optimizer_dict"
                )

            self.optimizer = optimizer_cls(
                optimizer_grouped_parameters, **optimizer_kwargs
            )

            if (
                "bitsandbytes" in str(optimizer_cls)
                and optimizer_kwargs.get("optim_bits", None) == 8
            ):
                import bitsandbytes

                manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
                skipped = 0
                for module in opt_model.modules():
                    if isinstance(module, nn.Embedding):
                        skipped += sum(
                            {
                                p.data_ptr(): p.numel()
                                for p in module.parameters()
                            }.values()
                        )
                        log.info(
                            f"skipped {module}: {skipped / 2**20}M params"
                        )
                        manager.register_module_override(
                            module, "weight", {"optim_bits": 32}
                        )
                        log.debug("bitsandbytes: will optimize "
                                  f"{module} in fp32")
                log.info(f"skipped: {skipped / 2**20}M params")

        if is_sagemaker_mp_enabled() and smp is not None:
            self.optimizer = smp.DistributedOptimizer(self.optimizer)

        return self.optimizer
