import sys
import warnings
from bisect import bisect_right

import torch
import torch.nn as nn
from torch.optim import lr_scheduler

import threestudio
from torch.optim.lr_scheduler import _LRScheduler

class ParamGroupScheduler(_LRScheduler):
    """
    仅对指定 param_group_index 的学习率应用 base_scheduler 的调整；
    其他 param groups 的 lr 在 step 后恢复到 step 前的数值。
    兼容不同 PyTorch 版本下 step(epoch) 签名差异。
    """
    def __init__(self, base_scheduler: _LRScheduler, param_group_index: int):
        self.base_scheduler = base_scheduler
        self.param_group_index = int(param_group_index)
        # 注意：必须把 optimizer 传给父类，这样 Lightning 会识别到 .optimizer
        super().__init__(
            optimizer=self.base_scheduler.optimizer,
            last_epoch=getattr(self.base_scheduler, "last_epoch", -1),
            verbose=getattr(self.base_scheduler, "verbose", False),
        )

    def step(self, *args, **kwargs):
        opt = self.base_scheduler.optimizer
        # 记录每个 group 的 lr（step 前）
        old_lrs = [g["lr"] for g in opt.param_groups]

        # 尝试透传 step 的参数；若不被底层接受，则降级为无参调用
        try:
            self.base_scheduler.step(*args, **kwargs)
        except TypeError:
            # 例如 SequentialLR 在某些版本不接收 epoch
            self.base_scheduler.step()

        # 恢复非目标 group 的 lr，仅保留目标 group 的变化
        for i, g in enumerate(opt.param_groups):
            if i != self.param_group_index:
                g["lr"] = old_lrs[i]

        # 同步 _last_lr，避免某些场景下对 get_last_lr 的依赖出问题
        self._last_lr = [pg["lr"] for pg in opt.param_groups]

    def get_last_lr(self):
        # 返回当前 optimizer 上每个 param group 的 lr
        return [pg["lr"] for pg in self.optimizer.param_groups]

    # 让 state_dict/restore 直通底层（不做额外处理，保持与原调度器一致）
    def state_dict(self):
        return self.base_scheduler.state_dict()

    def load_state_dict(self, state_dict):
        self.base_scheduler.load_state_dict(state_dict)



def get_scheduler(name):
    if hasattr(lr_scheduler, name):
        return getattr(lr_scheduler, name)
    else:
        raise NotImplementedError


def getattr_recursive(m, attr):
    for name in attr.split("."):
        m = getattr(m, name)
    return m


def get_parameters(model, name):
    module = getattr_recursive(model, name)
    if isinstance(module, nn.Module):
        return module.parameters()
    elif isinstance(module, nn.Parameter):
        return module
    return []


def parse_optimizer(config, model):
    if hasattr(config, "params"):
        params = [
            {"params": get_parameters(model, name), "name": name, **args}
            for name, args in config.params.items()
        ]
        threestudio.debug(f"Specify optimizer params: {config.params}")
    else:
        params = model.parameters()
    if config.name in ["FusedAdam"]:
        import apex

        optim = getattr(apex.optimizers, config.name)(params, **config.args)
    elif config.name in ["Adan"]:
        from threestudio.systems import optimizers

        optim = getattr(optimizers, config.name)(params, **config.args)
    else:
        optim = getattr(torch.optim, config.name)(params, **config.args)
    return optim


def parse_scheduler_to_instance(config, optimizer):
    if config.name == "ChainedScheduler":
        schedulers = [
            parse_scheduler_to_instance(conf, optimizer) for conf in config.schedulers
        ]
        scheduler = lr_scheduler.ChainedScheduler(schedulers)
    elif config.name == "Sequential":
        schedulers = [
            parse_scheduler_to_instance(conf, optimizer) for conf in config.schedulers
        ]
        scheduler = lr_scheduler.SequentialLR(
            optimizer, schedulers, milestones=config.milestones
        )
    else:
        scheduler = getattr(lr_scheduler, config.name)(optimizer, **config.args)
    return scheduler


def parse_scheduler(config, optimizer):
    interval = config.get("interval", "epoch")
    assert interval in ["epoch", "step"]
    if config.name == "SequentialLR":
        scheduler = {
            "scheduler": lr_scheduler.SequentialLR(
                optimizer,
                [
                    parse_scheduler(conf, optimizer)["scheduler"]
                    for conf in config.schedulers
                ],
                milestones=config.milestones,
            ),
            "interval": interval,
        }
    elif config.name == "ChainedScheduler":
        scheduler = {
            "scheduler": lr_scheduler.ChainedScheduler(
                [
                    parse_scheduler(conf, optimizer)["scheduler"]
                    for conf in config.schedulers
                ]
            ),
            "interval": interval,
        }
    else:
        scheduler = {
            "scheduler": get_scheduler(config.name)(optimizer, **config.args),
            "interval": interval,
        }
    
    if hasattr(config, "param_group_index"):
        base_obj = scheduler["scheduler"]               # ← 取出对象
        wrapped = ParamGroupScheduler(base_obj, config.param_group_index)
        scheduler["scheduler"] = wrapped                # ← 放回去
        # 不要再往外抛 param_group / param_group_index 这种自定义键
        # Lightning 不认识，容易引发后续解析问题
        
    return scheduler
