"""
Defines
- a parser that uses default configs, and overrides with command line arguments.
- a function to print the current configuration in a readable format.
TODO: this is unused for now, will be used once we have default configs & ship the code

plan:
- have configs in configs/ as yaml files
- subfolders organise the configs.
    - configs/special/ contains default_starmdp.yaml (default with few seeds), debug_starmdp.yaml (tiny to check if code runs)
    - configs/paper/ contains finalised configs (never touched)
    - configs/exps/ contains configs for experiments that are run, and they're named {run_id}.yaml
- the parser should first check if the user wants to load a specific config with (-cn CONFIG_NAME) where CONFIG_NAME is e.g. special/debug_starmdp or exps/123
- if no config is specified in CLI, it should check if the user specified an environment with (-env ENV_NAME) and load that environment's default config
- if no environment is specified in CLI, load a global default, e.g. configs/special/default_starmdp.yaml
- finally, override the loaded params dict with command line arguments
"""

import argparse
import yaml
import os


def str_to_bool(v):
    """Convert string representations of truth to True or False."""
    if isinstance(v, bool):
        return v
    if v.lower() in ("yes", "true", "t", "y", "1"):
        return True
    elif v.lower() in ("no", "false", "f", "n", "0"):
        return False
    else:
        raise argparse.ArgumentTypeError(f"Boolean value expected, got: {v}")


def parse_args(args=None):
    """Parse command line arguments and load configuration."""
    parser = argparse.ArgumentParser(description="Run preference-based RL experiments")

    # Config selection arguments
    parser.add_argument(
        "-cn",
        "--config-name",
        type=str,
        default=None,
        help="Config name to load (e.g., 'special/debug_starmdp' or 'exps/123')",
    )
    parser.add_argument(
        "-env",
        "--environment",
        type=str,
        default=None,
        choices=["starmdp", "gridworld"],
        help="Environment short (!!) name to load default config for. Only used if no config name specified.",
    )

    # Broad experiment params
    parser.add_argument(
        "--N_experiments",
        "--seeds",
        type=int,
        dest="N_experiments",
        help="Number of experiments (seeds)",
    )
    parser.add_argument("--N_iterations", type=int, help="Online iterations per seed")
    parser.add_argument("--episode_length", type=int, help="Episode length")
    parser.add_argument("--env_move_prob", type=float, help="Environment movement probability")
    parser.add_argument(
        "--phi_name",
        type=str,
        choices=["state_counts", "id_short", "id_long", "final_state"],
        help="Embedding function",
    )
    parser.add_argument("--do_offline_BC", type=str_to_bool, help="Use offline behavioral cloning")
    parser.add_argument("--N_offline_trajs", type=int, help="Number of offline trajectories")

    # Offline learning parameters
    parser.add_argument("--delta_offline", type=float, help="Offline delta parameter")
    parser.add_argument("--N_confset_size", type=int, help="Number of initial policies sampled")
    parser.add_argument(
        "--which_confset_construction_method",
        type=str,
        choices=["noise-matrices", "rejection-sampling"],
        help="Confidence set construction method",
    )
    parser.add_argument(
        "--which_hellinger_calc",
        type=str,
        choices=["exact", "approx"],
        help="Hellinger distance calculation method",
    )
    parser.add_argument(
        "--n_transition_model_epochs_offline",
        type=int,
        help="Number of transition model epochs offline",
    )
    parser.add_argument(
        "--offlineradius_formula",
        type=str,
        choices=[
            "full",
            "ignore_bracket",
            "only_alpha",
            "hardcode_radius_scaled",
            "hardcode_radius",
        ],
        help="Offline radius formula",
    )
    parser.add_argument(
        "--offlineradius_override_value", type=float, help="Offline radius override value"
    )
    parser.add_argument(
        "--replace_mle_with_optimal_policy_in_offline_confset",
        type=str_to_bool,
        help="Replace MLE with optimal policy in offline confidence set",
    )

    # Online learning parameters
    parser.add_argument("--N_rollouts", type=int, help="Number of rollouts per iteration")
    parser.add_argument("--delta_online", type=float, help="Online delta parameter")
    parser.add_argument("--W", type=int, help="W parameter")
    parser.add_argument("--w_MLE_epochs", type=int, help="MLE epochs for w")
    parser.add_argument("--w_initialization", type=str, help="Weight initialization method")
    parser.add_argument("--w_sigmoid_slope", type=float, help="Sigmoid slope for weights")
    parser.add_argument(
        "--xi_formula", type=str, choices=["full", "smaller_start"], help="Xi formula"
    )
    parser.add_argument(
        "--n_transition_model_epochs_online",
        type=int,
        help="Number of transition model epochs online",
    )
    parser.add_argument(
        "--online_confset_recalc_phi",
        type=str_to_bool,
        help="Recalculate phi in online confidence set",
    )
    parser.add_argument(
        "--online_confset_bonus_multiplier",
        type=float,
        help="Online confidence set bonus multiplier",
    )
    parser.add_argument(
        "--use_true_T_in_online",
        type=str_to_bool,
        help="Use true transition matrix in online learning",
    )
    parser.add_argument("--gamma_t_hardcoded_value", type=float, help="Hardcoded gamma_t value")
    parser.add_argument(
        "--baseline_search_space",
        type=str,
        choices=["all_policies", "random_sample", "augmented_ball"],
        help="Baseline search space. 'all_policies', or 'random_sample' (of size N_confset_size), or 'augmented_ball' (augmenting BRIDGE's ball to N_confset_size's size with random policies)",
    )

    # Verbosity and saving
    parser.add_argument(
        "--verbose",
        nargs="*",  # user can pass multiple options that are concatted into list
        help="Verbosity options",
    )
    parser.add_argument("--run_baseline", type=str_to_bool, help="Run baseline")
    parser.add_argument("--run_bridge", type=str_to_bool, help="Run bridge")
    parser.add_argument("--save_results", type=str_to_bool, help="Save results")
    parser.add_argument("--run_ID", type=str, help="ID")
    parser.add_argument(
        "--loaded_run_behaviour",
        "--load",
        dest="loaded_run_behaviour",
        type=str,
        choices=["continue", "redo", "overwrite"],
        help="Purpose of loaded run",
    )
    parser.add_argument(
        "--which_plot_subopt",
        type=str,
        choices=["suboptimality_percent", "regret", "cumulative_regret"],
        help="Which plot to show",
    )
    parser.add_argument(
        "--plot_slim",
        type=str_to_bool,
        help="Plot in slim mode",
    )
    parser.add_argument(
        "--plot_logy",
        type=str_to_bool,
        help="Plot in logy mode",
    )

    parsed_args = parser.parse_args(args)

    # Load configuration based on priority
    config_path = None

    # 1. Load user-specified config
    if parsed_args.config_name:
        config_path = os.path.join("configs", f"{parsed_args.config_name}.yaml")
        if not os.path.exists(config_path):
            raise FileNotFoundError(f"User-specified config file not found: {config_path}")

    # 2. If no config specified, load default config corresponding to user-specified environment
    elif parsed_args.environment:
        # Look for environment-specific default configs
        config_path = os.path.join("configs", "special", f"{parsed_args.environment}_default.yaml")
        if not os.path.exists(config_path):
            raise FileNotFoundError(
                f"User-specified environment config at {config_path} not found. Aborting."
            )

    # 3. If neither config nor environment specified, load global default config
    else:
        config_path = os.path.join("configs", "special", "starmdp_default.yaml")
        if not os.path.exists(config_path):
            raise FileNotFoundError(f"Global default config at {config_path} not found. Aborting.")

    # Load the configuration
    with open(config_path, "r") as f:
        params = yaml.safe_load(f)

    print(f"Loaded config from: {config_path}")

    # Override with command line arguments
    for key, value in vars(parsed_args).items():
        # Skip the config selection arguments
        if key in ["config_name", "environment"]:
            continue
        # Only override if the value was explicitly provided (not None)
        if value is not None:
            params[key] = value

    # Handle verbose as a list (in case it comes as None from argparse)
    if "verbose" not in params or params["verbose"] is None:
        params["verbose"] = []

    return params


def print_config(params):
    """Print the current configuration in a readable format."""
    print("=" * 60)
    print(f"EXPERIMENT CONFIGURATION of {params['run_ID']}")
    print("=" * 60)

    print(f"\nEnvironment Settings:")
    print(f"  Environment: {params['env']}")
    print(f"  Episode length: {params['episode_length']}")
    print(f"  Movement probability: {params['env_move_prob']}")
    print(f"  Embedding function: {params['phi_name']}")

    print(f"\nExperiment Settings:")
    print(f"  Number of seeds: {params['N_experiments']}")
    print(f"  Iterations per seed: {params['N_iterations']}")
    print(f"  Rollouts per iteration: {params['N_rollouts']}")
    print(f"  Offline trajectories: {params['N_offline_trajs']}")
    print(f"  Initial policies sampled: {params['N_confset_size']}")

    print(f"\nMethod Settings:")
    print(f"  Use offline BC: {params['do_offline_BC']}")
    print(f"  Confidence set method: {params['which_confset_construction_method']}")
    print(f"  Hellinger calculation: {params['which_hellinger_calc']}")

    print(f"\nOverride Settings:")
    print(f"  Offlineradius formula: {params['offlineradius_formula']}")
    print(f"  Offlineradius override value: {params['offlineradius_override_value']}")
    print(
        f"  Replace MLE with optimal policy in offline confset: {params['replace_mle_with_optimal_policy_in_offline_confset']}"
    )
    print(f"  Use true T in online: {params['use_true_T_in_online']}")
    print(f"  -> gamma_t hardcoded value: {params['gamma_t_hardcoded_value']}")
    print(f"  -> Pi_t bonus multiplier: {params['online_confset_bonus_multiplier']}")
    # print(f"  Online confset recalculate phi: {params['online_confset_recalc_phi']}")

    if params["verbose"]:
        print(f"\nVerbosity: {', '.join(params['verbose'])}")

    print(f"\nRun Settings:")
    print(f"  Run baseline: {params['run_baseline']}")
    print(f"  Run bridge: {params['run_bridge']}")
    print(f"  Save results: {params['save_results']}")
    print(f"  Run ID: {params['run_ID']}")
    print(f"  Loaded run behaviour: {params['loaded_run_behaviour']}")

    print("=" * 60)


def parse_args_mujoco(args=None, base_config=None):
    """Parse command line arguments and load configuration."""
    parser = argparse.ArgumentParser(description="Run preference-based RL experiments")

    # Config selection arguments
    parser.add_argument(
        "-cn",
        "--config-name",
        type=str,
        default=None,
        help="Config name to load (e.g., 'special/debug_starmdp' or 'exps/123')",
    )
    parser.add_argument(
        "-env",
        "--environment",
        type=str,
        default=None,
        choices=["Reacher-v5", "HalfCheetah-v5"],
        dest="env_id",
        help="Environment name (gym ID) to load default config for. Only used if no config name specified. Options: Reacher-v5, HalfCheetah-v5",
    )

    # Broad experiment params
    parser.add_argument(
        "--N_experiments",
        "-seeds",
        type=int,
        dest="N_experiments",
        help="Number of experiments (seeds)",
    )
    parser.add_argument("--N_iterations", type=int, help="Online iterations per seed")
    parser.add_argument("--episode_length", type=int, help="Episode length")
    parser.add_argument(
        "--phi_name",
        "--embedding_name",
        type=str,
        choices=["avg_sa", "avg_s", "last_s", "actionenergy", "psm"],
        dest="embedding_name",
        help="Embedding function",
    )
    parser.add_argument("--N_offline_trajs", type=int, help="Number of offline trajectories")
    parser.add_argument(
        "--fresh_offline_trajs", type=str_to_bool, help="Fresh offline trajectories"
    )

    # Offline learning parameters
    parser.add_argument("--N_confset_size", type=int, help="Number of initial policies sampled")
    parser.add_argument("--n_bc_epochs", type=int, help="Number of BC epochs")

    # Online learning parameters
    parser.add_argument("--N_rollouts", type=int, help="Number of rollouts per iteration")
    parser.add_argument("--W", type=int, help="W parameter")
    parser.add_argument("--w_epochs", type=int, help="MLE epochs for w")
    parser.add_argument(
        "--w_initialization",
        type=str,
        choices=["uniform", "small"],
        help="Weight initialization method",
    )
    parser.add_argument(
        "--project_w",
        type=str_to_bool,
        help="Project w s.t. ||w|| <= W",
    )
    parser.add_argument("--w_sigmoid_slope", type=float, help="Sigmoid slope for weights")
    parser.add_argument(
        "--which_policy_selection",
        type=str,
        choices=["random", "ucb", "max_uncertainty"],
        help="Policy selection method",
    )
    parser.add_argument(
        "--V_init", type=str, choices=["small", "bounds"], help="V initialization method"
    )
    parser.add_argument(
        "--n_embedding_samples", type=int, help="Number of samples for estimating policy embedding"
    )

    # policy model params
    parser.add_argument(
        "--hidden_dim",
        type=int,
        help="Hidden dimension of policy model. SB3 defaults to 64 x2, halfcheetah: 256 x2",
    )

    # Verbosity and saving
    parser.add_argument(
        "--verbose",
        nargs="*",  # user can pass multiple options that are concatted into list
        help="Verbosity options",
    )
    parser.add_argument("--run_baseline", type=str_to_bool, help="Run baseline")
    parser.add_argument("--run_bridge", type=str_to_bool, help="Run bridge")
    parser.add_argument("--save_results", type=str_to_bool, help="Save results")
    parser.add_argument("--run_ID", type=str, help="ID")
    parser.add_argument(
        "--loaded_run_behaviour",
        type=str,
        choices=["continue", "redo", "overwrite"],
        help="Purpose of loaded run",
    )
    parser.add_argument(
        "--which_plot_subopt",
        type=str,
        choices=["suboptimality_percent", "regret", "cumulative_regret"],
        help="Which plot to show",
    )
    parser.add_argument(
        "--baseline_or_bridge",
        type=str,
        choices=["baseline", "bridge"],
        help="Which method to run",
    )
    parser.add_argument(
        "--plot_scores",
        type=str_to_bool,
        help="Plot histogram of all candidates' scores at each loop iteration",
    )

    parsed_args = parser.parse_args(args)

    # Load configuration based on priority
    config_path = None

    # 1. Load user-specified config
    if parsed_args.config_name:
        config_path = os.path.join("configs", f"{parsed_args.config_name}.yaml")
        if not os.path.exists(config_path):
            raise FileNotFoundError(f"User-specified config file not found: {config_path}")
        with open(config_path, "r") as f:
            params = yaml.safe_load(f)
        print(f"Loaded config from: {config_path}")

    # 2. If no config specified, load default config corresponding to user-specified environment
    elif parsed_args.environment:
        # Look for environment-specific default configs
        config_path = os.path.join("configs", "special", f"{parsed_args.environment}_default.yaml")
        if not os.path.exists(config_path):
            raise FileNotFoundError(
                f"User-specified environment config at {config_path} not found. Aborting."
            )
        with open(config_path, "r") as f:
            params = yaml.safe_load(f)
        print(f"Loaded config from: {config_path}")

    elif base_config is not None:
        params = base_config.copy()
        print("Using base config hardcoded in main_mujoco.py file")

    # 3. If neither config nor environment specified, load global default config
    else:
        config_path = os.path.join("configs", "special", "starmdp_default.yaml")
        if not os.path.exists(config_path):
            raise FileNotFoundError(f"Global default config at {config_path} not found. Aborting.")
        with open(config_path, "r") as f:
            params = yaml.safe_load(f)
        print(f"Loaded config from: {config_path}")

    # Override with command line arguments
    for key, value in vars(parsed_args).items():
        # Skip the config selection arguments
        if key in ["config_name", "environment"]:
            continue
        # Only override if the value was explicitly provided (not None)
        if value is not None:
            params[key] = value

    # Handle verbose as a list (in case it comes as None from argparse)
    if "verbose" not in params or params["verbose"] is None:
        params["verbose"] = []

    return params


def print_config_mujoco(params):
    """Print the current configuration in a readable format."""
    print("=" * 60)
    print(f"EXPERIMENT CONFIGURATION of {params['run_ID']}")
    print("=" * 60)

    print(f"\nEnvironment Settings:")
    print(f"  Environment: {params['env_id']}")
    print(f"  Episode length: {params['episode_length']}")
    print(f"  Embedding function: {params['embedding_name']}")

    print(f"\nExperiment Settings:")
    print(f"  No. of seeds: {params['N_experiments']}")
    print(f"  Iterations per seed: {params['N_iterations']}")
    print(f"  No. queried preferences per iteration: {params['N_rollouts']}")
    print(f"  Offline trajectories: {params['N_offline_trajs']}")
    print(f"  Offline confset size: {params['N_confset_size']}")

    print(f"\nTraining Settings:")
    print(f"  PPO | arch: [{params['hidden_dim']} x2]")
    print(f"  BC  | epochs: {params['n_bc_epochs']}")
    print(f"  w   | epochs: {params['w_epochs']}")
    print(f"  w   | initialization: {params['w_initialization']}")
    print(f"  w   | sigmoid slope: {params['w_sigmoid_slope']}")
    print(f"  w   | bound W: {params['W']}")
    print(f"  w   | project: {params['project_w']}")

    print(f"\nOnline Learning Settings:")
    print(f"  V   | initialization: {params['V_init']}")
    print(f"  ϕ(π)| samples: {params['n_embedding_samples']}")
    print(f"  policy selection: {params['which_policy_selection']}")

    if params["verbose"]:
        print(f"\nVerbosity: {', '.join(params['verbose'])}")

    print(f"\nRun Settings:")
    print(f"  Run baseline: {params['run_baseline']}")
    print(f"  Run bridge: {params['run_bridge']}")
    print(f"  Save results: {params['save_results']}")
    print(f"  Run ID: {params['run_ID']}")
    print(f"  Loaded run behaviour: {params['loaded_run_behaviour']}")
    print(f"  Regret plot: {params['which_plot_subopt']}")
    print(f"  Plot scores: {params['plot_scores']}")
