from collections import OrderedDict

from ray import tune
from ray.rllib.algorithms.a2c import A2C
from ray.rllib.algorithms.a3c import A3C, a3c_torch_policy
from ray.rllib.algorithms.dqn import DQN, DQNTorchPolicy
from ray.rllib.algorithms.es import ES, ESTorchPolicy
from ray.rllib.algorithms.pg import PG, PGTorchPolicy
from ray.rllib.algorithms.ppo import PPO, PPOTF1Policy, PPOTorchPolicy
from ray.rllib.algorithms.sac import SAC, SACTorchPolicy
from ray.rllib.algorithms.simple_q import SimpleQ, SimpleQTorchPolicy
from ray.rllib.models.torch.complex_input_net import ComplexInputNetwork
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork

from stackerlberg.envs.matrix_game import named_matrix_games
from stackerlberg.models.custom_fully_connected_torch_network import (
    CustomFullyConnectedNetwork,
)
from stackerlberg.models.linear_torch_model import LinearTorchModel
from stackerlberg.train.experiments.debug_callbacks import *
from stackerlberg.train.make_env import registered_environments
from stackerlberg.train.policies import (
    AlwaysCoop,
    AlwaysDefect,
    IPD_MostlyTFT,
    IPD_TFT_Coop_Defect,
    IPDCoopOrDefectPerEpisode,
    IPDRandomEveryEpisodePolicy,
    SmIPD_TFT_Coop_Defect,
)

experiment_configurations = {
    # ---- Testing ---
    "matrix_bots": {
        "configuration": {
            "pre_training_iterations": 4,
            "inner_iterations_follower": 0,
            "inner_iterations_leader": 4,
            "outer_iterations": 1,
            "post_training_iterations": 0,
            "common_config": {
                "env": "matrix_game_stackelberg_observed_queries",
                "env_config": {"matrix": [[[2, 1], [0, 0]], [[0, 0], [1, 2]]], "reward_offset": 0},
                "framework": "torch",
                "rollout_fragment_length": 1,
                "train_batch_size": 256,
                "min_sample_timesteps_per_iteration": 64,
                "lr": 0.008,
                "replay_buffer_config": {
                    "learning_starts": 0,
                },
            },
            "leader_config": {
                "evaluation_interval": 1,
                "evaluation_duration": 4,
                "evaluation_duration_unit": "episodes",
                "exploration_config": {
                    "type": "ParameterNoise",
                    "random_timesteps": 0,
                    "initial_stddev": 1.0,
                    "sub_exploration": {
                        "type": "EpsilonGreedy",
                        "initial_epsilon": 0.0,
                        "final_epsilon": 0.0,
                        "epsilon_timesteps": 1000000,
                    },
                    # "type": "EpsilonGreedy",
                    # "initial_epsilon": 1.0,
                    # "final_epsilon": 0.2,
                    # "epsilon_timesteps": 1000000,
                },
                "lr_schedule": [[0, 0.008], [256, 0.00001]],
            },
            "leader_algorithm": SimpleQ,
            "follower_algorithm": SimpleQ,
            "leader_policy_config": {
                "model": {
                    "fcnet_hiddens": [],
                },
            },
            "follower_policy_config": {
                "model": {
                    "fcnet_hiddens": [],
                },
            },
            "randomize_leader": True,
            "callbacks": {
                "post-pretrain": [
                    lambda **kwargs: kwargs["results"].update(
                        {
                            "action_0": kwargs["pre_trainer"].compute_single_action(
                                OrderedDict(original_space=0, none_0=0), policy_id="agent_1", explore=False
                            ),
                            "action_1": kwargs["pre_trainer"].compute_single_action(
                                OrderedDict(original_space=0, none_0=1), policy_id="agent_1", explore=False
                            ),
                        }
                    ),
                ],
            },
        },
    },
    "test_save_load_checkpoint": {
        "configuration": {
            # Tests follower against 1/3 each always-coop, always-defect, and TFT leader.
            "common_config": {
                "env": "repeated_matrix_game_stackelberg_observed_queries",
                "env_config": {
                    "matrix_name": "prisoners_dilemma",
                    "discrete_obs": True,
                    "small_memory": True,
                    "episode_length": 5,
                    "memory": True,
                },
                "batch_mode": "complete_episodes",
            },
            "leader_algorithm": SimpleQ,
            "leader_policy_class": SimpleQTorchPolicy,
            "leader_config": {
                "evaluation_interval": 1,
                "evaluation_duration": 10,
                "evaluation_duration_unit": "episodes",
                "lr": 0.1,
                "train_batch_size": 10,
            },
            "leader_policy_config": {
                "model": {
                    "fcnet_hiddens": [],
                    "vf_share_layers": True,
                    "custom_model": LinearTorchModel,
                },
            },
            "follower_algorithm": PG,
            "follower_policy_class": PGTorchPolicy,
            "follower_policy_config": {
                "model": {
                    "fcnet_hiddens": [],
                    "vf_share_layers": True,
                    "custom_model": LinearTorchModel,
                },
            },
            "follower_config": {
                "lr": 0.008,
                "min_sample_timesteps_per_iteration": 100,
                # "train_batch_size": 256,
                # "sgd_minibatch_size": 256,
                "metrics_smoothing_episodes": 1,
                "rollout_fragment_length": 100,
                "train_batch_size": 100,
                "evaluation_interval": 1,
                "evaluation_duration": 10,
                "evaluation_duration_unit": "episodes",
                "learning_starts": 0,
            },
            "pre_training_iterations": 1,
            "inner_iterations_follower": 0,
            "inner_iterations_leader": 0,
            "outer_iterations": 1,
            "post_training_iterations": 0,
            "randomize_leader": True,
            # "_debug_dont_train_leader": True,
            "callbacks": {
                "post-pretrain": [smipd_check_follower_best_response],
            },
            "log_weights": True,
            "pretrain_save_checkpoint": "./pretrain_checkpoint.pkl",
        },
    },
    # -------- Leader Memory --
    "smipd_leadermemory_pg_pg": {
        "configuration": {
            "common_config": {
                "env": "repeated_matrix_game_stackelberg_observed_queries",
                "env_config": {
                    "matrix_name": "prisoners_dilemma",
                    "discrete_obs": True,
                    "small_memory": True,
                    "episode_length": 5,
                    "memory": True,
                    "tell_leader": True,
                },
                "batch_mode": "complete_episodes",
            },
            "deterministic_leader": True,
            "deterministic_follower": True,
            "leader_algorithm": PG,
            "leader_policy_class": PGTorchPolicy,
            "leader_config": {
                "lr": 0.008,
                "min_sample_timesteps_per_iteration": 100,
                "metrics_smoothing_episodes": 1,
                "rollout_fragment_length": 100,
                "train_batch_size": 100,
                "evaluation_interval": 1,
                "evaluation_duration": 10,
                "evaluation_duration_unit": "episodes",
                "learning_starts": 0,
            },
            "leader_policy_config": {
                "model": {
                    "fcnet_hiddens": [],
                    "vf_share_layers": True,
                    "custom_model": LinearTorchModel,
                },
            },
            "follower_algorithm": PG,
            "follower_policy_class": PGTorchPolicy,
            "follower_policy_config": {
                "model": {
                    "fcnet_hiddens": [],
                    "vf_share_layers": True,
                    "custom_model": LinearTorchModel,
                },
            },
            "follower_config": {
                "lr": 0.02,
                "min_sample_timesteps_per_iteration": 100,
                "metrics_smoothing_episodes": 1,
                "rollout_fragment_length": 100,
                "train_batch_size": 100,
                "evaluation_interval": 1,
                "evaluation_duration": 10,
                "evaluation_duration_unit": "episodes",
                "learning_starts": 0,
            },
            "pre_training_iterations": 500,
            "inner_iterations_follower": 0,
            "inner_iterations_leader": 1,
            "outer_iterations": 2000,
            "post_training_iterations": 50,
            "randomize_leader": True,
            # "_debug_dont_train_leader": True,
            "callbacks": {
                "post-pretrain": [smipd_check_follower_best_response],
                "config": [lambda **kwargs: kwargs["pretrain_config"]["env_config"].update({"tell_leader_mock": True})],
            },
            "log_weights": True,
        },
        "hyperopt_searchspace": {
            "leader_config": {
                "lr": tune.loguniform(0.00001, 1.0),
                # "rollout_fragment_length": tune.choice([1, 2, 4, 8, 16, 32, 64]),
                # "exploration_config": {"final_epsilon": tune.loguniform(0.01, 1.0)},
            },
        },
        "hyperopt_startingpoints": [
            {
                "leader_config": {
                    "lr": 0.008,
                    # "exploration_config": {"final_epsilon": 0.2},
                },
            },
            {
                "leader_config": {
                    "lr": 0.015,
                    # "exploration_config": {"final_epsilon": 0.2},
                },
            },
            {
                "leader_config": {
                    "lr": 0.03,
                    # "exploration_config": {"final_epsilon": 0.2},
                },
            },
            {
                "leader_config": {
                    "lr": 0.004,
                    # "exploration_config": {"final_epsilon": 0.2},
                },
            },
        ],
        "hyperopt_metric": "leader_results/evaluation/policy_reward_mean/agent_0",
        "hyperopt_seeds": 10,
    },
    "smipd_leadernomemory_pg_pg": {
        "configuration": {
            "common_config": {
                "env": "repeated_matrix_game_stackelberg_observed_queries",
                "env_config": {
                    "matrix_name": "prisoners_dilemma",
                    "discrete_obs": True,
                    "small_memory": True,
                    "episode_length": 5,
                    "memory": True,
                    "tell_leader": False,
                },
                "batch_mode": "complete_episodes",
            },
            "deterministic_leader": True,
            "deterministic_follower": True,
            "leader_algorithm": PG,
            "leader_policy_class": PGTorchPolicy,
            "leader_config": {
                "lr": 0.008,
                "min_sample_timesteps_per_iteration": 100,
                "metrics_smoothing_episodes": 1,
                "rollout_fragment_length": 100,
                "train_batch_size": 100,
                "evaluation_interval": 1,
                "evaluation_duration": 10,
                "evaluation_duration_unit": "episodes",
                "learning_starts": 0,
            },
            "leader_policy_config": {
                "model": {
                    "fcnet_hiddens": [],
                    "vf_share_layers": True,
                    "custom_model": LinearTorchModel,
                },
            },
            "follower_algorithm": PG,
            "follower_policy_class": PGTorchPolicy,
            "follower_policy_config": {
                "model": {
                    "fcnet_hiddens": [],
                    "vf_share_layers": True,
                    "custom_model": LinearTorchModel,
                },
            },
            "follower_config": {
                "lr": 0.02,
                "min_sample_timesteps_per_iteration": 100,
                "metrics_smoothing_episodes": 1,
                "rollout_fragment_length": 100,
                "train_batch_size": 100,
                "evaluation_interval": 1,
                "evaluation_duration": 10,
                "evaluation_duration_unit": "episodes",
                "learning_starts": 0,
            },
            "pre_training_iterations": 500,
            "inner_iterations_follower": 0,
            "inner_iterations_leader": 1,
            "outer_iterations": 2000,
            "post_training_iterations": 50,
            "randomize_leader": True,
            # "_debug_dont_train_leader": True,
            "callbacks": {
                "post-pretrain": [smipd_check_follower_best_response],
                "config": [lambda **kwargs: kwargs["pretrain_config"]["env_config"].update({"tell_leader_mock": True})],
            },
            "log_weights": True,
        },
        "hyperopt_searchspace": {
            "leader_config": {
                "lr": tune.loguniform(0.00001, 1.0),
                # "rollout_fragment_length": tune.choice([1, 2, 4, 8, 16, 32, 64]),
                # "exploration_config": {"final_epsilon": tune.loguniform(0.01, 1.0)},
            },
        },
        "hyperopt_startingpoints": [
            {
                "leader_config": {
                    "lr": 0.008,
                    # "exploration_config": {"final_epsilon": 0.2},
                },
            },
            {
                "leader_config": {
                    "lr": 0.015,
                    # "exploration_config": {"final_epsilon": 0.2},
                },
            },
            {
                "leader_config": {
                    "lr": 0.03,
                    # "exploration_config": {"final_epsilon": 0.2},
                },
            },
            {
                "leader_config": {
                    "lr": 0.004,
                    # "exploration_config": {"final_epsilon": 0.2},
                },
            },
        ],
        "hyperopt_metric": "leader_results/evaluation/policy_reward_mean/agent_0",
        "hyperopt_seeds": 10,
    },
    # -------- Hidden Queries --
    "smipd_hiddenqueries_pg_pg": {
        "configuration": {
            "common_config": {
                "env": "repeated_matrix_game_stackelberg_observed_queries",
                "env_config": {
                    "matrix_name": "prisoners_dilemma",
                    "discrete_obs": True,
                    "small_memory": True,
                    "episode_length": 5,
                    "memory": True,
                    "hidden_queries": True,
                },
                "batch_mode": "complete_episodes",
            },
            "deterministic_leader": True,
            "deterministic_follower": True,
            "leader_algorithm": PG,
            "leader_policy_class": PGTorchPolicy,
            "leader_config": {
                "lr": 0.008,
                "min_sample_timesteps_per_iteration": 100,
                "metrics_smoothing_episodes": 1,
                "rollout_fragment_length": 100,
                "train_batch_size": 100,
                "evaluation_interval": 1,
                "evaluation_duration": 10,
                "evaluation_duration_unit": "episodes",
                "learning_starts": 0,
            },
            "leader_policy_config": {
                "model": {
                    "fcnet_hiddens": [],
                    "vf_share_layers": True,
                    "custom_model": LinearTorchModel,
                },
            },
            "follower_algorithm": PG,
            "follower_policy_class": PGTorchPolicy,
            "follower_policy_config": {
                "model": {
                    "fcnet_hiddens": [],
                    "vf_share_layers": True,
                    "custom_model": LinearTorchModel,
                },
            },
            "follower_config": {
                "lr": 0.02,
                "min_sample_timesteps_per_iteration": 100,
                "metrics_smoothing_episodes": 1,
                "rollout_fragment_length": 100,
                "train_batch_size": 100,
                "evaluation_interval": 1,
                "evaluation_duration": 10,
                "evaluation_duration_unit": "episodes",
                "learning_starts": 0,
            },
            "pre_training_iterations": 500,
            "inner_iterations_follower": 0,
            "inner_iterations_leader": 1,
            "outer_iterations": 2000,
            "post_training_iterations": 50,
            "randomize_leader": True,
            # "_debug_dont_train_leader": True,
            "callbacks": {
                "post-pretrain": [smipd_check_follower_best_response],
            },
            "log_weights": True,
        },
        "hyperopt_searchspace": {
            "leader_config": {
                "lr": tune.loguniform(0.00001, 0.001),
                # "rollout_fragment_length": tune.choice([1, 2, 4, 8, 16, 32, 64]),
                # "exploration_config": {"final_epsilon": tune.loguniform(0.01, 1.0)},
            },
        },
        "hyperopt_startingpoints": [
            {
                "leader_config": {
                    "lr": 0.0002,
                    # "exploration_config": {"final_epsilon": 0.2},
                },
            },
            {
                "leader_config": {
                    "lr": 0.0008,
                    # "exploration_config": {"final_epsilon": 0.2},
                },
            },
            {
                "leader_config": {
                    "lr": 0.00005,
                    # "exploration_config": {"final_epsilon": 0.2},
                },
            },
        ],
        "hyperopt_metric": "leader_results/evaluation/policy_reward_mean/agent_0",
        "hyperopt_seeds": 10,
    },
    "smipd_hiddenqueries_dqn_pg": {
        "configuration": {
            "common_config": {
                "env": "repeated_matrix_game_stackelberg_observed_queries",
                "env_config": {
                    "matrix_name": "prisoners_dilemma",
                    "discrete_obs": True,
                    "small_memory": True,
                    "episode_length": 5,
                    "memory": True,
                    "hidden_queries": True,
                },
                "batch_mode": "complete_episodes",
            },
            "deterministic_leader": True,
            "deterministic_follower": True,
            "leader_algorithm": SimpleQ,
            "leader_policy_class": SimpleQTorchPolicy,
            "leader_config": {
                "min_sample_timesteps_per_iteration": 100,
                "evaluation_interval": 1,
                "evaluation_duration": 10,
                "evaluation_duration_unit": "episodes",
                "lr": 0.001,
                "rollout_fragment_length": 10,
                "batch_mode": "complete_episodes",
                "train_batch_size": 1024,
                "learning_starts": 5000,
                "exploration_config": {
                    "type": "ParameterNoise",
                    "random_timesteps": 0,
                    "initial_stddev": 1.0,
                    "sub_exploration": {
                        "type": "EpsilonGreedy",
                        "initial_epsilon": 0.0,
                        "final_epsilon": 0.0,
                        "epsilon_timesteps": 1000000,
                    },
                    # "type": "EpsilonGreedy",
                    # "initial_epsilon": 1.0,
                    # "final_epsilon": 0.2,
                    # "epsilon_timesteps": 1000000,
                },
            },
            "leader_policy_config": {
                "model": {
                    "fcnet_hiddens": [],
                    "vf_share_layers": True,
                    "custom_model": LinearTorchModel,
                },
            },
            "follower_algorithm": PG,
            "follower_policy_class": PGTorchPolicy,
            "follower_policy_config": {
                "model": {
                    "fcnet_hiddens": [],
                    "vf_share_layers": True,
                    "custom_model": LinearTorchModel,
                },
            },
            "follower_config": {
                "lr": 0.02,
                "min_sample_timesteps_per_iteration": 100,
                "metrics_smoothing_episodes": 1,
                "rollout_fragment_length": 100,
                "train_batch_size": 100,
                "evaluation_interval": 1,
                "evaluation_duration": 10,
                "evaluation_duration_unit": "episodes",
                "learning_starts": 0,
            },
            "pre_training_iterations": 500,
            "inner_iterations_follower": 0,
            "inner_iterations_leader": 1,
            "outer_iterations": 2000,
            "post_training_iterations": 50,
            "randomize_leader": True,
            # "_debug_dont_train_leader": True,
            "callbacks": {
                "post-pretrain": [smipd_check_follower_best_response],
            },
            "log_weights": True,
        },
        "hyperopt_searchspace": {
            "leader_config": {
                "lr": tune.loguniform(0.0005, 0.01),
                # "rollout_fragment_length": tune.choice([1, 2, 4, 8, 16, 32, 64]),
                # "exploration_config": {"final_epsilon": tune.loguniform(0.01, 1.0)},
            },
        },
        "hyperopt_startingpoints": [
            {
                "leader_config": {
                    "lr": 0.004,
                    # "exploration_config": {"final_epsilon": 0.2},
                },
            },
            {
                "leader_config": {
                    "lr": 0.008,
                    # "exploration_config": {"final_epsilon": 0.2},
                },
            },
        ],
        "hyperopt_metric": "leader_results/evaluation/policy_reward_mean/agent_0",
        "hyperopt_seeds": 10,
    },
    "smipd_nothiddenqueries_pg_pg": {
        "configuration": {
            "common_config": {
                "env": "repeated_matrix_game_stackelberg_observed_queries",
                "env_config": {
                    "matrix_name": "prisoners_dilemma",
                    "discrete_obs": True,
                    "small_memory": True,
                    "episode_length": 5,
                    "memory": True,
                    "hidden_queries": False,
                },
                "batch_mode": "complete_episodes",
            },
            "deterministic_leader": True,
            "deterministic_follower": True,
            "leader_algorithm": PG,
            "leader_policy_class": PGTorchPolicy,
            "leader_config": {
                "lr": 0.008,
                "min_sample_timesteps_per_iteration": 100,
                "metrics_smoothing_episodes": 1,
                "rollout_fragment_length": 100,
                "train_batch_size": 100,
                "evaluation_interval": 1,
                "evaluation_duration": 10,
                "evaluation_duration_unit": "episodes",
                "learning_starts": 0,
            },
            "leader_policy_config": {
                "model": {
                    "fcnet_hiddens": [],
                    "vf_share_layers": True,
                    "custom_model": LinearTorchModel,
                },
            },
            "follower_algorithm": PG,
            "follower_policy_class": PGTorchPolicy,
            "follower_policy_config": {
                "model": {
                    "fcnet_hiddens": [],
                    "vf_share_layers": True,
                    "custom_model": LinearTorchModel,
                },
            },
            "follower_config": {
                "lr": 0.02,
                "min_sample_timesteps_per_iteration": 100,
                "metrics_smoothing_episodes": 1,
                "rollout_fragment_length": 100,
                "train_batch_size": 100,
                "evaluation_interval": 1,
                "evaluation_duration": 10,
                "evaluation_duration_unit": "episodes",
                "learning_starts": 0,
            },
            "pre_training_iterations": 500,
            "inner_iterations_follower": 0,
            "inner_iterations_leader": 1,
            "outer_iterations": 2000,
            "post_training_iterations": 50,
            "randomize_leader": True,
            # "_debug_dont_train_leader": True,
            "callbacks": {
                "post-pretrain": [smipd_check_follower_best_response],
            },
            "log_weights": True,
        },
        "hyperopt_searchspace": {
            "leader_config": {
                "lr": tune.loguniform(0.00001, 0.001),
                # "rollout_fragment_length": tune.choice([1, 2, 4, 8, 16, 32, 64]),
                # "exploration_config": {"final_epsilon": tune.loguniform(0.01, 1.0)},
            },
        },
        "hyperopt_startingpoints": [
            {
                "leader_config": {
                    "lr": 0.0002,
                    # "exploration_config": {"final_epsilon": 0.2},
                },
            },
            {
                "leader_config": {
                    "lr": 0.0008,
                    # "exploration_config": {"final_epsilon": 0.2},
                },
            },
            {
                "leader_config": {
                    "lr": 0.00005,
                    # "exploration_config": {"final_epsilon": 0.2},
                },
            },
        ],
        "hyperopt_metric": "leader_results/evaluation/policy_reward_mean/agent_0",
        "hyperopt_seeds": 10,
    },
    "smipd_nothiddenqueries_dqn_pg": {
        "configuration": {
            "common_config": {
                "env": "repeated_matrix_game_stackelberg_observed_queries",
                "env_config": {
                    "matrix_name": "prisoners_dilemma",
                    "discrete_obs": True,
                    "small_memory": True,
                    "episode_length": 5,
                    "memory": True,
                    "hidden_queries": False,
                },
                "batch_mode": "complete_episodes",
            },
            "deterministic_leader": True,
            "deterministic_follower": True,
            "leader_algorithm": SimpleQ,
            "leader_policy_class": SimpleQTorchPolicy,
            "leader_config": {
                "min_sample_timesteps_per_iteration": 100,
                "evaluation_interval": 1,
                "evaluation_duration": 10,
                "evaluation_duration_unit": "episodes",
                "lr": 0.001,
                "rollout_fragment_length": 10,
                "batch_mode": "complete_episodes",
                "train_batch_size": 1024,
                "learning_starts": 5000,
                "exploration_config": {
                    "type": "ParameterNoise",
                    "random_timesteps": 0,
                    "initial_stddev": 1.0,
                    "sub_exploration": {
                        "type": "EpsilonGreedy",
                        "initial_epsilon": 0.0,
                        "final_epsilon": 0.0,
                        "epsilon_timesteps": 1000000,
                    },
                    # "type": "EpsilonGreedy",
                    # "initial_epsilon": 1.0,
                    # "final_epsilon": 0.2,
                    # "epsilon_timesteps": 1000000,
                },
            },
            "leader_policy_config": {
                "model": {
                    "fcnet_hiddens": [],
                    "vf_share_layers": True,
                    "custom_model": LinearTorchModel,
                },
            },
            "follower_algorithm": PG,
            "follower_policy_class": PGTorchPolicy,
            "follower_policy_config": {
                "model": {
                    "fcnet_hiddens": [],
                    "vf_share_layers": True,
                    "custom_model": LinearTorchModel,
                },
            },
            "follower_config": {
                "lr": 0.02,
                "min_sample_timesteps_per_iteration": 100,
                "metrics_smoothing_episodes": 1,
                "rollout_fragment_length": 100,
                "train_batch_size": 100,
                "evaluation_interval": 1,
                "evaluation_duration": 10,
                "evaluation_duration_unit": "episodes",
                "learning_starts": 0,
            },
            "pre_training_iterations": 500,
            "inner_iterations_follower": 0,
            "inner_iterations_leader": 1,
            "outer_iterations": 2000,
            "post_training_iterations": 50,
            "randomize_leader": True,
            # "_debug_dont_train_leader": True,
            "callbacks": {
                "post-pretrain": [smipd_check_follower_best_response],
            },
            "log_weights": True,
        },
        "hyperopt_searchspace": {
            "leader_config": {
                "lr": tune.loguniform(0.0005, 0.01),
                # "rollout_fragment_length": tune.choice([1, 2, 4, 8, 16, 32, 64]),
                # "exploration_config": {"final_epsilon": tune.loguniform(0.01, 1.0)},
            },
        },
        "hyperopt_startingpoints": [
            {
                "leader_config": {
                    "lr": 0.004,
                    # "exploration_config": {"final_epsilon": 0.2},
                },
            },
            {
                "leader_config": {
                    "lr": 0.008,
                    # "exploration_config": {"final_epsilon": 0.2},
                },
            },
        ],
        "hyperopt_metric": "leader_results/evaluation/policy_reward_mean/agent_0",
        "hyperopt_seeds": 10,
    },
    # --- Allll matrices --- #
    "ipd_allmatrices_pg_pg": {
        "configuration": {
            "common_config": {
                "env": "repeated_matrix_game_stackelberg_observed_queries",
                "env_config": {
                    "matrix_name": tune.grid_search(list(named_matrix_games.keys())),
                    "discrete_obs": True,
                    "small_memory": False,
                    "episode_length": 10,
                    "memory": True,
                },
                "batch_mode": "complete_episodes",
            },
            "seed": tune.grid_search([1, 2, 3, 4, 5]),
            "deterministic_leader": True,
            "deterministic_follower": True,
            "leader_algorithm": PG,
            "leader_policy_class": PGTorchPolicy,
            "leader_config": {
                "lr": 0.156,
                "min_sample_timesteps_per_iteration": 100,
                "metrics_smoothing_episodes": 1,
                "rollout_fragment_length": 100,
                "train_batch_size": 100,
                "evaluation_interval": 1,
                "evaluation_duration": 10,
                "evaluation_duration_unit": "episodes",
                "learning_starts": 0,
            },
            "leader_policy_config": {
                "model": {
                    "fcnet_hiddens": [],
                    "vf_share_layers": True,
                    "custom_model": LinearTorchModel,
                },
            },
            "follower_algorithm": PG,
            "follower_policy_class": PGTorchPolicy,
            "follower_policy_config": {
                "model": {
                    "fcnet_hiddens": [],
                    "vf_share_layers": True,
                    "custom_model": LinearTorchModel,
                },
            },
            "follower_config": {
                "lr": 0.02,
                "min_sample_timesteps_per_iteration": 100,
                "metrics_smoothing_episodes": 1,
                "rollout_fragment_length": 100,
                "train_batch_size": 100,
                "evaluation_interval": 1,
                "evaluation_duration": 10,
                "evaluation_duration_unit": "episodes",
                "learning_starts": 0,
            },
            "pre_training_iterations": 500,
            "inner_iterations_follower": 0,
            "inner_iterations_leader": 1,
            "outer_iterations": 1200,
            "post_training_iterations": 50,
            "randomize_leader": True,
            # "_debug_dont_train_leader": True,
            "callbacks": {
                "post-pretrain": [smipd_check_follower_best_response],
            },
            "log_weights": True,
        },
    },
    "ipd_allmatrices_ppo_pg": {
        "configuration": {
            "common_config": {
                "env": "repeated_matrix_game_stackelberg_observed_queries",
                "env_config": {
                    "matrix_name": tune.grid_search(list(named_matrix_games.keys())),
                    "discrete_obs": True,
                    "small_memory": False,
                    "episode_length": 10,
                    "memory": True,
                },
                "batch_mode": "complete_episodes",
            },
            "seed": tune.grid_search([1, 2, 3, 4, 5]),
            "deterministic_leader": True,
            "deterministic_follower": True,
            "leader_algorithm": PPO,
            "leader_policy_class": PPOTorchPolicy,
            "leader_config": {
                "lr": 0.008,
                "entropy_coeff": 0.0,
                "min_sample_timesteps_per_iteration": 100,
                "metrics_smoothing_episodes": 1,
                "rollout_fragment_length": 1000,
                "train_batch_size": 1000,
                "sgd_minibatch_size": 1000,
                "evaluation_interval": 1,
                "evaluation_duration": 10,
                "evaluation_duration_unit": "episodes",
                "learning_starts": 0,
            },
            "leader_policy_config": {
                "model": {
                    "fcnet_hiddens": [],
                    "vf_share_layers": True,
                    "custom_model": LinearTorchModel,
                },
            },
            "follower_algorithm": PG,
            "follower_policy_class": PGTorchPolicy,
            "follower_policy_config": {
                "model": {
                    "fcnet_hiddens": [],
                    "vf_share_layers": True,
                    "custom_model": LinearTorchModel,
                },
            },
            "follower_config": {
                "lr": 0.02,
                "min_sample_timesteps_per_iteration": 100,
                "metrics_smoothing_episodes": 1,
                "rollout_fragment_length": 100,
                "train_batch_size": 100,
                "evaluation_interval": 1,
                "evaluation_duration": 10,
                "evaluation_duration_unit": "episodes",
                "learning_starts": 0,
            },
            "pre_training_iterations": 500,
            "inner_iterations_follower": 0,
            "inner_iterations_leader": 1,
            "outer_iterations": 500,
            "post_training_iterations": 50,
            "randomize_leader": True,
            # "_debug_dont_train_leader": True,
            "callbacks": {
                "post-pretrain": [smipd_check_follower_best_response],
            },
            "log_weights": True,
        },
    },
}
