
import torch
import torch.nn as nn 

from typing import Dict, Optional, Union, List, Any
from mcu.arm.models.encoders.action_encoder import ActionEncoder
from mcu.arm.models.encoders.text_encoder import TextEncoder
from mcu.arm.models.encoders.subgoal_encoder import SubGoalEncoder
from mcu.arm.models.encoders.trajectory_encoder import TrajectoryEncoder

def build_condition_encoder(
    name: Optional[str] = None, **kwargs, 
) -> Union[nn.Module, None]:
    if name == 'text':
        return TextEncoder(**kwargs)
    elif name == 'subgoal':
        return SubGoalEncoder(**kwargs)
    elif name == 'trajectory':
        return TrajectoryEncoder(**kwargs)
    else:
        return None

if __name__ == '__main__':
    pass
