from ml_collections import ConfigDict
from typing import Any, Mapping


def get_config(updates: Mapping[str, Any] | None = None) -> ConfigDict:

    config = ConfigDict()

    # Standard TD3 algorithm parameters
    config.discount = 0.99
    config.soft_target_update_rate = 5e-3

    # TD3-specific parameters (matching TD3Agent expected names)
    config.target_policy_noise = 0.2      # Standard TD3 target smoothing noise
    config.target_noise_clip = 0.5        # Standard TD3 noise clipping
    config.exploration_noise = 0.1        # Standard TD3 exploration noise
    config.policy_delay = 2               # Standard TD3 delayed policy updates

    # Twin critics (TD3 always uses exactly 2)
    config.critic_ensemble_size = 2
    config.critic_subsample_size = None

    # Network architectures
    config.critic_network_kwargs = ConfigDict(
        {
            "hidden_dims": [256, 256],
            "activate_final": True,
            "use_layer_norm": True,  # Often used in TD3 for stability
        }
    )

    config.policy_network_kwargs = ConfigDict(
        {
            "hidden_dims": [256, 256],
            "activate_final": True,
            "use_layer_norm": True,  # Can help with stability
        }
    )

    # Policy configuration for deterministic TD3
    config.policy_kwargs = ConfigDict(
        {
            "tanh_squash_distribution": True,   # Standard for continuous control
            "std_parameterization": "fixed",   # TD3 uses deterministic policy
            "fixed_std": 0.0,                  # Deterministic (no intrinsic noise)
        }
    )

    # Optimizers (standard TD3 learning rates)
    config.actor_optimizer_kwargs = ConfigDict({"learning_rate": 1e-3})     # pi_lr 1e-3
    config.critic_optimizer_kwargs = ConfigDict({"learning_rate": 1e-3})    # q_lr 1e-3

    # config.normlization_q = False
    # Behavioral cloning regularization (optional)
    config.use_bc = False
    config.normalize_q = False
    config.q_lambda_alpha = 2.5


    if updates is not None:
        config.update(ConfigDict(updates).copy_and_resolve_references())

    return config
