"""
Experiment grid for studying the effect of different feedback compositions.

This grid sweeps over:
- Different environments
- Different mixtures of preference and demonstration feedback
- Different hyperparameters (kl_weight, td_error_weight)

Usage:
    python -m umfavi.experiments.cli add-grid feedback_mix --seeds 5
"""

from umfavi.experiments.config import ExperimentGrid


# Define the experiment grid
grid = ExperimentGrid(
    base_config={
        # Training parameters
        "num_epochs": 1000,
        "batch_size": 128,
        "lr": 1e-4,
        "val_every_n_epochs": 50,
        "vis_every_n_epochs": None,  # Disable visualization to save time
        "log_every_n_steps": 50,
        "skip_first_val_epoch": True,
        
        # Model architecture
        "encoder_hidden_sizes": [256, 256, 256],
        "q_value_hidden_sizes": [256, 256, 256],
        "reward_domain": "s",
        
        # Environment defaults
        "gamma": 0.99,
        
        # Wandb
        "log_wandb": True,
        "wandb_project": "umfavi-experiments",

        # no one hot encoding for cont. actions
        "act_transform": None, # All current envs have discrete actions
    }
)

# Environment sweep
grid.add("env_id", [
    "HalfCheetah-v4"
])

# Feedback composition sweep
# Each tuple is (n_pref_samples, n_demo_samples)
# We test: preferences only, demos only, and mixed
grid.add("n_pref_samples", [1024])
grid.add("n_demo_samples", [10])

has_pref_samples = lambda c: c.get("n_pref_samples", 0) > 0
grid.add_conditional("n_pref_episodes", [1024], condition=has_pref_samples)

# Preference-specific parameters (only when preferences are used)
grid.add_conditional(
    "pref_seg_len",
    [64],
    condition=lambda c: c.get("n_pref_samples", 0) > 0
)

grid.add_conditional(
    "pref_trajectory_rationality",
    [1.0, 5.0],
    condition=lambda c: c.get("n_pref_samples", 0) > 0
)

# Demonstration-specific parameters (only when demos are used)
grid.add_conditional(
    "demo_rationality",
    [float("inf")],  # Use optimal demonstrations
    condition=lambda c: c.get("n_demo_samples", 0) > 0
)

# Loss weighting hyperparameters
grid.add("kl_weight", [1.0])
grid.add("td_error_weight", [1.0])

grid.add("pref_policy_path", ["~/umfavi/umfavi/expert_policies/ppo/HalfCheetah-v4_2/best_model.zip"])
grid.add("demo_policy_path", ["~/umfavi/umfavi/expert_policies/ppo/HalfCheetah-v4_2/best_model.zip"])
grid.add("optimal_policy_path", ["~/umfavi/umfavi/expert_policies/ppo/HalfCheetah-v4_2/best_model.zip"])


# ============================================================================
# Additional grid definitions for specific studies
# ============================================================================

def create_minimal_grid() -> ExperimentGrid:
    """
    Create a minimal grid for quick testing.
    
    Useful for verifying the pipeline works before running full experiments.
    """
    return ExperimentGrid(
        base_config={
            "num_epochs": 100,
            "batch_size": 64,
            "lr": 1e-4,
            "val_every_n_epochs": 25,
            "log_wandb": False,
            "encoder_hidden_sizes": [64, 64],
            "q_value_hidden_sizes": [64, 64],
        }
    ).add("env_id", ["CartPole-v1"]) \
     .add("n_pref_samples", [100]) \
     .add("n_demo_samples", [50]) \
     .add("pref_seg_len", [64]) \
     .add("kl_weight", [1.0]) \
     .add("td_error_weight", [1.0])


def create_feedback_ratio_grid() -> ExperimentGrid:
    """
    Grid focused on studying the ratio of preference to demonstration feedback.
    
    Keeps total feedback budget roughly constant while varying the mix.
    """
    return ExperimentGrid(
        base_config={
            "num_epochs": 1000,
            "batch_size": 128,
            "lr": 1e-4,
            "val_every_n_epochs": 50,
            "kl_weight": 1.0,
            "td_error_weight": 1.0,
            "pref_seg_len": 128,
            "log_wandb": True,
            "wandb_project": "umfavi-feedback-ratio",
        }
    ).add("env_id", ["LunarLander-v3"]) \
     .add("n_pref_samples", [0, 100, 200, 300, 400, 500]) \
     .add_conditional(
         "n_demo_samples",
         [250, 200, 150, 100, 50, 0],  # Inverse of pref samples
         condition=lambda c: True  # Apply to all
     )

grid.add_validator(lambda c: c.get("n_pref_samples", 0) > 0 or c.get("n_demo_samples", 0) > 0)


if __name__ == "__main__":
    # Print summary when run directly
    print(grid.summary(seeds=5))

