import argparse

def create_parser() -> argparse.ArgumentParser:
    """Create the argument parser with all training arguments."""
    parser = argparse.ArgumentParser()

    # Reproducibility
    parser.add_argument("--seed", type=int, default=0, help="Global seed")

    # Reward model parameters
    parser.add_argument("--reward_domain", type=str, default="s", help="Either state-only ('s'), state-action ('sa'), state-action-next-state ('sas')")
    parser.add_argument("--td_error_weight", type=float, default=1.0, help="Weight for TD-error constraint in demonstrations")
    parser.add_argument("--gamma", type=float, default=0.99, help="Discount factor")
    
    # Dataset parameters
    parser.add_argument("--step_offset", type=int, default=1, help="Offset applied to next_obs and next_act")
    parser.add_argument("--subsample_factor", type=int, default=1, help="Subsample factor for demonstrations")

    # Optimal policy
    parser.add_argument("--optimal_policy_path", type=str, default=None, help="Path to optimal policy for generating evaluations")

    # Preferences
    parser.add_argument("--n_pref_episodes", type=int, default=0, help="Number of episodes to extract preference segments from")
    parser.add_argument("--n_pref_samples", type=int, default=0, help="Number of preference samples (0 to disable)")
    parser.add_argument("--pref_policy_path", type=str, default="expert_policies/ppo/LunarLander-v3_1/best_model.zip", help="Path to preference policy")
    parser.add_argument("--pref_trajectory_rationality", type=float, default=0.5, help="Rationality of the expert policy generating the comparison trajectories")
    parser.add_argument("--pref_rationality", type=float, default=1.0, help="Rationality for Bradley-Terry model")
    parser.add_argument("--pref_seg_len", type=lambda x: None if x.lower() == "none" else int(x), default=128, help="Length of the extracted segments from episodes. None for full episodes")
    parser.add_argument("--min_reward_pref", type=lambda x: None if x.lower() == "none" else float(x), default=None, help="If not 'None', preferences with a lower reward will be rejected if their cumulative reward is lower than this threshold")

    # Demonstrations
    parser.add_argument("--n_demo_samples", type=int, default=0, help="Number of demonstration samples (0 to disable)")
    parser.add_argument("--demo_policy_path", type=str, default="expert_policies/dqn/LunarLander-v3_1/best_model.zip", help="Path to demonstration policy")
    parser.add_argument("--min_reward_demo", type=lambda x: None if x.lower() == "none" else float(x), default=None, help="If not 'None', demonstrations with a lower reward will be rejected if their cumulative reward is lower than this threshold")
    parser.add_argument("--demo_rationality", type=float, default=float("inf"), help="Rationality for expert policy")

    # Ratings
    parser.add_argument("--n_rating_episodes", type=int, default=0, help="Number of episodes to extract rating segments from")
    parser.add_argument("--n_rating_samples", type=int, default=0, help="Number of rating samples (0 to disable)")
    parser.add_argument("--rating_policy_path", type=str, default="expert_policies/ppo/LunarLander-v3_1/best_model.zip", help="Path to rating policy")
    parser.add_argument("--rating_trajectory_rationality", type=float, default=5.0, help="Rationality of the expert policy generating rated trajectories")
    parser.add_argument("--rating_seg_len", type=lambda x: None if x.lower() == "none" else int(x), default=128, help="Length of the extracted segments for ratings. None for full episodes")
    parser.add_argument("--min_reward_rating", type=lambda x: None if x.lower() == "none" else float(x), default=None, help="If not 'None', ratings with a lower reward will be rejected")

    # Rankings (Plackett-Luce)
    parser.add_argument("--n_ranking_episodes", type=int, default=0, help="Number of episodes to extract ranking segments from")
    parser.add_argument("--n_ranking_samples", type=int, default=0, help="Number of ranking samples (0 to disable)")
    parser.add_argument("--ranking_policy_path", type=str, default="expert_policies/ppo/LunarLander-v3_1/best_model.zip", help="Path to ranking policy")
    parser.add_argument("--ranking_trajectory_rationality", type=float, default=0.5, help="Rationality of the expert policy generating ranked trajectories")
    parser.add_argument("--ranking_rationality", type=float, default=1.0, help="Rationality (beta) for Plackett-Luce model")
    parser.add_argument("--ranking_seg_len", type=lambda x: None if x.lower() == "none" else int(x), default=128, help="Length of the extracted segments for rankings. None for full episodes")
    parser.add_argument("--num_ranked_items", type=int, default=4, help="Number of items (k) per ranking")
    parser.add_argument("--min_reward_ranking", type=lambda x: None if x.lower() == "none" else float(x), default=None, help="If not 'None', rankings with a lower reward will be rejected")

    # Stops
    parser.add_argument("--n_stop_samples", type=int, default=0, help="Number of stop samples (0 to disable)")
    parser.add_argument("--stop_policy_path", type=str, default="expert_policies/ppo/LunarLander-v3_1/best_model.zip", help="Path to stop policy")
    parser.add_argument("--stop_trajectory_rationality", type=float, default=0.5, help="Rationality of the expert policy generating the comparison trajectories")
    parser.add_argument("--stop_c", type=float, default=1.0, help="Calibration constant for lambda (higher = more aggressive stopping)")
    parser.add_argument("--stop_regret_percentile", type=float, default=75.0, help="Percentile of final regrets to use as reference")
    parser.add_argument("--stop_regret_discount", type=float, default=0.8, help="Discount factor for old regret (0-1). Lower = faster forgetting.")
    parser.add_argument("--min_reward_stop", type=lambda x: None if x.lower() == "none" else float(x), default=None, help="If not 'None', stops with a lower reward will be rejected if their cumulative reward is lower than this threshold")
    parser.add_argument("--n_stop_episodes", type=int, default=0, help="Number of episodes to extract stop segments from")
    parser.add_argument("--stop_seg_len", type=lambda x: None if x.lower() == "none" else int(x), default=128, help="Length of the extracted segments for stops. None for full episodes")
    parser.add_argument("--stop_q_value_model", type=str, default=None, help="Path to q-value model used to estimate immediate regret in dataset generation. Does not need to be provided for tabular environments.")

    # Training parameters
    parser.add_argument("--num_epochs", type=int, default=2000)
    parser.add_argument("--val_every_n_epochs", type=lambda x: None if x.lower() == "none" else int(x), default=None)
    parser.add_argument("--vis_every_n_epochs", type=lambda x: None if x.lower() == "none" else int(x), default=None)
    parser.add_argument("--retrain_verbose", type=int, default=1, help="Verbosity of the retraining process")
    parser.add_argument("--retrain_pbar", type=bool, default=True, help="Show progress bar during retraining")
    parser.add_argument("--batch_size", type=int, default=128)
    parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate for model")
    parser.add_argument("--kl_weight", type=float, default=1.0, help="KL weight")
    parser.add_argument("--encoder_hidden_sizes", type=int, nargs="+", default=[256, 256, 256], help="Hidden sizes for encoder MLP")
    parser.add_argument("--skip_first_val_epoch", action="store_true", help="Skip the first validation epoch")
    parser.add_argument("--retrain_reward_thresh", type=lambda x: None if x.lower() == "none" else float(x), default=None, help="Stop regret computation early if mean true reward goes below this threshold")
    parser.add_argument("--use_importance_weights", action="store_true", help="Reweight loss components by importance weights derived from dataset sizes")
    parser.add_argument("--model_save_dir", type=str, default=None, help="Directory to save models and policies")
    parser.add_argument("--save_behavior", type=str, choices=["best", "all"], default="best", help="Save behavior: 'best' saves only when val loss improves, 'all' saves at every eval epoch")
    parser.add_argument("--n_regret_samples", type=int, default=1000, help="Number of samples for regret computation")
    
    # Environment parameters
    parser.add_argument("--grid_size", type=int, default=10)
    parser.add_argument("--env_id", type=str, default="CartPole-v1")
    parser.add_argument("--p_rand", type=float, default=0.0, help="Randomness in transitions (0 for deterministic)")
    parser.add_argument("--obs_transform", choices=["one_hot", "continuous_coordinate", "dct", None], default=None, help="Apply a transform to the observation space")
    parser.add_argument("--act_transform", choices=["one_hot", None], default="one_hot", help="Apply a transform to the action space")
        
    # Wandb parameters
    parser.add_argument("--log_wandb", action="store_true", help="Log to weights and biases")
    parser.add_argument("--log_every_n_steps", type=int, default=10, help="Log every n steps")
    parser.add_argument("--wandb_project", type=str, default="var-rew-learning", help="Wandb project name")
    parser.add_argument("--wandb_run_name", type=str, default=None, help="Custom wandb run name")
    parser.add_argument("--wandb_log_dir", type=str, default="wandb", help="Wandb log directory")

    # Baselines
    parser.add_argument("--use_imitation_learning", action="store_true", help="Use imitation learning")
    
    return parser