from ml_collections import ConfigDict, config_dict
from jaxrl_m.agents.continuous.cql import (
    ContinuousCQLAgent,
    get_default_config as get_default_cql_config,
)
from jaxrl_m.agents.continuous.gc_bc import GCBCAgent
from jax import numpy as jnp


def get_config(config_str: str) -> ConfigDict:
    config_str, *config_opts = config_str.split(",")

    if config_str == "gc_cql":
        return ConfigDict(
            {
                "agent_config": get_default_cql_config(
                    dict(
                        goal_conditioned=True,
                        early_goal_concat=True,
                        cql_autotune_alpha=False,
                        cql_alpha=0,
                        cql_target_action_gap=1.5,
                        cql_temp=1.0,
                        discount=0.97,
                        cql_importance_sample=False,
                        use_calql=False,
                        gc_kwargs=ConfigDict(dict(negative_proportion=0.25))
                    )
                ),
                "agent_cls": ContinuousCQLAgent,
                "batch_size": 256,
                "dataset_config": config_dict.placeholder(str, required=False),
                "online": "online" in config_opts,
                "train_steps": 1_000_000,
            }
        )
    elif config_str == "gc_bc":
        return ConfigDict(
            {
                "agent_config": dict(
                    early_goal_concat=True,
                    shared_goal_encoder=True,
                    policy_kwargs=dict(
                        tanh_squash_distribution=False,
                        std_parameterization="fixed",
                        fixed_std=jnp.array([0.2, 0.2]),
                    )
                ),
                "agent_cls": GCBCAgent,
                "batch_size": 256,
                "dataset_config": config_dict.placeholder(str, required=False),
                "train_steps": 1_000_000,
                "online": False,
            }
        )
    else:
        raise ValueError(f"Unknown config {config_str}")