from umfavi.experiments.config import ExperimentGrid, get_all_train_params

grid = ExperimentGrid(
    all_params=get_all_train_params(),
    base_config={
        
        # Environment
        "env_id": "grid_trap",
        "grid_size": 10,

        # Shared training parameters
        "vis_every_n_epochs": None,

        # Data
        "subsample_factor": 1,
        "step_offset": 1,
        "obs_transform": "one_hot",
        "act_transform": "one_hot", # All current envs have discrete actions

        # Model architecture
        "reward_domain": "s",
        "encoder_hidden_sizes": [64, 64],

        # Wandb
        "log_wandb": True,

        # Training parameters
        "num_epochs": 1000,
        "lr": 5e-4,
        "batch_size": 32,
        "gamma": 0.95,
        
        # Validation
        "retrain_verbose": 0,
        "retrain_pbar": False,
        "log_every_n_steps": 50,
        "val_every_n_epochs": 1,
    }
)

# ============================================================================
# General training
# ============================================================================
grid.add("use_importance_weights", [True])

# ============================================================================
# Imitation learning parameters
# ============================================================================
grid.add("use_imitation_learning", [True, False])
is_imitation = lambda c: c.get("use_imitation_learning") == True

grid.add_conditional("td_error_weight", [0.0], condition=is_imitation)
grid.add_conditional("kl_weight", [0.0], condition=is_imitation)
grid.add_conditional("td_error_weight", [1.0], condition=lambda c: not is_imitation(c))
grid.add_conditional("kl_weight", [1.0], condition=lambda c: not is_imitation(c))

# ============================================================================
# Preference parameters
# ============================================================================
grid.add("n_pref_samples", [0, 64, 128, 256])

has_pref = lambda c: c.get("n_pref_samples", 0) > 0
grid.add_conditional("n_pref_episodes", [256], condition=has_pref)

grid.add_conditional("pref_seg_len", [10], condition=has_pref)
grid.add_conditional("pref_trajectory_rationality", [0.0], condition=has_pref)
grid.add_conditional("pref_rationality", [5.0], condition=has_pref)

# ============================================================================
# Rating parameters
# ============================================================================
grid.add("n_rating_samples", [0, 32, 64, 128])

has_rating = lambda c: c.get("n_rating_samples", 0) > 0
grid.add_conditional("n_rating_episodes", [256], condition=has_rating)

grid.add_conditional("rating_seg_len", [10], condition=has_rating)
grid.add_conditional("rating_trajectory_rationality", [0.0], condition=has_rating)

# ============================================================================
# Demonstration parameters
# ============================================================================
grid.add("n_demo_samples", [0, 1, 2])
has_demo = lambda c: c.get("n_demo_samples", 0) > 0

# Add demo rationality
grid.add_conditional("demo_rationality", [10.0], condition=has_demo)

# ============================================================================
# Stop parameters
# ============================================================================
grid.add("n_stop_samples", [0, 128, 256])

has_stop = lambda c: c.get("n_stop_samples", 0) > 0

grid.add_conditional("n_stop_episodes", [256], condition=has_stop)
grid.add_conditional("stop_seg_len", [10], condition=has_stop)
grid.add_conditional("stop_regret_percentile", [50.0], condition=has_stop)
grid.add_conditional("stop_regret_discount", [0.1], condition=has_stop)
grid.add_conditional("stop_c", [2.0], condition=has_stop)
grid.add_conditional("stop_trajectory_rationality", [0.0], condition=has_stop)

# ============================================================================
# Filtering invalid configurations
# ============================================================================

num_active_fb_types = lambda c: sum([has_pref(c), has_demo(c), has_rating(c), has_stop(c)])

# Only single, pairs or all combined
grid.add_validator(lambda c: num_active_fb_types(c) == 1 or num_active_fb_types(c) == 2 or num_active_fb_types(c) == 4)

# Imitation learning must only have demonstration feedback with 32 samples
grid.add_validator(lambda c:
    not is_imitation(c) or
    (has_demo(c) and not has_pref(c) and not has_rating(c) and not has_stop(c))
)

if __name__ == "__main__":
    # Print summary when run directly
    print(grid.summary(seeds=3))