import dataclasses
import enum
import logging
import socket
import inspect

import tyro

from openpi.policies import policy as _policy
from openpi.policies import policy_config as _policy_config
from openpi.serving import websocket_policy_server
from openpi.training import config as _config
from openpi.models import model as _model
from openpi.models import pi0_config
from openpi.shared import path_utils

@dataclasses.dataclass(frozen=True)
class ComparatorOfflineTrainConfig(_config.TrainConfig):

    name: str = "pi05_comparator_online_pairs"
    model: _model.BaseModelConfig = dataclasses.field(
        default_factory=lambda: pi0_config.Pi0Config(
            enable_action_comparator=True,
            pi05=True,
            action_horizon=10,
            discrete_state_input=False,
        )
    )

    # 数据集配置（用于加载Observation；可命令行覆盖）
    data: _config.DataConfigFactory = dataclasses.field(
        default_factory=lambda: _config.LeRobotLiberoDataConfig(
            repo_id=path_utils.env_path(
                "OPENPI_LIBERO_ONLINE_REPO_ID",
                default="physical-intelligence/libero",
            ),
            base_config=_config.DataConfig(prompt_from_task=True),
            extra_delta_transform=False,       
        )
    )


@dataclasses.dataclass(frozen=True)
class ComparatorOfflineTrainConfigPi0(_config.TrainConfig):
    """使用离线构造好的优劣动作对进行比较器训练（样本级shuffle且与obs严格对齐）。

    需要 pairs_dir 下存在 build_action_pairs.py 生成的 shard_*.npz 与 manifest.json。
    """

    name: str = "pi0_comparator_online_pairs"
    model: _model.BaseModelConfig = dataclasses.field(
        default_factory=lambda: pi0_config.Pi0Config(
            enable_action_comparator=True,
            pi05=False,
            action_horizon=10,
            discrete_state_input=False,
        )
    )

    # 数据集配置（用于加载Observation；可命令行覆盖）
    data: _config.DataConfigFactory = dataclasses.field(
        default_factory=lambda: _config.LeRobotLiberoDataConfig(
            repo_id=path_utils.env_path(
                "OPENPI_LIBERO_ONLINE_REPO_ID",
                default="physical-intelligence/libero",
            ),
            base_config=_config.DataConfig(prompt_from_task=True),
            extra_delta_transform=False,       
        )
    )

class EnvMode(enum.Enum):
    """Supported environments."""

    ALOHA = "aloha"
    ALOHA_SIM = "aloha_sim"
    DROID = "droid"
    LIBERO = "libero"


@dataclasses.dataclass
class Checkpoint:
    """Load a policy from a trained checkpoint."""

    # Training config name (e.g., "pi0_aloha_sim").
    config: str
    # Checkpoint directory (e.g., "checkpoints/pi0_aloha_sim/exp/10000").
    dir: str


@dataclasses.dataclass
class Default:
    """Use the default policy for the given environment."""


@dataclasses.dataclass
class Args:
    """Arguments for the serve_policy script."""

    # Environment to serve the policy for. This is only used when serving default policies.
    env: EnvMode = EnvMode.ALOHA_SIM

    # If provided, will be used in case the "prompt" key is not present in the data, or if the model doesn't have a default
    # prompt.
    default_prompt: str | None = None

    # Port to serve the policy on.
    port: int = 8010
    # Record the policy's behavior for debugging.
    record: bool = False

    # Compare-and-select server defaults (optional)
    compare_and_select: bool = False
    default_num_candidates: int | None = None
    ttrl: bool = False
    ttrl_pi05_checkpoint_dir: str | None = dataclasses.field(
        default_factory=lambda: path_utils.env_path("OPENPI_TTRL_PI05_CHECKPOINT_DIR") or None
    )
    ttrl_pi0_checkpoint_dir: str | None = dataclasses.field(
        default_factory=lambda: path_utils.env_path("OPENPI_TTRL_PI0_CHECKPOINT_DIR") or None
    )

    # Specifies how to load the policy. If not provided, the default policy for the environment will be used.
    policy: Checkpoint | Default = dataclasses.field(default_factory=Default)


# Default checkpoints that should be used for each environment.
DEFAULT_CHECKPOINT: dict[EnvMode, Checkpoint] = {
    EnvMode.ALOHA: Checkpoint(
        config="pi05_aloha",
        dir="gs://openpi-assets/checkpoints/pi05_base",
    ),
    EnvMode.ALOHA_SIM: Checkpoint(
        config="pi0_aloha_sim",
        dir="gs://openpi-assets/checkpoints/pi0_aloha_sim",
    ),
    EnvMode.DROID: Checkpoint(
        config="pi05_droid",
        dir="gs://openpi-assets/checkpoints/pi05_droid",
    ),
    EnvMode.LIBERO: Checkpoint(
        config="pi05_libero",
        dir=path_utils.env_path(
            "OPENPI_LIBERO_POLICY_DIR",
            default="gs://openpi-assets/checkpoints/pi05_libero",
        ),
    ),
}


def create_default_policy(env: EnvMode, *, default_prompt: str | None = None) -> _policy.Policy:
    """Create a default policy for the given environment."""
    if checkpoint := DEFAULT_CHECKPOINT.get(env):
        return _policy_config.create_trained_policy(
            _config.get_config(checkpoint.config), checkpoint.dir, default_prompt=default_prompt
        )
    raise ValueError(f"Unsupported environment mode: {env}")


def create_policy(args: Args) -> _policy.Policy:
    """Create a policy from the given arguments."""
    match args.policy:
        case Checkpoint():
            return _policy_config.create_trained_policy(
                _config.get_config(args.policy.config), args.policy.dir, default_prompt=args.default_prompt
            )
        case Default():
            return create_default_policy(args.env, default_prompt=args.default_prompt)

def create_ttrl_policy(args: Args) -> _policy.Policy:
    """Create a TTRL policy from the given arguments."""
    ckpt_dir = path_utils.require_path(
        args.ttrl_pi05_checkpoint_dir,
        description="TTRL pi05 checkpoint 目录",
        env_vars=("OPENPI_TTRL_PI05_CHECKPOINT_DIR",),
        cli_flag="--ttrl-pi05-checkpoint-dir",
    )
    return _policy_config.create_trained_policy(
        ComparatorOfflineTrainConfig(), ckpt_dir, default_prompt=args.default_prompt
    )

def create_ttrl_policy_pi0(args: Args) -> _policy.Policy:
    """Create a TTRL policy from the given arguments."""
    ckpt_dir = path_utils.require_path(
        args.ttrl_pi0_checkpoint_dir,
        description="TTRL pi0 checkpoint 目录",
        env_vars=("OPENPI_TTRL_PI0_CHECKPOINT_DIR",),
        cli_flag="--ttrl-pi0-checkpoint-dir",
    )
    return _policy_config.create_trained_policy(
        ComparatorOfflineTrainConfigPi0(), ckpt_dir, default_prompt=args.default_prompt
    )


def main(args: Args) -> None:
    if args.ttrl:
        policy = create_ttrl_policy_pi0(args)
    else:
        policy = create_policy(args)
    policy_metadata = policy.metadata

    # Record the policy's behavior.
    if args.record:
        policy = _policy.PolicyRecorder(policy, "policy_records")

    hostname = socket.gethostname()
    local_ip = socket.gethostbyname(hostname)
    logging.info("Creating server (host: %s, ip: %s)", hostname, local_ip)

    # Be tolerant to older server signatures that may not accept compare_and_select/default_num_candidates
    server_init_sig = inspect.signature(websocket_policy_server.WebsocketPolicyServer.__init__)
    server_kwargs = {
        "policy": policy,
        "host": "0.0.0.0",
        "port": args.port,
        "metadata": policy_metadata,
    }
    if "compare_and_select" in server_init_sig.parameters:
        server_kwargs["compare_and_select"] = args.compare_and_select
    if "default_num_candidates" in server_init_sig.parameters:
        server_kwargs["default_num_candidates"] = args.default_num_candidates

    server = websocket_policy_server.WebsocketPolicyServer(**server_kwargs)
    server.serve_forever()


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO, force=True)
    main(tyro.cli(Args))
