import dataclasses
import enum
import logging
import socket

import tyro

from openpi.policies import policy as _policy
from openpi.policies import policy_config as _policy_config
from openpi.serving import websocket_policy_server
from openpi.training import config as _config


class EnvMode(enum.Enum):
    """Supported environments."""

    ALOHA = "aloha"
    ALOHA_SIM = "aloha_sim"

    SMALL_PRETRAIN = "SMALL_PRETRAIN"
    CALVIN_FINETUNE_D_JOINT = "CALVIN_FINETUNE_D_JOINT"
    CALVIN_D_JOINT = "CALVIN_D_JOINT"
    CALVIN_JOINT = "CALVIN_JOINT"
    CALVIN_ABC_EEF_D_JOINT = "CALVIN_ABC_EEF_D_JOINT"
    CALVIN_EEF = "CALVIN_EEF"

    CALVIN_D_FINETUNE_JOINT_10K = "CALVIN_D_FINETUNE_JOINT_10K"
    CALVIN_D_FINETUNE_JOINT_20K = "CALVIN_D_FINETUNE_JOINT_20K"
    CALVIN_D_FINETUNE_JOINT_30K = "CALVIN_D_FINETUNE_JOINT_30K"
    CALVIN_D_FINETUNE_JOINT_35K = "CALVIN_D_FINETUNE_JOINT_35K"
    CALVIN_D_FINETUNE_JOINT_40K = "CALVIN_D_FINETUNE_JOINT_40K"
    CALVIN_D_FINETUNE_JOINT_45K = "CALVIN_D_FINETUNE_JOINT_45K"
    CALVIN_D_FINETUNE_JOINT_50K = "CALVIN_D_FINETUNE_JOINT_50K"
    CALVIN_D_FINETUNE_JOINT_55K = "CALVIN_D_FINETUNE_JOINT_55K"
    CALVIN_D_FINETUNE_JOINT_60K = "CALVIN_D_FINETUNE_JOINT_60K"
    CALVIN_D_FINETUNE_JOINT_65K = "CALVIN_D_FINETUNE_JOINT_65K"
    CALVIN_D_FINETUNE_JOINT_70K = "CALVIN_D_FINETUNE_JOINT_70K"
    CALVIN_D_FINETUNE_JOINT_75K = "CALVIN_D_FINETUNE_JOINT_75K"
    CALVIN_D_FINETUNE_JOINT_80K = "CALVIN_D_FINETUNE_JOINT_80K"
    CALVIN_D_FINETUNE_JOINT_85K = "CALVIN_D_FINETUNE_JOINT_85K"
    CALVIN_D_FINETUNE_JOINT_90K = "CALVIN_D_FINETUNE_JOINT_90K"


    CALVIN_D_FINETUNE_EEF_10K = "CALVIN_D_FINETUNE_EEF_10K"
    CALVIN_D_FINETUNE_EEF_20K = "CALVIN_D_FINETUNE_EEF_20K"
    CALVIN_D_FINETUNE_EEF_30K = "CALVIN_D_FINETUNE_EEF_30K"
    CALVIN_D_FINETUNE_EEF_40K = "CALVIN_D_FINETUNE_EEF_40K"
    CALVIN_D_FINETUNE_EEF_50K = "CALVIN_D_FINETUNE_EEF_50K"
    CALVIN_D_FINETUNE_EEF_60K = "CALVIN_D_FINETUNE_EEF_60K"
    CALVIN_D_FINETUNE_EEF_70K = "CALVIN_D_FINETUNE_EEF_70K"
    CALVIN_D_FINETUNE_EEF_80K = "CALVIN_D_FINETUNE_EEF_80K"
    CALVIN_D_FINETUNE_EEF_90K = "CALVIN_D_FINETUNE_EEF_90K"

    CALVIN_D_JOINT_20K = "CALVIN_D_JOINT_20K"
    CALVIN_D_JOINT_25K = "CALVIN_D_JOINT_25K"
    CALVIN_D_JOINT_30K = "CALVIN_D_JOINT_30K"
    CALVIN_D_JOINT_35K = "CALVIN_D_JOINT_35K"
    CALVIN_D_JOINT_40K = "CALVIN_D_JOINT_40K"
    CALVIN_D_JOINT_45K = "CALVIN_D_JOINT_45K"
    CALVIN_D_JOINT_50K = "CALVIN_D_JOINT_50K"
    CALVIN_D_JOINT_55K = "CALVIN_D_JOINT_55K"
    CALVIN_D_JOINT_60K = "CALVIN_D_JOINT_60K"
    CALVIN_D_JOINT_65K = "CALVIN_D_JOINT_65K"
    CALVIN_D_JOINT_70K = "CALVIN_D_JOINT_70K"
    CALVIN_D_JOINT_75K = "CALVIN_D_JOINT_75K"
    CALVIN_D_JOINT_80K = "CALVIN_D_JOINT_80K"
    CALVIN_D_JOINT_85K = "CALVIN_D_JOINT_85K"
    CALVIN_D_JOINT_90K = "CALVIN_D_JOINT_90K"
    CALVIN_D_JOINT_95K = "CALVIN_D_JOINT_95K"
    CALVIN_D_JOINT_100K = "CALVIN_D_JOINT_100K"

    CALVIN_ABC_EEF_D_JOINT_20K = "CALVIN_ABC_EEF_D_JOINT_20K"
    CALVIN_ABC_EEF_D_JOINT_40K = "CALVIN_ABC_EEF_D_JOINT_40K"
    CALVIN_ABC_EEF_D_JOINT_60K = "CALVIN_ABC_EEF_D_JOINT_60K"
    CALVIN_ABC_EEF_D_JOINT_70K = "CALVIN_ABC_EEF_D_JOINT_70K"
    CALVIN_ABC_EEF_D_JOINT_75K = "CALVIN_ABC_EEF_D_JOINT_75K"
    CALVIN_ABC_EEF_D_JOINT_80K = "CALVIN_ABC_EEF_D_JOINT_80K"
    CALVIN_ABC_EEF_D_JOINT_85K = "CALVIN_ABC_EEF_D_JOINT_85K"
    CALVIN_ABC_EEF_D_JOINT_90K = "CALVIN_ABC_EEF_D_JOINT_90K"
    CALVIN_ABC_EEF_D_JOINT_95K = "CALVIN_ABC_EEF_D_JOINT_95K"
    CALVIN_ABC_EEF_D_JOINT_100K = "CALVIN_ABC_EEF_D_JOINT_100K"
    CALVIN_ABC_EEF_D_JOINT_105K = "CALVIN_ABC_EEF_D_JOINT_105K"
    CALVIN_ABC_EEF_D_JOINT_110K = "CALVIN_ABC_EEF_D_JOINT_110K"
    CALVIN_ABC_EEF_D_JOINT_115K = "CALVIN_ABC_EEF_D_JOINT_115K"
    CALVIN_ABC_EEF_D_JOINT_120K = "CALVIN_ABC_EEF_D_JOINT_120K"
    CALVIN_ABC_EEF_D_JOINT_125K = "CALVIN_ABC_EEF_D_JOINT_125K"
    CALVIN_ABC_EEF_D_JOINT_130K = "CALVIN_ABC_EEF_D_JOINT_130K"
    CALVIN_ABC_EEF_D_JOINT_135K = "CALVIN_ABC_EEF_D_JOINT_135K"

    CALVIN_ABC_EEF_D_JOINT_NO_AUX_LOSS_60K = "CALVIN_ABC_EEF_D_JOINT_NO_AUX_LOSS_60K"
    CALVIN_ABC_EEF_D_JOINT_NO_AUX_LOSS_80K = "CALVIN_ABC_EEF_D_JOINT_NO_AUX_LOSS_80K"
    CALVIN_ABC_EEF_D_JOINT_NO_AUX_LOSS_90K = "CALVIN_ABC_EEF_D_JOINT_NO_AUX_LOSS_90K"

    CALVIN_ABC_EEF_D_JOINT_TOPK4_90K = "CALVIN_ABC_EEF_D_JOINT_TOPK4_90K"
    CALVIN_ABC_EEF_D_JOINT_TOPK1_90K = "CALVIN_ABC_EEF_D_JOINT_TOPK1_90K"
    CALVIN_ABC_EEF_D_JOINT_EXPERTS16_90K = "CALVIN_ABC_EEF_D_JOINT_EXPERTS16_90K"
    CALVIN_ABC_EEF_D_JOINT_EXPERTS4_90K = "CALVIN_ABC_EEF_D_JOINT_EXPERTS4_90K"
    CALVIN_ABC_EEF_D_JOINT_EXPERTS2_90K = "CALVIN_ABC_EEF_D_JOINT_EXPERTS2_90K"
    CALVIN_ABC_EEF_D_JOINT_EXPERTS16_TOPK4_90K = "CALVIN_ABC_EEF_D_JOINT_EXPERTS16_TOPK4_90K"
    CALVIN_ABC_EEF_D_JOINT_EXPERTS16_TOPK8_90K = "CALVIN_ABC_EEF_D_JOINT_EXPERTS16_TOPK8_90K"
    CALVIN_ABC_EEF_D_JOINT_EXPERTS16_TOPK8_100K = "CALVIN_ABC_EEF_D_JOINT_EXPERTS16_TOPK8_100K"
    CALVIN_ABC_EEF_D_JOINT_EXPERTS32_TOPK4_90K = "CALVIN_ABC_EEF_D_JOINT_EXPERTS32_TOPK4_90K"
    CALVIN_ABC_EEF_D_JOINT_EXPERTS32_TOPK8_90K = "CALVIN_ABC_EEF_D_JOINT_EXPERTS32_TOPK8_90K"

    CALVIN_ABC_EEF_D_JOINT_EXPERTS128_TOPK4_90K = "CALVIN_ABC_EEF_D_JOINT_EXPERTS128_TOPK8_90K"
    CALVIN_ABC_EEF_D_JOINT_EXPERTS64_TOPK4_90K = "CALVIN_ABC_EEF_D_JOINT_EXPERTS64_TOPK4_90K"
    CALVIN_ABC_EEF_D_JOINT_EXPERTS64_TOPK4_100K = "CALVIN_ABC_EEF_D_JOINT_EXPERTS64_TOPK4_100K"

    CO_TRAINING_10K = "CO_TRAINING_10K"
    CO_TRAINING_15K = "CO_TRAINING_15K"
    CO_TRAINING_20K = "CO_TRAINING_20K"
    CO_TRAINING_25K = "CO_TRAINING_25K"
    CO_TRAINING_30K = "CO_TRAINING_30K"
    CO_TRAINING_35K = "CO_TRAINING_35K"
    CO_TRAINING_40K = "CO_TRAINING_40K"
    CO_TRAINING_45K = "CO_TRAINING_45K"
    CO_TRAINING_50K = "CO_TRAINING_50K"
    CO_TRAINING_60K = "CO_TRAINING_60K"
    CO_TRAINING_70K = "CO_TRAINING_70K"
    CO_TRAINING_80K = "CO_TRAINING_80K"
    CO_TRAINING_90K = "CO_TRAINING_90K"

    CALVIN_D_FINETUNE_JOINT_NO_FREEZE_VISION_20K = "CALVIN_D_FINETUNE_JOINT_NO_FREEZE_VISION_20K"
    CALVIN_D_FINETUNE_JOINT_NO_FREEZE_VISION_25K = "CALVIN_D_FINETUNE_JOINT_NO_FREEZE_VISION_25K"
    CALVIN_D_FINETUNE_JOINT_NO_FREEZE_VISION_30K = "CALVIN_D_FINETUNE_JOINT_NO_FREEZE_VISION_30K"
    CALVIN_D_FINETUNE_JOINT_NO_FREEZE_VISION_35K = "CALVIN_D_FINETUNE_JOINT_NO_FREEZE_VISION_35K"
    CALVIN_D_FINETUNE_JOINT_NO_FREEZE_VISION_40K = "CALVIN_D_FINETUNE_JOINT_NO_FREEZE_VISION_40K"
    CALVIN_D_FINETUNE_JOINT_NO_FREEZE_VISION_45K = "CALVIN_D_FINETUNE_JOINT_NO_FREEZE_VISION_45K"
    CALVIN_D_FINETUNE_JOINT_NO_FREEZE_VISION_50K = "CALVIN_D_FINETUNE_JOINT_NO_FREEZE_VISION_50K"
    CALVIN_D_FINETUNE_JOINT_NO_FREEZE_VISION_55K = "CALVIN_D_FINETUNE_JOINT_NO_FREEZE_VISION_55K"
    CALVIN_D_FINETUNE_JOINT_NO_FREEZE_VISION_60K = "CALVIN_D_FINETUNE_JOINT_NO_FREEZE_VISION_60K"
    CALVIN_D_FINETUNE_JOINT_NO_FREEZE_VISION_65K = "CALVIN_D_FINETUNE_JOINT_NO_FREEZE_VISION_65K"
    CALVIN_D_FINETUNE_JOINT_NO_FREEZE_VISION_70K = "CALVIN_D_FINETUNE_JOINT_NO_FREEZE_VISION_70K"
    CALVIN_D_FINETUNE_JOINT_NO_FREEZE_VISION_75K = "CALVIN_D_FINETUNE_JOINT_NO_FREEZE_VISION_75K"
    CALVIN_D_FINETUNE_JOINT_NO_FREEZE_VISION_80K = "CALVIN_D_FINETUNE_JOINT_NO_FREEZE_VISION_80K"
    CALVIN_D_FINETUNE_JOINT_NO_FREEZE_VISION_85K = "CALVIN_D_FINETUNE_JOINT_NO_FREEZE_VISION_85K"
    CALVIN_D_FINETUNE_JOINT_NO_FREEZE_VISION_90K = "CALVIN_D_FINETUNE_JOINT_NO_FREEZE_VISION_90K"

    CALVIN_D_FINETUNE_JOINT_NO_INIT_MOE_30K = "CALVIN_D_FINETUNE_JOINT_NO_INIT_MOE_30K"
    CALVIN_D_FINETUNE_JOINT_NO_INIT_MOE_40K = "CALVIN_D_FINETUNE_JOINT_NO_INIT_MOE_40K"
    CALVIN_D_FINETUNE_JOINT_NO_INIT_MOE_50K = "CALVIN_D_FINETUNE_JOINT_NO_INIT_MOE_50K"
    CALVIN_D_FINETUNE_JOINT_NO_INIT_MOE_60K = "CALVIN_D_FINETUNE_JOINT_NO_INIT_MOE_60K"
    CALVIN_D_FINETUNE_JOINT_NO_INIT_MOE_70K = "CALVIN_D_FINETUNE_JOINT_NO_INIT_MOE_70K"

    LIBERO_10_20k = "LIBERO_10_20k"
    LIBERO_10_25k = "LIBERO_10_25k"
    LIBERO_10_30k = "LIBERO_10_30k"
    LIBERO_10_35k = "LIBERO_10_35k"
    LIBERO_10_40k = "LIBERO_10_40k"
    LIBERO_10_45k = "LIBERO_10_45k"
    LIBERO_10_50k = "LIBERO_10_50k"
    LIBERO_10_55k = "LIBERO_10_55k"
    LIBERO_10_60k = "LIBERO_10_60k"
    LIBERO_10_65k = "LIBERO_10_65k"
    LIBERO_10_70k = "LIBERO_10_70k"
    LIBERO_10_75k = "LIBERO_10_75k"
    LIBERO_10_80k = "LIBERO_10_80k"

    LIBERO_goal_20k = "LIBERO_goal_20k"
    LIBERO_goal_25k = "LIBERO_goal_25k"
    LIBERO_goal_30k = "LIBERO_goal_30k"
    LIBERO_goal_35k = "LIBERO_goal_35k"
    LIBERO_goal_40k = "LIBERO_goal_40k"
    LIBERO_goal_45k = "LIBERO_goal_45k"
    LIBERO_goal_50k = "LIBERO_goal_50k"
    LIBERO_goal_55k = "LIBERO_goal_55k"
    LIBERO_goal_60k = "LIBERO_goal_60k"
    LIBERO_goal_65k = "LIBERO_goal_65k"
    LIBERO_goal_70k = "LIBERO_goal_70k"
    LIBERO_goal_75k = "LIBERO_goal_75k"
    LIBERO_goal_80k = "LIBERO_goal_80k"
    

    LIBERO_object_20k = "LIBERO_object_20k"
    LIBERO_object_25k = "LIBERO_object_25k"
    LIBERO_object_30k = "LIBERO_object_30k"
    LIBERO_object_35k = "LIBERO_object_35k"
    LIBERO_object_40k = "LIBERO_object_40k"
    LIBERO_object_45k = "LIBERO_object_45k"
    LIBERO_object_50k = "LIBERO_object_50k"
    LIBERO_object_55k = "LIBERO_object_55k"

    LIBERO_spatial_20k = "LIBERO_spatial_20k"
    LIBERO_spatial_25k = "LIBERO_spatial_25k"
    LIBERO_spatial_30k = "LIBERO_spatial_30k"
    LIBERO_spatial_35k = "LIBERO_spatial_35k"
    LIBERO_spatial_40k = "LIBERO_spatial_40k"
    LIBERO_spatial_45k = "LIBERO_spatial_45k"
    LIBERO_spatial_50k = "LIBERO_spatial_50k"
    LIBERO_spatial_55k = "LIBERO_spatial_55k"
    LIBERO_spatial_60k = "LIBERO_spatial_60k"
    LIBERO_spatial_65k = "LIBERO_spatial_65k"
    LIBERO_spatial_70k = "LIBERO_spatial_70k"
    LIBERO_spatial_75k = "LIBERO_spatial_75k"
    LIBERO_spatial_80k = "LIBERO_spatial_80k"

    LIBERO_spatial_from_scratch_10k = "LIBERO_spatial_from_scratch_10k"
    LIBERO_spatial_from_scratch_30k = "LIBERO_spatial_from_scratch_30k"
    LIBERO_spatial_from_scratch_40k = "LIBERO_spatial_from_scratch_40k"
    LIBERO_spatial_from_scratch_50k = "LIBERO_spatial_from_scratch_50k"
    

@dataclasses.dataclass
class Checkpoint:
    """Load a policy from a trained checkpoint."""

    # Training config name (e.g., "pi0_aloha_sim").
    config: str
    # Checkpoint directory (e.g., "checkpoints/pi0_aloha_sim/exp/10000").
    dir: str


@dataclasses.dataclass
class Default:
    """Use the default policy for the given environment."""


@dataclasses.dataclass
class Args:
    """Arguments for the serve_policy script."""

    # Environment to serve the policy for. This is only used when serving default policies.
    env: EnvMode = EnvMode.ALOHA_SIM

    # If provided, will be used in case the "prompt" key is not present in the data, or if the model doesn't have a default
    # prompt.
    default_prompt: str | None = None

    # Port to serve the policy on.
    port: int = 8000
    # Record the policy's behavior for debugging.
    record: bool = False

    # Specifies how to load the policy. If not provided, the default policy for the environment will be used.
    policy: Checkpoint | Default = dataclasses.field(default_factory=Default)


# Default checkpoints that should be used for each environment.
DEFAULT_CHECKPOINT: dict[EnvMode, Checkpoint] = {
    EnvMode.ALOHA: Checkpoint(
        config="aloha_moevla",
        dir="/mnt/blob/task/moe_v2_full_aloha_moevla_9_8/checkpoints/aloha_moevla_finetune/moevla/checkpoint-45000/pytorch_model.pth",
    ),
    
    EnvMode.ALOHA_SIM: Checkpoint(
        config="pi0_aloha_sim",
        dir="s3://openpi-assets/checkpoints/pi0_aloha_sim",
    ),

    EnvMode.CALVIN_D_FINETUNE_JOINT_NO_FREEZE_VISION_20K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_gaussion_calvin_d_joint_finetune_8_19/checkpoints/calvin_d_joint_finetune_bs64/calvin_d_joint/checkpoint-20000/pytorch_model.pth",
    ),

    EnvMode.CALVIN_D_FINETUNE_EEF_10K: Checkpoint(
        config="calvin_d_eef",
        dir="/mnt/blob/task/moe_v2_gaussion_calvin_d_eef_finetune_8_13/checkpoints/calvin_d_eef_finetune/calvin_d_eef/checkpoint-10000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_D_FINETUNE_EEF_20K: Checkpoint(
        config="calvin_d_eef",
        dir="/mnt/blob/task/moe_v2_gaussion_calvin_d_eef_finetune_8_13/checkpoints/calvin_d_eef_finetune/calvin_d_eef/checkpoint-20000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_D_FINETUNE_EEF_30K: Checkpoint(
        config="calvin_d_eef",
        dir="/mnt/blob/task/moe_v2_gaussion_calvin_d_eef_finetune_8_13/checkpoints/calvin_d_eef_finetune/calvin_d_eef/checkpoint-30000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_D_FINETUNE_EEF_40K: Checkpoint(
        config="calvin_d_eef",
        dir="/mnt/blob/task/moe_v2_gaussion_calvin_d_eef_finetune_8_13/checkpoints/calvin_d_eef_finetune/calvin_d_eef/checkpoint-40000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_D_FINETUNE_EEF_50K: Checkpoint(
        config="calvin_d_eef",
        dir="/mnt/blob/task/moe_v2_gaussion_calvin_d_eef_finetune_8_13/checkpoints/calvin_d_eef_finetune/calvin_d_eef/checkpoint-50000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_D_FINETUNE_EEF_60K: Checkpoint(
        config="calvin_d_eef",
        dir="/mnt/blob/task/moe_v2_gaussion_calvin_d_eef_finetune_8_13/checkpoints/calvin_d_eef_finetune/calvin_d_eef/checkpoint-60000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_D_FINETUNE_EEF_70K: Checkpoint(
        config="calvin_d_eef",
        dir="/mnt/blob/task/moe_v2_gaussion_calvin_d_eef_finetune_8_13/checkpoints/calvin_d_eef_finetune/calvin_d_eef/checkpoint-70000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_D_FINETUNE_EEF_80K: Checkpoint(
        config="calvin_d_eef",
        dir="/mnt/blob/task/moe_v2_gaussion_calvin_d_eef_finetune_8_13/checkpoints/calvin_d_eef_finetune/calvin_d_eef/checkpoint-80000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_D_FINETUNE_EEF_90K: Checkpoint(
        config="calvin_d_eef",
        dir="/mnt/blob/task/moe_v2_gaussion_calvin_d_eef_finetune_8_13/checkpoints/calvin_d_eef_finetune/calvin_d_eef/checkpoint-90000/pytorch_model.pth",
    ),
    
    EnvMode.CALVIN_D_JOINT_20K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_gaussion_calvin_d_joint_finetune_8_19/checkpoints/calvin_d_joint_finetune_bs64/calvin_d_joint/checkpoint-20000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_D_JOINT_25K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_gaussion_calvin_d_joint_finetune_8_19/checkpoints/calvin_d_joint_finetune_bs64/calvin_d_joint/checkpoint-25000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_D_JOINT_30K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_gaussion_calvin_d_joint_finetune_8_19/checkpoints/calvin_d_joint_finetune_bs64/calvin_d_joint/checkpoint-30000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_D_JOINT_35K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_gaussion_calvin_d_joint_finetune_8_19/checkpoints/calvin_d_joint_finetune_bs64/calvin_d_joint/checkpoint-35000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_D_JOINT_40K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_gaussion_calvin_d_joint_finetune_8_19/checkpoints/calvin_d_joint_finetune_bs64/calvin_d_joint/checkpoint-40000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_D_JOINT_45K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_gaussion_calvin_d_joint_finetune_8_19/checkpoints/calvin_d_joint_finetune_bs64/calvin_d_joint/checkpoint-45000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_D_JOINT_50K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_gaussion_calvin_d_joint_finetune_8_19/checkpoints/calvin_d_joint_finetune_bs64/calvin_d_joint/checkpoint-50000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_D_JOINT_55K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_gaussion_calvin_d_joint_finetune_8_19/checkpoints/calvin_d_joint_finetune_bs64/calvin_d_joint/checkpoint-55000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_D_JOINT_60K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_gaussion_calvin_d_joint_finetune_8_19/checkpoints/calvin_d_joint_finetune_bs64/calvin_d_joint/checkpoint-60000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_D_JOINT_65K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_gaussion_calvin_d_joint_finetune_8_19/checkpoints/calvin_d_joint_finetune_bs64/calvin_d_joint/checkpoint-65000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_D_JOINT_70K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_gaussion_calvin_d_joint_finetune_8_19/checkpoints/calvin_d_joint_finetune_bs64/calvin_d_joint/checkpoint-70000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_D_JOINT_75K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_gaussion_calvin_d_joint_finetune_8_19/checkpoints/calvin_d_joint_finetune_bs64/calvin_d_joint/checkpoint-75000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_D_JOINT_80K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_gaussion_calvin_d_joint_finetune_8_19/checkpoints/calvin_d_joint_finetune_bs64/calvin_d_joint/checkpoint-80000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_D_JOINT_85K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_gaussion_calvin_d_joint_finetune_8_19/checkpoints/calvin_d_joint_finetune_bs64/calvin_d_joint/checkpoint-85000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_D_JOINT_90K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_gaussion_calvin_d_joint_finetune_8_19/checkpoints/calvin_d_joint_finetune_bs64/calvin_d_joint/checkpoint-90000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_D_JOINT_95K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_gaussion_calvin_d_joint_finetune_8_19/checkpoints/calvin_d_joint_finetune_bs64/calvin_d_joint/checkpoint-95000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_D_JOINT_100K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_gaussion_calvin_d_joint_finetune_8_19/checkpoints/calvin_d_joint_finetune_bs64/calvin_d_joint/checkpoint-100000/pytorch_model.pth",
    ),
    
    # EnvMode.CALVIN_ABC_EEF_D_JOINT_20K: Checkpoint(
    #     config="calvin_d_joint",
    #     dir="/mnt/blob/task/moe_v1_full_calvin_abc_eef_d_joint_from_scratch_7_31/checkpoints/calvin_abc_eef_d_joint_pretrain/calvin_abc_eef_d_joint/checkpoint-20000/pytorch_model.pth",
    # ),
    EnvMode.CALVIN_ABC_EEF_D_JOINT_40K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_full_calvin_abc_eef_d_joint_from_scratch_7_31/checkpoints/calvin_abc_eef_d_joint_pretrain/calvin_abc_eef_d_joint/checkpoint-40000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_ABC_EEF_D_JOINT_60K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_full_calvin_abc_eef_d_joint_from_scratch_7_31/checkpoints/calvin_abc_eef_d_joint_pretrain/calvin_abc_eef_d_joint/checkpoint-60000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_ABC_EEF_D_JOINT_70K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_full_calvin_abc_eef_d_joint_from_scratch_7_31/checkpoints/calvin_abc_eef_d_joint_pretrain/calvin_abc_eef_d_joint/checkpoint-70000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_ABC_EEF_D_JOINT_75K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_full_calvin_abc_eef_d_joint_from_scratch_7_31/checkpoints/calvin_abc_eef_d_joint_pretrain/calvin_abc_eef_d_joint/checkpoint-75000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_ABC_EEF_D_JOINT_80K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_full_calvin_abc_eef_d_joint_from_scratch_7_31/checkpoints/calvin_abc_eef_d_joint_pretrain/calvin_abc_eef_d_joint/checkpoint-80000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_ABC_EEF_D_JOINT_85K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_full_calvin_abc_eef_d_joint_from_scratch_7_31/checkpoints/calvin_abc_eef_d_joint_pretrain/calvin_abc_eef_d_joint/checkpoint-85000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_ABC_EEF_D_JOINT_90K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_full_calvin_abc_eef_d_joint_from_scratch_7_31/checkpoints/calvin_abc_eef_d_joint_pretrain/calvin_abc_eef_d_joint/checkpoint-90000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_ABC_EEF_D_JOINT_95K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_full_calvin_abc_eef_d_joint_from_scratch_7_31/checkpoints/calvin_abc_eef_d_joint_pretrain/calvin_abc_eef_d_joint/checkpoint-95000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_ABC_EEF_D_JOINT_100K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_full_calvin_abc_eef_d_joint_from_scratch_7_31/checkpoints/calvin_abc_eef_d_joint_pretrain/calvin_abc_eef_d_joint/checkpoint-100000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_ABC_EEF_D_JOINT_105K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_full_calvin_abc_eef_d_joint_from_scratch_7_31/checkpoints/calvin_abc_eef_d_joint_pretrain/calvin_abc_eef_d_joint/checkpoint-105000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_ABC_EEF_D_JOINT_110K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_full_calvin_abc_eef_d_joint_from_scratch_7_31/checkpoints/calvin_abc_eef_d_joint_pretrain/calvin_abc_eef_d_joint/checkpoint-110000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_ABC_EEF_D_JOINT_115K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_full_calvin_abc_eef_d_joint_from_scratch_7_31/checkpoints/calvin_abc_eef_d_joint_pretrain/calvin_abc_eef_d_joint/checkpoint-115000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_ABC_EEF_D_JOINT_120K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_full_calvin_abc_eef_d_joint_from_scratch_7_31/checkpoints/calvin_abc_eef_d_joint_pretrain/calvin_abc_eef_d_joint/checkpoint-120000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_ABC_EEF_D_JOINT_125K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_full_calvin_abc_eef_d_joint_from_scratch_7_31/checkpoints/calvin_abc_eef_d_joint_pretrain/calvin_abc_eef_d_joint/checkpoint-125000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_ABC_EEF_D_JOINT_130K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_full_calvin_abc_eef_d_joint_from_scratch_7_31/checkpoints/calvin_abc_eef_d_joint_pretrain/calvin_abc_eef_d_joint/checkpoint-130000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_ABC_EEF_D_JOINT_135K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_full_calvin_abc_eef_d_joint_from_scratch_7_31/checkpoints/calvin_abc_eef_d_joint_pretrain/calvin_abc_eef_d_joint/checkpoint-135000/pytorch_model.pth",
    ),

    EnvMode.CALVIN_ABC_EEF_D_JOINT_NO_AUX_LOSS_60K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_full_calvin_abc_eef_d_joint_from_scratch_no_aux_8_20/checkpoints/calvin_abc_eef_d_joint_pretrain/calvin_abc_eef_d_joint/checkpoint-60000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_ABC_EEF_D_JOINT_NO_AUX_LOSS_80K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_full_calvin_abc_eef_d_joint_from_scratch_no_aux_8_20/checkpoints/calvin_abc_eef_d_joint_pretrain/calvin_abc_eef_d_joint/checkpoint-80000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_ABC_EEF_D_JOINT_NO_AUX_LOSS_90K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_full_calvin_abc_eef_d_joint_from_scratch_no_aux_8_20/checkpoints/calvin_abc_eef_d_joint_pretrain/calvin_abc_eef_d_joint/checkpoint-90000/pytorch_model.pth",
    ),

    EnvMode.CALVIN_ABC_EEF_D_JOINT_TOPK4_90K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_full_calvin_abc_eef_d_joint_from_scratch_topk4_8_22/checkpoints/calvin_abc_eef_d_joint_pretrain_topk_4/calvin_d_joint/checkpoint-90000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_ABC_EEF_D_JOINT_TOPK1_90K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_full_calvin_abc_eef_d_joint_from_scratch_topk1_8_22/checkpoints/calvin_abc_eef_d_joint_pretrain_topk_1/calvin_d_joint/checkpoint-90000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_ABC_EEF_D_JOINT_EXPERTS16_90K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_full_calvin_abc_eef_d_joint_from_scratch_experts16_8_22/checkpoints/calvin_abc_eef_d_joint_pretrain_experts_16/calvin_d_joint/checkpoint-90000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_ABC_EEF_D_JOINT_EXPERTS16_TOPK4_90K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_full_calvin_abc_eef_d_joint_from_scratch_experts16_topk4_8_30/checkpoints/calvin_abc_eef_d_joint_pretrain_experts16_topk4/calvin_d_joint/checkpoint-90000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_ABC_EEF_D_JOINT_EXPERTS16_TOPK8_90K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_full_calvin_abc_eef_d_joint_from_scratch_experts16_topk8_8_30/checkpoints/calvin_abc_eef_d_joint_pretrain_experts16_topk8/calvin_d_joint/checkpoint-90000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_ABC_EEF_D_JOINT_EXPERTS32_TOPK4_90K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_full_calvin_abc_eef_d_joint_from_scratch_experts32_topk4_9_1/checkpoints/calvin_abc_eef_d_joint_pretrain_experts32_topk4/calvin_d_joint/checkpoint-90000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_ABC_EEF_D_JOINT_EXPERTS32_TOPK8_90K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_full_calvin_abc_eef_d_joint_from_scratch_experts32_topk8_9_1/checkpoints/calvin_abc_eef_d_joint_pretrain_experts32_topk8/calvin_d_joint/checkpoint-90000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_ABC_EEF_D_JOINT_EXPERTS64_TOPK4_90K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_experts128_topk4_9_7/moe_v2_full_calvin_abc_eef_d_joint_from_scratch_experts128_topk4_9_4/checkpoints/calvin_abc_eef_d_joint_pretrain_experts128_topk4/calvin_d_joint/checkpoint-90000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_ABC_EEF_D_JOINT_EXPERTS64_TOPK4_100K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_experts128_topk4_9_7/moe_v2_full_calvin_abc_eef_d_joint_from_scratch_experts128_topk4_9_4/checkpoints/calvin_abc_eef_d_joint_pretrain_experts128_topk4/calvin_d_joint/checkpoint-100000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_ABC_EEF_D_JOINT_EXPERTS16_TOPK8_100K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_full_calvin_abc_eef_d_joint_from_scratch_experts16_topk8_8_30/checkpoints/calvin_abc_eef_d_joint_pretrain_experts16_topk8/calvin_d_joint/checkpoint-100000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_ABC_EEF_D_JOINT_EXPERTS4_90K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_full_calvin_abc_eef_d_joint_from_scratch_experts4_8_22/checkpoints/calvin_abc_eef_d_joint_pretrain_experts_4/calvin_d_joint/checkpoint-90000/pytorch_model.pth",
    ),
    EnvMode.CALVIN_ABC_EEF_D_JOINT_EXPERTS2_90K: Checkpoint(
        config="calvin_d_joint",
        dir="/mnt/blob/task/moe_v2_full_calvin_abc_eef_d_joint_from_scratch_experts2_8_22/checkpoints/calvin_abc_eef_d_joint_pretrain_experts_2/calvin_d_joint/checkpoint-90000/pytorch_model.pth",
    ),


    EnvMode.CO_TRAINING_10K: Checkpoint(
        config="calvin_joint",
        dir="/mnt/blob/task/moe_v2_full_calvin_libero_aloha_sim_8_5/checkpoints/calvin_libero_aloha_sim/calvin_libero_aloha_sim/checkpoint-10000/pytorch_model.pth",
    ),



    # libero
    EnvMode.LIBERO_10_20k: Checkpoint(
        config="libero_10_no_noops_lerobot",
        dir="/mnt/blob/task/moe_v2_full_libero_10_no_noops_finetune_816/checkpoints/libero_10_no_noops_lerobot_finetune_bs64/libero/checkpoint-20000/pytorch_model.pth",
    ),

    EnvMode.LIBERO_object_20k: Checkpoint(
        config="libero_object_no_noops_lerobot",
        dir="/mnt/blob/task/moe_v2_full_libero_object_no_noops_finetune_816/checkpoints/libero_object_no_noops_lerobot_finetune_bs64/libero/checkpoint-20000/pytorch_model.pth"
    ),

    EnvMode.LIBERO_spatial_20k: Checkpoint(
        config="libero_spatial_no_noops_lerobot",
        dir="/mnt/blob/task/moe_v2_full_libero_spatial_no_noops_finetune_816/checkpoints/libero_spatial_no_noops_lerobot_finetune_bs64/libero/checkpoint-20000/pytorch_model.pth"
    ),
    EnvMode.LIBERO_spatial_25k: Checkpoint(
        config="libero_spatial_no_noops_lerobot",
        dir="/mnt/blob/task/moe_v2_full_libero_spatial_no_noops_finetune_816/checkpoints/libero_spatial_no_noops_lerobot_finetune_bs64/libero/checkpoint-25000/pytorch_model.pth"
    ),


    EnvMode.LIBERO_spatial_from_scratch_10k: Checkpoint(
        config="libero_spatial_no_noops_lerobot",
        dir="/mnt/blob/task/moe_v2_full_libero_spatial_no_noops_from_scratch_814/checkpoints/libero_spatial_no_noops_lerobot_from_scratch/libero/checkpoint-10000/pytorch_model.pth"
    ),
}


def create_default_policy(env: EnvMode, *, default_prompt: str | None = None) -> _policy.Policy:
    """Create a default policy for the given environment."""
    if checkpoint := DEFAULT_CHECKPOINT.get(env):
        return _policy_config.create_trained_policy(
            _config.get_config(checkpoint.config), checkpoint.dir, default_prompt=default_prompt
        )
    raise ValueError(f"Unsupported environment mode: {env}")


def create_policy(args: Args) -> _policy.Policy:
    """Create a policy from the given arguments."""
    match args.policy:
        case Checkpoint():
            return _policy_config.create_trained_policy(
                _config.get_config(args.policy.config), args.policy.dir, default_prompt=args.default_prompt
            )
        case Default():
            return create_default_policy(args.env, default_prompt=args.default_prompt)


def create_policy_for_simpler(config, ckpt_dir) -> _policy.Policy:
    """Create a policy from the given arguments."""
    return _policy_config.create_trained_policy(
        _config.get_config(config), ckpt_dir, default_prompt=None
    )


def main(args: Args) -> None:
    policy = create_policy(args)
    policy_metadata = policy.metadata

    # Record the policy's behavior.
    if args.record:
        policy = _policy.PolicyRecorder(policy, "policy_records")

    hostname = socket.gethostname()
    local_ip = socket.gethostbyname(hostname)
    logging.info("Creating server (host: %s, ip: %s)", hostname, local_ip)

    server = websocket_policy_server.WebsocketPolicyServer(
        policy=policy,
        host="0.0.0.0",
        port=args.port,
        metadata=policy_metadata,
    )
    server.serve_forever()


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO, force=True)
    main(tyro.cli(Args))
