from ray import tune
from concept_config import concept_config
from task_env_func import curriculum_fn
from utils import possibly_gridsearch


def impala_config(args, env_config, custom_callback):
    config = {
        "num_workers": args.num_workers,
        "env": "custom_env",
        "env_config": env_config,
        "callbacks": custom_callback,
        # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
        "num_gpus": args.num_gpu,
        # "num_cpus_for_driver": 4,
        "num_envs_per_worker": args.num_envs_per_worker,
        # "num_gpus_per_worker": 0.3,
        "compress_observations": False,
        "framework": "torch",
        "batch_mode": "truncate_episodes",
        "horizon": args.max_episode_steps + 1,
        # "include_concepts": args.include_concepts,
        # "rollout_fragment_length": 100,
        # "train_batch_size": 500,
        # === Evaluation Settings ===
        "evaluation_interval": 1,
        "evaluation_num_episodes": 1,
        "evaluation_num_workers": 1,
        "opt_type": "rmsprop",
        "epsilon": 0.01,
        "lr": 0.0005,  # tune.grid_search([0.0005, 0.001, 0.005, 0.01]),
        "lr_schedule": [[0, 0.0005], [20000000, 0.000000000001],],
        # "sgd_minibatch_size": tune.sample_from(lambda spec: spec.config.model.max_seq_len * 32),
        # "sgd_minibatch_size": 32 * 5,  # tune.grid_search([32,64]),#tune.grid_search([50*32, 50*64, 50*128, 50*256, 50*512]),
        "train_batch_size": 32,
        "evaluation_config": {
            # Example: overriding env_config, exploration, etc:
            # "env_config": {...},
            # "explore": False,
            "render_env": True,
            "record_env": "videos",
        },
        "model": {
            # Share layers for value function. If you set this to True, it's
            # important to tune vf_loss_coeff.
            "vf_share_layers": False,
            #    "use_lstm": True,
            "custom_model_config": {
                # "n_heads": tune.grid_search([1,2,3,4,5,6,7,8]),
            },
            "max_seq_len": 5,  # tune.grid_search([5, 10]),
        },
    }
    return config


def ppo_config(args, env_config, custom_callback):
    config = {
        "num_workers": args.num_workers,
        "env": "custom_env",
        "env_config": env_config,
        "callbacks": custom_callback,
        # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
        "num_gpus": args.num_gpu,
        # "num_cpus_for_driver": 4,
        "num_envs_per_worker": args.num_envs_per_worker,
        # "num_gpus_per_worker": 0.3,
        "compress_observations": False,
        "framework": "torch",
        "batch_mode": "complete_episodes",
        "horizon": 200,
        # "include_concepts": args.include_concepts,
        # "rollout_fragment_length": 100,
        # "train_batch_size": 500,
        # === Evaluation Settings ===
        "evaluation_interval": 10,
        "evaluation_num_episodes": 10,
        "evaluation_num_workers": args.num_eval_workers,
        "entropy_coeff": possibly_gridsearch(args.entropy_coeff),
        "entropy_coeff_schedule": possibly_gridsearch(args.entropy_coeff_schedule),
        "lr": possibly_gridsearch(args.lr),  # tune.grid_search([1e-3, 5e-4, 1e-4]),
        # "lr_schedule": [[0, 0.0005], [20000000, 0.000000000001],],
        # "sgd_minibatch_size": tune.sample_from(lambda spec: spec.config.model.max_seq_len * 32),
        "train_batch_size": 2 * 5120,
        "sgd_minibatch_size": 50
        * 32,  # tune.grid_search([50 * 32, 50 * 16, 50 * 24, 50 * 64]),
        # 32 * 20,  # tune.grid_search([32,64]),#tune.grid_search([50*32, 50*64, 50*128, 50*256, 50*512]),
        "render_env": args.render_env,
        "evaluation_config": {
            # Example: overriding env_config, exploration, etc:
            # "env_config": {...},
            # "explore": False,
            "render_env": args.render_env,
            # "record_env": "videos",
        },
        "disable_env_checking": True,
        "model": {
            # Share layers for value function. If you set this to True, it's
            # important to tune vf_loss_coeff.
            "vf_share_layers": False,
            #    "use_lstm": True,
            "custom_model_config": {
                # "n_heads": tune.grid_search([1,2,3,4,5,6,7,8]),
            },
            "max_seq_len": 50,  # tune.grid_search([5, 10]),
        },
        # concept lr config
        "concept_loss_coeff": 10.0,  # tune.grid_search([100.0, 10.0, 5.0, 1.0, 0.5]),
        "balanced_beta": 0.99,
        "balanced_gamma": 1,
        "use_balanced": possibly_gridsearch(
            args.use_balanced
        ),  # tune.sample_from(lambda spec: spec.config.loss_type != "not_balanced"),
        "loss_type": "focal",  # tune.grid_search(["focal", "sigmoid", "softmax", "not_balanced"]),
    }
    return config


def r2d2_config(args, env_config, custom_callback):
    config = {
        "framework": "torch",
        "env": "custom_env",
        "env_config": env_config,
        "callbacks": custom_callback,
        # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
        "num_gpus": args.num_gpu,
        # "num_cpus_for_driver": 4,
        "num_envs_per_worker": args.num_envs_per_worker,
        # "num_gpus_per_worker": 0.3,
        # Learning rate for adam optimizer.
        "lr": tune.grid_search([5e-4, 1e-5, 5e-6]),
        # Discount factor.
        "gamma": 0.99,
        # Train batch size (in number of single timesteps).
        "train_batch_size": 64 * 50,
        # Adam epsilon hyper parameter
        "adam_epsilon": 5e-4,
        # Run in parallel by default.
        "num_workers": args.num_workers,
        # Batch mode must be complete_episodes.
        "batch_mode": "complete_episodes",
        "exploration_config": {
            "type": "EpsilonGreedy",
            "initial_epsilon": 0.96,
            "final_epsilon": 0.01,
            "epsilon_timesteps": 500000,
        },
        # If True, assume a zero-initialized state input (no matter where in
        # the episode the sequence is located).
        # If False, store the initial states along with each SampleBatch, use
        # it (as initial state when running through the network for training),
        # and update that initial state during training (from the internal
        # state outputs of the immediately preceding sequence).
        "hiddens": [512],
        "dueling": True,
        "zero_init_states": False,
        # If > 0, use the `burn_in` first steps of each replay-sampled sequence
        # (starting either from all 0.0-values if `zero_init_state=True` or
        # from the already stored values) to calculate an even more accurate
        # initial states for the actual sequence (starting after this burn-in
        # window). In the burn-in case, the actual length of the sequence
        # used for loss calculation is `n - burn_in` time steps
        # (n=LSTM’s/attention net’s max_seq_len).
        "burn_in": 4,
        # Whether to use the h-function from the paper [1] to scale target
        # values in the R2D2-loss function:
        # h(x) = sign(x)(􏰅|x| + 1 − 1) + εx
        "use_h_function": True,
        # The epsilon parameter from the R2D2 loss function (only used
        # if `use_h_function`=True.
        "h_function_epsilon": 1e-3,
        # === Hyperparameters from the paper [1] ===
        # Size of the replay buffer (in sequences, not timesteps).
        "buffer_size": 100000,
        # If True prioritized replay buffer will be used.
        "prioritized_replay": True,
        # Set automatically: The number of contiguous environment steps to
        # replay at once. Will be calculated via
        # model->max_seq_len + burn_in.
        # Do not set this to any valid value!
        "replay_sequence_length": -1,
        # Update the target network every `target_network_update_freq` steps.
        "target_network_update_freq": 2500,
        "evaluation_config": {
            # Example: overriding env_config, exploration, etc:
            # "env_config": {...},
            # "explore": False,
            "render_env": True,
            "record_env": "videos",
        },
        "model": {
            # Share layers for value function. If you set this to True, it's
            # important to tune vf_loss_coeff.
            "vf_share_layers": False,
            #    "use_lstm": True,
            "custom_model_config": {
                # "n_heads": tune.grid_search([1,2,3,4,5,6,7,8]),
            },
            "max_seq_len": 50,  # tune.grid_search([5, 10]),
        },
    }
    return config


def multiagent_config(args, config, policies, policy_mapping_fn):
    config["multiagent"] = {
        "policies": policies,
        "policy_mapping_fn": policy_mapping_fn,
        # "adversary_policy",
        "policies_to_train": args.policies_to_train,  # "good_policy","adversary_policy"
    }
    return config


def singleagent_config(args, config, model):
    config["model"] = model
    # config["preprocessor_pref"] = None
    # config["_disable_preprocessor_api"] = True
    return config


def return_config(args, env_config, custom_callback, policies, policy_mapping_fn):
    if args.algo == "impala":
        config = impala_config(args, env_config, custom_callback)
    elif args.algo == "ppo":
        config = ppo_config(args, env_config, custom_callback)
    elif args.algo == "r2d2":
        config = r2d2_config(args, env_config, custom_callback)
    else:
        assert 0, "Unknown algorithm"

    if args.scenario == "fort_attack":
        config = multiagent_config(args, config, policies, policy_mapping_fn)
        print("here")
    elif args.scenario == "atari":
        # policies is just the model config in the single agent scenario
        config = singleagent_config(args, config, policies)
    else:
        assert 0, "Unknown scenario"

    if args.cirriculum == "all":
        config["env_task_fn"] = curriculum_fn

    if args.test:
        config["multiagent"]["policies_to_train"] = []
        # to avoid spending too much time on sampling
        # config["rollout_fragment_length"] = 1
        # config["train_batch_size"] = 2
        # config["sgd_minibatch_size"] = 1
        # config["num_workers"] = 0

        # evaluation setup
        config["evaluation_interval"] = 1
        config["evaluation_num_episodes"] = 20
        config["evaluation_num_workers"] = 2

    if args.scenario == "fort_attack":
        config["concept_configs"] = concept_config(args)

    return config
