import math

import numbers

import torch.nn as nn

from typing import List, Dict, Any, Mapping, Optional, Tuple



def get_stage_parameters(model: nn.Module, stage_name: str) -> List[nn.Parameter]:

    """
    Get the parameters for a specific stage of the DNAChunker model.
    """

    params = []





    stage_patterns = {

        'embeddings': ['net.backbone.embeddings'],

        'encoder1': ['net.backbone.encoder1_layers', 'net.backbone.routing_module_stage1'],

        'encoder2': ['net.backbone.encoder2_layers', 'net.backbone.routing_module_stage2'],

        'main': ['net.backbone.main_model'],

        'decoder1': ['net.backbone.decoder1_layers', 'net.backbone.dechunker2'],

        'decoder2': ['net.backbone.decoder2_layers', 'net.backbone.dechunker1'],

        'norm': ['net.backbone.norm_f'],

        'head': ['lm_head'],

    }



    if stage_name not in stage_patterns:

        return []



    for name, p in model.named_parameters():

        if any(name.startswith(pattern) for pattern in stage_patterns[stage_name]):

            params.append(p)



    return params



def get_stage_named_parameters(model: nn.Module, stage_name: str) -> List[Tuple[str, nn.Parameter]]:

    """
    Get the named parameters for a specific stage of the DNAChunker model.
    """

    named_params = []



    stage_patterns = {

        'embeddings': ['net.backbone.embeddings'],

        'encoder1': ['net.backbone.encoder1_layers', 'net.backbone.routing_module_stage1'],

        'encoder2': ['net.backbone.encoder2_layers', 'net.backbone.routing_module_stage2'],

        'main': ['net.backbone.main_model'],

        'decoder1': ['net.backbone.decoder1_layers', 'net.backbone.dechunker2'],

        'decoder2': ['net.backbone.decoder2_layers', 'net.backbone.dechunker1'],

        'norm': ['net.backbone.norm_f'],

        'head': ['lm_head'],

    }



    if stage_name not in stage_patterns:

        return []



    for name, p in model.named_parameters():

        if any(name.startswith(pattern) for pattern in stage_patterns[stage_name]):

            named_params.append((name, p))



    return named_params



def _is_no_decay_param(name: str, stage_name: Optional[str] = None) -> bool:

    if stage_name in {"embeddings", "norm"}:

        return True

    if name.endswith(".bias") or name == "bias":

        return True

    lowered = name.lower()

    if "norm" in lowered or "layernorm" in lowered or "rmsnorm" in lowered:

        return True

    if "embedding" in lowered:

        return True

    return False



def _normalize_lambda_value(stage: str, value: Any) -> float:

    if isinstance(value, (list, tuple)):

        if len(value) != 1:

            raise ValueError(f"Stage LR lambda for {stage} must be a single float, got {value}")

        value = value[0]

    if isinstance(value, str):

        try:

            value = float(value)

        except ValueError as exc:

            raise ValueError(f"Stage LR lambda for {stage} must be numeric, got {value}") from exc

    if isinstance(value, bool) or not isinstance(value, numbers.Real):

        raise TypeError(f"Stage LR lambda for {stage} must be numeric, got {type(value)}")

    return float(value)



def _get_model_d_model(model: nn.Module) -> int:

    if hasattr(model, "net") and hasattr(model.net, "backbone") and hasattr(model.net.backbone, "config"):

        return model.net.backbone.config.d_model

    if hasattr(model, "config") and hasattr(model.config, "d_model"):

        return model.config.d_model

    raise ValueError("Unable to infer d_model from model; pass stage_wise_lr_dims overrides.")



def _validate_ratio(name: str, value: float) -> float:

    if value is None or value <= 0:

        raise ValueError(f"{name} must be > 0, got {value}")

    return value



def _stage_batch_ratio(target_ratio_stage1: float, target_ratio_stage2: float) -> Dict[str, float]:

    target_ratio_stage1 = _validate_ratio("target_ratio_stage1", target_ratio_stage1)

    target_ratio_stage2 = _validate_ratio("target_ratio_stage2", target_ratio_stage2)

    l2_ratio = target_ratio_stage1 * target_ratio_stage2

    if l2_ratio <= 0:

        raise ValueError("target_ratio_stage1 * target_ratio_stage2 must be > 0")



    l0_over_l2 = 1.0 / l2_ratio

    l1_over_l2 = 1.0 / target_ratio_stage2



    return {

        "embeddings": l0_over_l2,

        "encoder1": l0_over_l2,

        "encoder2": l1_over_l2,

        "main": 1.0,

        "decoder1": l1_over_l2,

        "decoder2": l0_over_l2,

        "norm": l0_over_l2,

        "head": l0_over_l2,

    }



def compute_stage_lr_lambdas(

    model: nn.Module,

    target_ratio_stage1: float,

    target_ratio_stage2: float,

    stage_dim_overrides: Optional[Mapping[str, int]] = None,

) -> Dict[str, float]:

    """
    Compute stage-wise LR scaling factors using:
        lambda = sqrt(B_s / B_ref) * sqrt(D_ref / D_s)
    where B_s is the effective batch size for stage s (sequence length relative to main),
    and D_s is the hidden dimension for that stage.
    """

    stage_dim_overrides = stage_dim_overrides or {}

    d_model = _get_model_d_model(model)

    d_ref = stage_dim_overrides.get("main", d_model)

    if d_ref <= 0:

        raise ValueError(f"Reference dimension must be > 0, got {d_ref}")



    stage_batch_ratio = _stage_batch_ratio(target_ratio_stage1, target_ratio_stage2)

    stage_lr_lambdas = {}

    for stage, b_ratio in stage_batch_ratio.items():

        d_s = stage_dim_overrides.get(stage, d_model)

        if d_s <= 0:

            raise ValueError(f"Stage dimension for {stage} must be > 0, got {d_s}")

        stage_lr_lambdas[stage] = math.sqrt(b_ratio) * math.sqrt(d_ref / d_s)



    return stage_lr_lambdas



def resolve_stage_lr_lambdas(

    model: nn.Module,

    stage_lr_config: Any,

    target_ratio_stage1: float,

    target_ratio_stage2: float,

    stage_dim_overrides: Optional[Mapping[str, int]] = None,

) -> Dict[str, float]:

    auto_lambdas = compute_stage_lr_lambdas(

        model,

        target_ratio_stage1=target_ratio_stage1,

        target_ratio_stage2=target_ratio_stage2,

        stage_dim_overrides=stage_dim_overrides,

    )

    if stage_lr_config is None:

        return {stage: _normalize_lambda_value(stage, val) for stage, val in auto_lambdas.items()}

    if isinstance(stage_lr_config, str):

        if stage_lr_config.lower() == "auto":

            return {stage: _normalize_lambda_value(stage, val) for stage, val in auto_lambdas.items()}

        raise ValueError(f"Unsupported stage_wise_lr value: {stage_lr_config}")

    if isinstance(stage_lr_config, Mapping):

        if not stage_lr_config:

            return {stage: _normalize_lambda_value(stage, val) for stage, val in auto_lambdas.items()}

        if stage_lr_config.get("auto"):

            overrides = {k: v for k, v in stage_lr_config.items() if k != "auto"}

            merged = {**auto_lambdas, **overrides}

            return {stage: _normalize_lambda_value(stage, val) for stage, val in merged.items()}

        return {stage: _normalize_lambda_value(stage, val) for stage, val in stage_lr_config.items()}

    raise TypeError(f"Unsupported stage_wise_lr type: {type(stage_lr_config)}")







def get_stage_parameters_1stage(model: nn.Module, stage_name: str) -> List[nn.Parameter]:

    """
    Get the parameters for a specific stage of the 1-stage DNAChunker model.
    """

    params = []



    stage_patterns = {

        'embeddings': ['net.backbone.embeddings'],

        'encoder': ['net.backbone.encoder_layers', 'net.backbone.routing_module'],

        'main': ['net.backbone.main_model'],

        'decoder': ['net.backbone.decoder_layers', 'net.backbone.dechunker'],

        'norm': ['net.backbone.norm_f'],

        'head': ['lm_head'],

    }



    if stage_name not in stage_patterns:

        return []



    for name, p in model.named_parameters():

        if any(name.startswith(pattern) for pattern in stage_patterns[stage_name]):

            params.append(p)



    return params





def get_stage_named_parameters_1stage(model: nn.Module, stage_name: str) -> List[Tuple[str, nn.Parameter]]:

    """
    Get the named parameters for a specific stage of the 1-stage DNAChunker model.
    """

    named_params = []



    stage_patterns = {

        'embeddings': ['net.backbone.embeddings'],

        'encoder': ['net.backbone.encoder_layers', 'net.backbone.routing_module'],

        'main': ['net.backbone.main_model'],

        'decoder': ['net.backbone.decoder_layers', 'net.backbone.dechunker'],

        'norm': ['net.backbone.norm_f'],

        'head': ['lm_head'],

    }



    if stage_name not in stage_patterns:

        return []



    for name, p in model.named_parameters():

        if any(name.startswith(pattern) for pattern in stage_patterns[stage_name]):

            named_params.append((name, p))



    return named_params





def _stage_batch_ratio_1stage(target_ratio: float) -> Dict[str, float]:

    """
    Compute batch ratio for 1-stage model.
    L0 operates at full length, L1 operates at compressed length.
    """

    target_ratio = _validate_ratio("target_ratio", target_ratio)



    l0_over_l1 = 1.0 / target_ratio



    return {

        "embeddings": l0_over_l1,

        "encoder": l0_over_l1,

        "main": 1.0,

        "decoder": l0_over_l1,

        "norm": l0_over_l1,

        "head": l0_over_l1,

    }





def compute_stage_lr_lambdas_1stage(

    model: nn.Module,

    target_ratio: float,

    stage_dim_overrides: Optional[Mapping[str, int]] = None,

) -> Dict[str, float]:

    """
    Compute stage-wise LR scaling factors for 1-stage model.
    """

    stage_dim_overrides = stage_dim_overrides or {}

    d_model = _get_model_d_model(model)

    d_ref = stage_dim_overrides.get("main", d_model)

    if d_ref <= 0:

        raise ValueError(f"Reference dimension must be > 0, got {d_ref}")



    stage_batch_ratio = _stage_batch_ratio_1stage(target_ratio)

    stage_lr_lambdas = {}

    for stage, b_ratio in stage_batch_ratio.items():

        d_s = stage_dim_overrides.get(stage, d_model)

        if d_s <= 0:

            raise ValueError(f"Stage dimension for {stage} must be > 0, got {d_s}")

        stage_lr_lambdas[stage] = math.sqrt(b_ratio) * math.sqrt(d_ref / d_s)



    return stage_lr_lambdas





def resolve_stage_lr_lambdas_1stage(

    model: nn.Module,

    stage_lr_config: Any,

    target_ratio: float,

    stage_dim_overrides: Optional[Mapping[str, int]] = None,

) -> Dict[str, float]:

    """
    Resolve stage-wise LR lambdas for 1-stage model.
    """

    auto_lambdas = compute_stage_lr_lambdas_1stage(

        model,

        target_ratio=target_ratio,

        stage_dim_overrides=stage_dim_overrides,

    )

    if stage_lr_config is None:

        return {stage: _normalize_lambda_value(stage, val) for stage, val in auto_lambdas.items()}

    if isinstance(stage_lr_config, str):

        if stage_lr_config.lower() == "auto":

            return {stage: _normalize_lambda_value(stage, val) for stage, val in auto_lambdas.items()}

        raise ValueError(f"Unsupported stage_wise_lr value: {stage_lr_config}")

    if isinstance(stage_lr_config, Mapping):

        if not stage_lr_config:

            return {stage: _normalize_lambda_value(stage, val) for stage, val in auto_lambdas.items()}

        if stage_lr_config.get("auto"):

            overrides = {k: v for k, v in stage_lr_config.items() if k != "auto"}

            merged = {**auto_lambdas, **overrides}

            return {stage: _normalize_lambda_value(stage, val) for stage, val in merged.items()}

        return {stage: _normalize_lambda_value(stage, val) for stage, val in stage_lr_config.items()}

    raise TypeError(f"Unsupported stage_wise_lr type: {type(stage_lr_config)}")





def build_stage_wise_param_groups(model: nn.Module, lr: float, weight_decay: float, stage_lr_lambdas: Dict[str, float]) -> List[Dict[str, Any]]:

    """
    Build parameter groups for the optimizer with stage-wise learning rates.
    """

    param_groups = []



    if isinstance(lr, bool) or not isinstance(lr, numbers.Real):

        raise TypeError(f"learning_rate must be numeric, got {type(lr)}")

    if isinstance(weight_decay, bool) or not isinstance(weight_decay, numbers.Real):

        raise TypeError(f"weight_decay must be numeric, got {type(weight_decay)}")





    assigned_param_ids = set()



    stage_order = [

        "embeddings",

        "encoder1",

        "encoder2",

        "main",

        "decoder1",

        "decoder2",

        "norm",

        "head",

    ]





    for stage in stage_order:

        if stage not in stage_lr_lambdas:

            continue

        lambda_val = _normalize_lambda_value(stage, stage_lr_lambdas[stage])

        named_params = get_stage_named_parameters(model, stage)

        if named_params:

            decay_params = []

            no_decay_params = []

            for name, p in named_params:

                if _is_no_decay_param(name, stage):

                    no_decay_params.append(p)

                else:

                    decay_params.append(p)

                assigned_param_ids.add(id(p))

            if decay_params:

                param_groups.append({

                    "params": decay_params,

                    "lr": lr * lambda_val,

                    "weight_decay": weight_decay,

                })

            if no_decay_params:

                param_groups.append({

                    "params": no_decay_params,

                    "lr": lr * lambda_val,

                    "weight_decay": 0.0,

                })





    for stage, lambda_val in stage_lr_lambdas.items():

        if stage in stage_order:

            continue

        lambda_val = _normalize_lambda_value(stage, lambda_val)

        named_params = get_stage_named_parameters(model, stage)

        if named_params:

            decay_params = []

            no_decay_params = []

            for name, p in named_params:

                if _is_no_decay_param(name, stage):

                    no_decay_params.append(p)

                else:

                    decay_params.append(p)

                assigned_param_ids.add(id(p))

            if decay_params:

                param_groups.append({

                    "params": decay_params,

                    "lr": lr * lambda_val,

                    "weight_decay": weight_decay,

                })

            if no_decay_params:

                param_groups.append({

                    "params": no_decay_params,

                    "lr": lr * lambda_val,

                    "weight_decay": 0.0,

                })





    remaining_named_params = [(name, p) for name, p in model.named_parameters() if id(p) not in assigned_param_ids]

    if remaining_named_params:

        decay_params = []

        no_decay_params = []

        for name, p in remaining_named_params:

            if _is_no_decay_param(name):

                no_decay_params.append(p)

            else:

                decay_params.append(p)

            assigned_param_ids.add(id(p))

        if decay_params:

            param_groups.append({

                "params": decay_params,

                "lr": lr,

                "weight_decay": weight_decay,

            })

        if no_decay_params:

            param_groups.append({

                "params": no_decay_params,

                "lr": lr,

                "weight_decay": 0.0,

            })



    return param_groups





def build_stage_wise_param_groups_1stage(

    model: nn.Module,

    lr: float,

    weight_decay: float,

    stage_lr_lambdas: Dict[str, float]

) -> List[Dict[str, Any]]:

    """
    Build parameter groups for the optimizer with stage-wise learning rates (1-stage model).
    """

    param_groups = []



    if isinstance(lr, bool) or not isinstance(lr, numbers.Real):

        raise TypeError(f"learning_rate must be numeric, got {type(lr)}")

    if isinstance(weight_decay, bool) or not isinstance(weight_decay, numbers.Real):

        raise TypeError(f"weight_decay must be numeric, got {type(weight_decay)}")



    assigned_param_ids = set()



    stage_order = [

        "embeddings",

        "encoder",

        "main",

        "decoder",

        "norm",

        "head",

    ]



    for stage in stage_order:

        if stage not in stage_lr_lambdas:

            continue

        lambda_val = _normalize_lambda_value(stage, stage_lr_lambdas[stage])

        named_params = get_stage_named_parameters_1stage(model, stage)

        if named_params:

            decay_params = []

            no_decay_params = []

            for name, p in named_params:

                if _is_no_decay_param(name, stage):

                    no_decay_params.append(p)

                else:

                    decay_params.append(p)

                assigned_param_ids.add(id(p))

            if decay_params:

                param_groups.append({

                    "params": decay_params,

                    "lr": lr * lambda_val,

                    "weight_decay": weight_decay,

                })

            if no_decay_params:

                param_groups.append({

                    "params": no_decay_params,

                    "lr": lr * lambda_val,

                    "weight_decay": 0.0,

                })





    for stage, lambda_val in stage_lr_lambdas.items():

        if stage in stage_order:

            continue

        lambda_val = _normalize_lambda_value(stage, lambda_val)

        named_params = get_stage_named_parameters_1stage(model, stage)

        if named_params:

            decay_params = []

            no_decay_params = []

            for name, p in named_params:

                if _is_no_decay_param(name, stage):

                    no_decay_params.append(p)

                else:

                    decay_params.append(p)

                assigned_param_ids.add(id(p))

            if decay_params:

                param_groups.append({

                    "params": decay_params,

                    "lr": lr * lambda_val,

                    "weight_decay": weight_decay,

                })

            if no_decay_params:

                param_groups.append({

                    "params": no_decay_params,

                    "lr": lr * lambda_val,

                    "weight_decay": 0.0,

                })





    remaining_named_params = [(name, p) for name, p in model.named_parameters() if id(p) not in assigned_param_ids]

    if remaining_named_params:

        decay_params = []

        no_decay_params = []

        for name, p in remaining_named_params:

            if _is_no_decay_param(name):

                no_decay_params.append(p)

            else:

                decay_params.append(p)

            assigned_param_ids.add(id(p))

        if decay_params:

            param_groups.append({

                "params": decay_params,

                "lr": lr,

                "weight_decay": weight_decay,

            })

        if no_decay_params:

            param_groups.append({

                "params": no_decay_params,

                "lr": lr,

                "weight_decay": 0.0,

            })



    return param_groups

