from umfavi.experiments.config import ExperimentGrid, get_all_transfer_params

grid = ExperimentGrid(
    all_params=get_all_transfer_params(),
    base_config={
        
        # Environment
        "env_id": "grid_trap",

        # Data
        "obs_transform": "one_hot",
        "act_transform": "one_hot", # All current envs have discrete actions

        # Model architecture
        "reward_domain": "s",
        "encoder_hidden_sizes": [64, 64],

        # Wandb
        "log_wandb": True,
        "wandb_project": "transfer_grid_trap",

        # Training parameters
        "gamma": 0.95,
    }
)

# ============================================================================
# Environment perturbations (use "env_params." prefix for env-specific args)
# ============================================================================

# Grid trap p_rand perturbation
grid.add("env_params.grid_size", [10])
grid.add("env_params.gamma", [0.95])
grid.add("env_params.p_rand", [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])


# ============================================================================
# Feedback combinations and reward model paths
# ============================================================================
grid.add("feedback_combo", [
    'demo+pref',
    'demo+pref+rating',
    'demo+rating',
    'demo_only',
    'pref+rating',
    'pref_only',
    'rating_only',
    'imitation',
])


grid.add_conditional("fb_model_path", [
    '<path_to_best_model>', # paths to models will be provided at acceptance

], condition=lambda c: c.get("feedback_combo") == "demo+pref")

grid.add_conditional("fb_model_path", [
    '<path_to_best_model>', # paths to models will be provided at acceptance
], condition=lambda c: c.get("feedback_combo") == "demo+pref+rating+stop")

grid.add_conditional("fb_model_path", [
    '<path_to_best_model>', # paths to models will be provided at acceptance
], condition=lambda c: c.get("feedback_combo") == "demo+rating")

grid.add_conditional("fb_model_path", [
    '<path_to_best_model>', # paths to models will be provided at acceptance
], condition=lambda c: c.get("feedback_combo") == "demo+stop")

grid.add_conditional("fb_model_path", [
    '<path_to_best_model>', # paths to models will be provided at acceptance
], condition=lambda c: c.get("feedback_combo") == "demo_only")

grid.add_conditional("fb_model_path", [
    '<path_to_best_model>', # paths to models will be provided at acceptance
], condition=lambda c: c.get("feedback_combo") == "imitation")

grid.add_conditional("fb_model_path", [
    '<path_to_best_model>', # paths to models will be provided at acceptance
], condition=lambda c: c.get("feedback_combo") == "pref+rating")

grid.add_conditional("fb_model_path", [
    '<path_to_best_model>', # paths to models will be provided at acceptance
], condition=lambda c: c.get("feedback_combo") == "pref+stop")

grid.add_conditional("fb_model_path", [
    '<path_to_best_model>', # paths to models will be provided at acceptance
], condition=lambda c: c.get("feedback_combo") == "pref_only")

grid.add_conditional("fb_model_path", [
    '<path_to_best_model>', # paths to models will be provided at acceptance
], condition=lambda c: c.get("feedback_combo") == "rating+stop")

grid.add_conditional("fb_model_path", [
    '<path_to_best_model>', # paths to models will be provided at acceptance
], condition=lambda c: c.get("feedback_combo") == "rating_only")

grid.add_conditional("fb_model_path", [
    '<path_to_best_model>', # paths to models will be provided at acceptance
], condition=lambda c: c.get("feedback_combo") == "stop_only")


# Differentiate between modes
grid.add_conditional("mode", ["reward_model"], condition=lambda c: c.get("feedback_combo") != "imitation")
grid.add_conditional("mode", ["imitation"], condition=lambda c: c.get("feedback_combo") == "imitation")

# ============================================================================
# Validation
# ============================================================================

# Ensure fb_model_path is specified
grid.add_validator(lambda c: c.get("fb_model_path") is not None)