from environments.kitchen.spirl.configs.hrl.kitchen.spirl.conf import *
from environments.kitchen.spirl.models.closed_loop_spirl_mdl import ClSPiRLMdl
from environments.kitchen.spirl.rl.policies.cl_model_policies import ClModelPolicy

# update model params to conditioned decoder on state
ll_model_params.cond_decode = True

# create LL closed-loop policy
ll_policy_params = AttrDict(
    policy_model=ClSPiRLMdl,
    policy_model_params=ll_model_params,
    policy_model_checkpoint=os.path.join(
        "environments/kitchen/spirl", "skill_prior_learning/kitchen/hierarchical_cl"
    ),
)
ll_policy_params.update(ll_model_params)

# create LL SAC agent (by default we will only use it for rolling out decoded skills, not finetuning skill decoder)
ll_agent_config = AttrDict(
    policy=ClModelPolicy,
    policy_params=ll_policy_params,
    critic=MLPCritic,  # LL critic is not used since we are not finetuning LL
    critic_params=hl_critic_params,
)

# update HL policy model params
hl_policy_params.update(
    AttrDict(
        prior_model=ll_policy_params.policy_model,
        prior_model_params=ll_policy_params.policy_model_params,
        prior_model_checkpoint=ll_policy_params.policy_model_checkpoint,
    )
)

# register new LL agent in agent_config and turn off LL agent updates
agent_config.update(
    AttrDict(
        ll_agent=SACAgent,
        ll_agent_params=ll_agent_config,
        update_ll=False,
    )
)
