#!/usr/bin/env python3

def get_adapt_agent_conf(agent, algorithm):
    from torch import nn
    from .drqv2_invar import AdaptationAgent
    if Agent.algorithm == 'drqv2':
        encoder = nn.Sequential(agent.encoder, agent.actor.trunk)
        inv_dynamics = nn.Sequential(agent.encoder, agent.inv_dynamics_head.trunk)
        config = dict(
            encoder=encoder, actor=agent.actor, inv_dynamics=inv_dynamics
        )
    elif Agent.algorithm in ['sac', 'svea', 'pad', 'soda']:
        encoder = nn.Sequential(agent.encoder, agent.actor.trunk)
        inv_dynamics = nn.Sequential(agent.encoder, agent.inv_dynamics_head.trunk)
        config = dict(
            encoder=agent.encoder, actor=agent.actor, inv_dynamics=agent.inv_dynamics_head
        )
    else:
        raise ValueError(f'invalid algorithm: {Agent.algorithm}')

    return config
