from umfavi.experiments.config import ExperimentGrid, get_all_train_params

grid = ExperimentGrid(
    all_params=get_all_train_params(),
    base_config={
        
        # Environment
        "env_id": "CartPole-v1",

        # Shared training parameters
        "vis_every_n_epochs": None,

        # Data
        "subsample_factor": 1,
        "step_offset": 1,
        "obs_transform": None,
        "act_transform": "one_hot", # All current envs have discrete actions

        # Model architecture
        "reward_domain": "sa",
        "encoder_hidden_sizes": [256, 256],

        # Wandb
        "log_wandb": True,

        # Training parameters
        "num_epochs": 100,
        "lr": 1e-4,
        "batch_size": 128,
        "gamma": 0.999, # higher gamma to make long-term rewards more important
        
        # Validation
        "retrain_verbose": 0,
        "retrain_pbar": False,
        "log_every_n_steps": 50,
        "val_every_n_epochs": 5,
        "n_regret_samples": 1000,

        # optimal policy
        "optimal_policy_path": "~/umfavi/expert_policies/dqn/CartPole-v1_1/best_model.zip",
    }
)

# ============================================================================
# General training
# ============================================================================
grid.add("use_importance_weights", [True])

# ============================================================================
# Imitation learning parameters
# ============================================================================
grid.add("use_imitation_learning", [True, False])
is_imitation = lambda c: c.get("use_imitation_learning") == True

grid.add_conditional("td_error_weight", [0.0], condition=is_imitation)
grid.add_conditional("kl_weight", [0.0], condition=is_imitation)
grid.add_conditional("td_error_weight", [0.5, 1.0], condition=lambda c: not is_imitation(c))
grid.add_conditional("kl_weight", [0.5, 1.0], condition=lambda c: not is_imitation(c))

# ============================================================================
# Preference parameters
# ============================================================================
grid.add("n_pref_samples", [0, 256, 512])

has_pref = lambda c: c.get("n_pref_samples", 0) > 0
grid.add_conditional("n_pref_episodes", [1024], condition=has_pref)

grid.add_conditional("pref_seg_len", [32], condition=has_pref)
grid.add_conditional("pref_trajectory_rationality", [5.0], condition=has_pref)
grid.add_conditional("pref_rationality", [5.0], condition=has_pref)

# Expert policy path
grid.add_conditional("pref_policy_path", ["~/umfavi/expert_policies/dqn/CartPole-v1_1/best_model.zip"], condition=has_pref)

# ============================================================================
# Rating parameters
# ============================================================================
grid.add("n_rating_samples", [0, 32, 64, 128])

has_rating = lambda c: c.get("n_rating_samples", 0) > 0
grid.add_conditional("n_rating_episodes", [1024], condition=has_rating)

grid.add_conditional("rating_seg_len", [32], condition=has_rating)
grid.add_conditional("rating_trajectory_rationality", [1.0, 5.0], condition=has_rating)

# Expert policy paths for all environments except grid
grid.add_conditional("rating_policy_path", ["~/umfavi/expert_policies/dqn/CartPole-v1_1/best_model.zip"], condition=has_rating)

# ============================================================================
# Demonstration parameters
# ============================================================================
grid.add("n_demo_samples", [0, 1, 2, 4])

# Expert policy paths for environments except grid
has_demo = lambda c: c.get("n_demo_samples", 0) > 0
grid.add_conditional("demo_policy_path", ["~/umfavi/expert_policies/dqn/CartPole-v1_1/best_model.zip"], has_demo)

# Add demo rationality
grid.add_conditional("demo_rationality", [5.0], condition=has_demo)

# ============================================================================
# Stop parameters
# ============================================================================
grid.add("n_stop_samples", [0, 64, 128])

# Expert policy paths for environments except grid
has_stop = lambda c: c.get("n_stop_samples", 0) > 0
grid.add_conditional("stop_policy_path", ["~/umfavi/expert_policies/dqn/CartPole-v1_1/best_model.zip"], has_stop)
grid.add_conditional("stop_q_value_model", ["~/umfavi/expert_policies/dqn/CartPole-v1_1/best_model.zip"], has_stop)
grid.add_conditional("n_stop_episodes", [1024], condition=has_stop)
grid.add_conditional("stop_seg_len", [32], condition=has_stop)
grid.add_conditional("stop_c", [1.0], condition=has_stop)
grid.add_conditional("stop_regret_percentile", [50.0], condition=has_stop)
grid.add_conditional("stop_regret_discount", [0.1], condition=has_stop)
grid.add_conditional("stop_trajectory_rationality", [1.0], condition=has_stop)

# ============================================================================
# Filtering invalid configurations
# ============================================================================

num_active_fb_types = lambda c: sum([has_pref(c), has_demo(c), has_rating(c), has_stop(c)])

# Only single, pairs or all combined
grid.add_validator(lambda c: num_active_fb_types(c) == 1 or num_active_fb_types(c) == 2 or num_active_fb_types(c) == 4)

# Imitation learning must only have demonstration feedback with 32 samples
grid.add_validator(lambda c:
    not is_imitation(c) or
    (has_demo(c) and not has_pref(c) and not has_rating(c) and not has_stop(c))
)

if __name__ == "__main__":
    # Print summary when run directly
    print(grid.summary(seeds=3))