from ml_collections import ConfigDict

from experiments.configs.cql_config import get_config as get_cql_config
from experiments.configs.iql_config import get_config as get_iql_config
from experiments.configs.sac_config import get_config as get_sac_config
from experiments.configs.wsrl_config import get_config as get_wsrl_config
from experiments.configs.td3_config import get_config as get_td3_config
from experiments.configs.rebrac_config import get_config as get_rebrac_config
from experiments.configs.bc_config import get_config as get_bc_config

def get_config(config_string):
    possible_structures = {

        ########################################################
        #                    antmaze configs                   #
        ########################################################

        "antmaze_cql": ConfigDict(
            dict(
                agent_kwargs=get_cql_config(
                    updates=dict(
                        policy_kwargs=dict(
                            tanh_squash_distribution=True,
                            std_parameterization="uniform",
                        ),
                        critic_network_kwargs={
                            "hidden_dims": [256, 256, 256, 256],
                            "activations": "relu",
                            "kernel_scale_final": 1e-2,
                        },
                        policy_network_kwargs={
                            "hidden_dims": [256, 256],
                            "activations": "relu",
                            "kernel_scale_final": 1e-2,
                        },
                        cql_autotune_alpha=True,
                        cql_target_action_gap=0.8,
                    )
                ).to_dict(),
            )
        ),

        "antmaze_iql":ConfigDict(
            dict(
                agent_kwargs=get_iql_config(
                    updates=dict(
                        expectile=0.9,
                        temperature=10.0,
                    )
                ).to_dict(),
            )
        ),

        "antmaze_wsrl": ConfigDict(
            dict(
                agent_kwargs=get_wsrl_config(
                    updates=dict(
                        policy_kwargs=dict(
                            tanh_squash_distribution=True,
                            std_parameterization="uniform",
                        ),
                        critic_network_kwargs={
                            "hidden_dims": [256, 256, 256, 256],
                            "activations": "relu",
                            "kernel_scale_final": 1e-2,
                            "use_layer_norm": True,
                        },
                        policy_network_kwargs={
                            "hidden_dims": [256, 256],
                            "activations": "relu",
                            "kernel_scale_final": 1e-2,
                            "use_layer_norm": True,
                        },
                        max_target_backup=True,
                    )
                ).to_dict(),
            )
        ),

        "antmaze_sac": ConfigDict(
            dict(
                agent_kwargs=get_sac_config(
                    updates=dict(
                        policy_kwargs=dict(
                            tanh_squash_distribution=True,
                            std_parameterization="uniform",
                        ),
                        critic_network_kwargs={
                            "hidden_dims": [256, 256, 256, 256],
                            "activations": "relu",
                            "kernel_scale_final": 1e-2,
                            "use_layer_norm": True,
                        },
                        policy_network_kwargs={
                            "hidden_dims": [256, 256],
                            "activations": "relu",
                            "kernel_scale_final": 1e-2,
                            "use_layer_norm": True,
                        },
                        max_target_backup=True,
                    )
                ).to_dict(),
            )
        ),

        "antmaze_td3": ConfigDict(
            dict(
                agent_kwargs=get_td3_config(
                    updates=dict(
                    critic_network_kwargs={
                        "hidden_dims": [256, 256, 256],
                        "activations": "relu",
                        "kernel_scale_final": 3e-3,
                        "use_layer_norm": True,
                    },
                    policy_network_kwargs={
                        "hidden_dims": [256, 256, 256],
                        "activations": "relu",
                        "kernel_scale_final": 1e-3,
                        "use_layer_norm": True,
                    },
                    discount=0.999,
                    actor_optimizer_kwargs=ConfigDict({"learning_rate": 1e-4}),
                    critic_optimizer_kwargs=ConfigDict({"learning_rate": 1e-4}),
                )
                ).to_dict(),
            )
        ),

        "antmaze_bc_deter": ConfigDict(
            dict(
                agent_kwargs=get_bc_config(
                    updates=dict(
                    critic_network_kwargs={
                        "hidden_dims": [256, 256, 256],
                        "activations": "relu",
                        "kernel_scale_final": 3e-3,
                        "use_layer_norm": True,
                    },
                    policy_network_kwargs={
                        "hidden_dims": [256, 256, 256],
                        "activations": "relu",
                        "kernel_scale_final": 1e-3,
                        "use_layer_norm": True,
                    },
                    discount=0.999,
                    policy_kwargs=dict(
                        tanh_squash_distribution=True,
                        std_parameterization="fixed",
                        fixed_std=0.0,
                    ),
                )
                ).to_dict(),
            )
        ),

        "antmaze_rebrac": ConfigDict(
            dict(
                agent_kwargs=get_rebrac_config(
                    updates=dict(
                    critic_network_kwargs={
                        "hidden_dims": [256, 256, 256],
                        "activations": "relu",
                        "kernel_scale_final": 3e-3,
                        "use_layer_norm": True,
                    },
                    policy_network_kwargs={
                        "hidden_dims": [256, 256, 256],
                        "activations": "relu",
                        "kernel_scale_final": 1e-3,
                        "use_layer_norm": True,
                    },
                    discount=0.999,
                    actor_optimizer_kwargs=ConfigDict({"learning_rate": 1e-4}),
                    critic_optimizer_kwargs=ConfigDict({"learning_rate": 1e-4}),
                )
                ).to_dict(),
            )
        ),
        ########################################################
        #                    adroit configs                    #
        ########################################################

        "adroit_cql": ConfigDict(
            dict(
                agent_kwargs=get_cql_config(
                    updates=dict(
                        policy_kwargs=dict(
                            tanh_squash_distribution=True,
                            std_parameterization="exp",
                        ),
                        critic_network_kwargs={
                            "hidden_dims": [512, 512, 512],
                            "kernel_scale_final": 1e-2,
                            "activations": "relu",
                        },
                        policy_network_kwargs={
                            "hidden_dims": [512, 512],
                            "kernel_scale_final": 1e-2,
                            "activations": "relu",
                        },
                        online_cql_alpha=1.0,
                        cql_alpha=1.0,
                    )
                ).to_dict(),
            )
        ),

        "adroit_iql":ConfigDict(
            dict(
                agent_kwargs=get_iql_config(
                    updates=dict(
                        policy_network_kwargs=dict(
                            hidden_dims=(256, 256),
                            kernel_init_type="var_scaling",
                            kernel_scale_final=1e-2,
                            dropout_rate=0.1,
                        ),
                        expectile=0.7,
                        temperature=0.5,
                    ),
                ).to_dict(),
            )
        ),

        "adroit_wsrl": ConfigDict(
            dict(
                agent_kwargs=get_wsrl_config(
                    updates=dict(
                        policy_kwargs=dict(
                            tanh_squash_distribution=True,
                            std_parameterization="exp",
                        ),
                        critic_network_kwargs={
                            "hidden_dims": [512, 512, 512],
                            "kernel_scale_final": 1e-2,
                            "activations": "relu",
                            "use_layer_norm": True,
                        },
                        policy_network_kwargs={
                            "hidden_dims": [512, 512],
                            "kernel_scale_final": 1e-2,
                            "activations": "relu",
                            "use_layer_norm": True,
                        },
                    )
                ).to_dict(),
            )
        ),

        "adroit_sac": ConfigDict(
            dict(
                agent_kwargs=get_sac_config(
                    updates=dict(
                        policy_kwargs=dict(
                            tanh_squash_distribution=True,
                            std_parameterization="exp",
                        ),
                        critic_network_kwargs={
                            "hidden_dims": [512, 512, 512],
                            "kernel_scale_final": 1e-2,
                            "activations": "relu",
                            "use_layer_norm": True,
                        },
                        policy_network_kwargs={
                            "hidden_dims": [512, 512],
                            "kernel_scale_final": 1e-2,
                            "activations": "relu",
                            "use_layer_norm": True,
                        },
                    )
                ).to_dict(),
            )
        ),

        "adroit_td3": ConfigDict(
            dict(
                agent_kwargs=get_td3_config(
                    updates=dict(
                        critic_network_kwargs={
                            "hidden_dims": [512, 512, 512],
                            "activations": "relu",
                            "kernel_scale_final": 1e-2,
                            "use_layer_norm": True,
                        },
                        policy_network_kwargs={
                            "hidden_dims": [512, 512],
                            "activations": "relu",
                            "kernel_scale_final": 1e-2,
                            "use_layer_norm": True,
                        },  
                    )
                ).to_dict(),
            )
        ),
        "adroit_bc_deter": ConfigDict(
            dict(
                agent_kwargs=get_bc_config(
                    updates=dict(
                        critic_network_kwargs={
                            "hidden_dims": [512, 512, 512],
                            "activations": "relu",
                            "kernel_scale_final": 1e-2,
                            "use_layer_norm": True,
                        },
                        policy_network_kwargs={
                            "hidden_dims": [512, 512],
                            "activations": "relu",
                            "kernel_scale_final": 1e-2,
                            "use_layer_norm": True,
                        },  
                        policy_kwargs=dict(
                            tanh_squash_distribution=True,
                            std_parameterization="fixed",
                            fixed_std=0.0,
                        ),
                    )
                ).to_dict(),
            )
        ),

        "adroit_rebrac": ConfigDict(
            dict(
                agent_kwargs=get_rebrac_config(
                    updates=dict(
                        critic_network_kwargs={
                            "hidden_dims": [512, 512, 512],
                            "activations": "relu",
                            "kernel_scale_final": 3e-3,
                            "use_layer_norm": True,
                        },
                        policy_network_kwargs={
                            "hidden_dims": [512, 512],
                            "activations": "relu",
                            "kernel_scale_final": 1e-3,
                            "use_layer_norm": True,
                        },  
                    actor_optimizer_kwargs=ConfigDict({"learning_rate": 3e-4}),
                    critic_optimizer_kwargs=ConfigDict({"learning_rate": 3e-4}),
                    )
                ).to_dict(),
            )
        ),
        ########################################################
        #                    kitchen configs                   #
        ########################################################

        "kitchen_cql": ConfigDict(
            dict(
                agent_kwargs=get_cql_config(
                    updates=dict(
                        policy_kwargs=dict(
                            tanh_squash_distribution=True,
                            std_parameterization="exp",
                        ),
                        critic_network_kwargs={
                            "hidden_dims": [512, 512, 512],
                            "activations": "relu",
                        },
                        policy_network_kwargs={
                            "hidden_dims": [512, 512, 512],
                            "activations": "relu",
                        },
                        online_cql_alpha=5.0,
                        cql_alpha=5.0,
                        cql_importance_sample=False,
                    )
                ).to_dict(),
            )
        ),

        "kitchen_iql":ConfigDict(
            dict(
                agent_kwargs=get_iql_config(
                    updates=dict(
                        policy_network_kwargs=dict(
                            hidden_dims=(256, 256),
                            activations="relu",
                            dropout_rate=0.1,
                        ),
                        critic_network_kwargs=dict(
                            hidden_dims=(256, 256),
                            activations="relu",
                        ),
                        expectile=0.7,
                        temperature=0.5,
                    )
                ).to_dict(),
            )
        ),

        "kitchen_wsrl": ConfigDict(
            dict(
                agent_kwargs=get_wsrl_config(
                    updates=dict(
                        policy_kwargs=dict(
                            tanh_squash_distribution=True,
                            std_parameterization="exp",
                        ),
                        critic_network_kwargs={
                            "hidden_dims": [512, 512, 512],
                            "activations": "relu",
                            "use_layer_norm": True,
                        },
                        policy_network_kwargs={
                            "hidden_dims": [512, 512, 512],
                            "activations": "relu",
                            "use_layer_norm": True,
                        },
                    )
                ).to_dict(),
            )
        ),

        "kitchen_sac": ConfigDict(
            dict(
                agent_kwargs=get_sac_config(
                    updates=dict(
                        policy_kwargs=dict(
                            tanh_squash_distribution=True,
                            std_parameterization="exp",
                        ),
                        critic_network_kwargs={
                            "hidden_dims": [512, 512, 512],
                            "activations": "relu",
                            "use_layer_norm": True,
                        },
                        policy_network_kwargs={
                            "hidden_dims": [512, 512, 512],
                            "activations": "relu",
                            "use_layer_norm": True,
                        },
                    )
                ).to_dict(),
            )
        ),

        "kitchen_td3": ConfigDict(
            dict(
                agent_kwargs=get_td3_config(
                    updates=dict(
                        critic_network_kwargs={
                            "hidden_dims": [512, 512, 512],
                            "activations": "relu",
                            "kernel_scale_final": 1e-2,
                            "use_layer_norm": True,
                        },
                        policy_network_kwargs={
                            "hidden_dims": [512, 512, 512],
                            "activations": "relu",
                            "kernel_scale_final": 1e-2,
                            "use_layer_norm": True,
                        },
                    )
                ).to_dict(),
            )
        ),

        "kitchen_bc_deter": ConfigDict(
            dict(
                agent_kwargs=get_bc_config(
                    updates=dict(
                        critic_network_kwargs={
                            "hidden_dims": [512, 512, 512],
                            "activations": "relu",
                            "kernel_scale_final": 1e-2,
                            "use_layer_norm": True,
                        },
                        policy_network_kwargs={
                            "hidden_dims": [512, 512, 512],
                            "activations": "relu",
                            "kernel_scale_final": 1e-2,
                            "use_layer_norm": True,
                        },
                        policy_kwargs=dict(
                            tanh_squash_distribution=True,
                            std_parameterization="fixed",
                            fixed_std=0.0,
                        ),
                    )
                ).to_dict(),
            )
        ),

        "kitchen_rebrac": ConfigDict(
            dict(
                agent_kwargs=get_rebrac_config(
                    updates=dict(
                        critic_network_kwargs={
                            "hidden_dims": [512, 512, 512],
                            "activations": "relu",
                            "kernel_scale_final": 1e-2,
                            "use_layer_norm": True,
                        },
                        policy_network_kwargs={
                            "hidden_dims": [512, 512, 512],
                            "activations": "relu",
                            "kernel_scale_final": 1e-2,
                            "use_layer_norm": True,
                        },
                        actor_optimizer_kwargs=ConfigDict({"learning_rate": 1e-3}),
                        critic_optimizer_kwargs=ConfigDict({"learning_rate": 1e-3}),
                    )
                ).to_dict(),
            )
        ),
        ########################################################
        #                  locomotion configs                  #
        ########################################################

        "locomotion_cql": ConfigDict(
            dict(
                agent_kwargs=get_cql_config(
                    updates=dict(
                        critic_network_kwargs={
                            "hidden_dims": [256, 256],
                            "activations": "relu",
                            "kernel_scale_final": 1e-2,
                        },
                        policy_network_kwargs={
                            "hidden_dims": [256, 256],
                            "activations": "relu",
                            "kernel_scale_final": 1e-2,
                        },
                        online_cql_alpha=5.0,
                        cql_alpha=5.0,
                    )
                ).to_dict(),
            )
        ),

        "locomotion_iql":ConfigDict(
            dict(
                agent_kwargs=get_iql_config(
                    updates=dict(
                        expectile=0.7,
                        temperature=3.0,
                    )
                ).to_dict(),
            )
        ),

        "locomotion_wsrl": ConfigDict(
            dict(
                agent_kwargs=get_wsrl_config(
                    updates=dict(
                        critic_network_kwargs={
                            "hidden_dims": [256, 256],
                            "activations": "relu",
                            "kernel_scale_final": 1e-2,
                            "use_layer_norm": True,
                        },
                        policy_network_kwargs={
                            "hidden_dims": [256, 256],
                            "activations": "relu",
                            "kernel_scale_final": 1e-2,
                            "use_layer_norm": True,
                        },
                    )
                ).to_dict(),
            )
        ),
        "locomotion_sac": ConfigDict(
            dict(
                agent_kwargs=get_sac_config(
                    updates=dict(
                        critic_network_kwargs={
                            "hidden_dims": [256, 256],
                            "activations": "relu",
                            "kernel_scale_final": 1e-2,
                            "use_layer_norm": True,
                        },
                        policy_network_kwargs={
                            "hidden_dims": [256, 256],
                            "activations": "relu",
                            "kernel_scale_final": 1e-2,
                            "use_layer_norm": True,
                        },
                    )
                ).to_dict(),
            )
        ),

        "locomotion_td3": ConfigDict(
            dict(
                agent_kwargs=get_td3_config(
                    updates=dict(
                        critic_network_kwargs={
                            "hidden_dims": [256, 256, 256],
                            "activations": "relu",  
                            "kernel_scale_final": 3e-3,
                            "use_layer_norm": True,
                        },
                        policy_network_kwargs={
                            "hidden_dims": [256, 256, 256],
                            "activations": "relu",  
                            "kernel_scale_final": 1e-3,
                            "use_layer_norm": True,
                        },
                    )
                ).to_dict(),
            )
        ),

        "locomotion_rebrac": ConfigDict(
            dict(
                agent_kwargs=get_rebrac_config(
                    updates=dict(
                        critic_network_kwargs={
                            "hidden_dims": [256, 256, 256],
                            "activations": "relu", 
                            "kernel_scale_final": 3e-3,
                            "use_layer_norm": True,
                        },
                        policy_network_kwargs={
                            "hidden_dims": [256, 256, 256],
                            "activations": "relu",
                            "kernel_scale_final": 1e-3,
                            "use_layer_norm": True,
                        },
                    )
                ).to_dict(),
            )
        ),
        
        "locomotion_bc_dist": ConfigDict(
            dict(
                agent_kwargs=get_bc_config(
                    updates=dict(
                        policy_network_kwargs={
                            "hidden_dims": [256, 256],
                            "activations": "relu",
                            "kernel_scale_final": 1e-2,
                            "use_layer_norm": True,
                        },
                        critic_network_kwargs=dict(
                            hidden_dims=[256, 256],
                            activations="relu",
                            kernel_scale_final=1e-2,
                            use_layer_norm=True,
                        ),
                    policy_kwargs=dict(
                        tanh_squash_distribution=True,
                        std_parameterization="exp",
                    ),
                    )
                ).to_dict(),
            )
        ),
        "locomotion_bc_deter": ConfigDict(
            dict(
                agent_kwargs=get_bc_config(
                    updates=dict(
                        decay_steps=500000,
                        policy_network_kwargs={
                            "hidden_dims": [256, 256, 256],
                            "activations": "relu",  
                            "kernel_scale_final": 3e-3,
                            "use_layer_norm": True,
                        },
                        critic_network_kwargs=dict(
                            hidden_dims=[256, 256, 256],
                            activations="relu",
                            kernel_scale_final=1e-3,
                            use_layer_norm=True,
                        ),
                    policy_kwargs=dict(
                        tanh_squash_distribution=True,
                        std_parameterization="fixed",
                        fixed_std=0.0,
                    ),
                    )
                ).to_dict(),
            )
        ),        
    }

    return possible_structures[config_string]
