from runner_tools import WORKERS, add_job, random_search, Categorical, ATARI_57
import numpy as np

from typing import Union

QUICK_CHECK = False # limit to 1 seed on one environment with 0.1 epochs (just for testing)

ROLLOUT_SIZE = 128*128
ATARI_3_VAL = ['Assault', 'MsPacman', 'YarsRevenge']
ATARI_5 = ['BattleZone', 'DoubleDunk', 'NameThisGame', 'Phoenix', 'Qbert']


def add_run(experiment: str, run_name: str, default_args, env_args, subset:list, seeds:Union[int, list]=3, priority=0, seed_params=None, **kwargs):

    args = HPS_ARGS.copy()
    args.update(default_args)
    args.update(env_args)

    if seed_params is None:
        seed_params = {}

    if type(seeds) is int:
        seeds = list(range(1, seeds+1))

    if QUICK_CHECK:
        # just for testing
        seed = 1
        env = subset[0]
        add_job(
            experiment,
            run_name=f"game={env} {run_name} ({seed})",
            env_name=env,
            seed=seed,
            priority=priority - ((seed - 1) * 50),
            default_params=args,
            epochs=0.1,
            **seed_params.get(seed, {}),
            **kwargs,
        )
        return

    for seed in seeds:
        for env in subset:
            add_job(
                experiment,
                run_name=f"game={env} {run_name} ({seed})",
                env_name=env,
                seed=seed,
                priority=priority - ((seed - 1) * 50),
                default_params=args,
                **seed_params.get(seed, {}),
                **kwargs,
            )


HARD_MODE_ARGS = {
    # hard mode
    "terminal_on_loss_of_life": False,
    "reward_clipping": "off",
    "full_action_space": True,
    "repeat_action_probability": 0.25,
}

EASY_MODE_ARGS = {
    # hard mode
    "terminal_on_loss_of_life": True,
    "reward_clipping": "off",
    "full_action_space": False,
    "repeat_action_probability": 0.0,
}

# These are the best settings from the HPS, but not from the axis search performed later.
HPS_ARGS = {
    'checkpoint_every': int(5e6),
    'workers': WORKERS,
    'hostname': '',
    'architecture': 'dual',
    'export_video': False,
    'epochs': 50,
    'use_compression': False,
    'upload_batch': True,  # much faster
    'warmup_period': 1000,
    'disable_ev': False,
    'seed': 0,
    'mutex_key': "DEVICE",

    'max_grad_norm': 5.0,
    'agents': 128,                  # HPS
    'n_steps': 128,                 # HPS
    'policy_mini_batch_size': 2048, # HPS
    'value_mini_batch_size': 512,   # should be 256, but 512 for performance
    'distil_mini_batch_size': 512,  # should be 256, but 512 for performance
    'policy_epochs': 2,             # reasonable guess
    'value_epochs': 2,              # reasonable guess
    'distil_epochs': 2,             # reasonable guess
    'ppo_epsilon': 0.2,             # allows faster policy movement
    'policy_lr': 2.5e-4,
    'value_lr': 2.5e-4,
    'distil_lr': 2.5e-4,
    'entropy_bonus': 1e-2,           # standard
    'hidden_units': 512,             # standard
    'gae_lambda': 0.95,              # standard
    'td_lambda': 0.95,               # standard
    'repeated_action_penalty': 0.25, # HPS says 0, but I think we need this..

    # tvf params
    'use_tvf': False,

    # distil / replay buffer (This would have been called h11 before
    'distil_period': 1,
    'replay_size': 0,       # off for now...
    'distil_beta': 1.0,     # was 1.0

    'replay_mode': "uniform",

    # horizon
    'gamma': 0.999,

    # other
    'observation_normalization': True, # pong (and others maybe) do not work without this, so jsut default it to on..
}

# used in the PPO Paper
PPO_ORIG_ARGS = HPS_ARGS.copy()
PPO_ORIG_ARGS.update({
    'n_steps': 128,            # no change
    'agents': 8,
    'ppo_epsilon': 0.1,
    'policy_lr': 2.5e-4,
    'policy_lr_anneal': True,
    'ppo_epsilon_anneal': True,
    'entropy_bonus': 1e-2,     # no change
    'gamma': 0.99,
    'policy_epochs': 3,
    'td_lambda': 0.95,
    'gae_lambda': 0.95,
    'policy_mini_batch_size': 256,
    'vf_coef': 2.0, # because I use vf 0.5*MSE
    'value_epochs': 0,
    'distil_epochs': 0,
    'architecture': 'single',
})

DNA_TUNED_ARGS = HPS_ARGS.copy()
DNA_TUNED_ARGS.update({
    'gae_lambda': 0.8,
    'td_lambda': 0.95,
    'policy_epochs': 2,
    'value_epochs': 1,
    'distil_epochs': 2,
})

PPO_TUNED_ARGS = HPS_ARGS.copy()
PPO_TUNED_ARGS.update({
    'gae_lambda': 0.95,
    'td_lambda': 0.95,
    'policy_epochs': 1,
    'value_epochs': 0,
    'distil_epochs': 0,
    'architecture': 'single',
    'policy_network': 'nature_fat',
})

PPG_ARGS = HPS_ARGS.copy()
PPG_ARGS.update({
    'policy_epochs': 1,
    'value_epochs': 1,
    'distil_epochs': 0,
    'aux_epochs': 6,
    'aux_target': 'vtarg',
    'aux_source': 'value',
    'aux_period': 32,
    'replay_mode': 'sequential',
    'replay_size': 32*128*128,  # this is 0.5M frames (might need more?)
    'distil_batch_size': 32*128*128, # use entire batch (but only every 32th step)
    'use_compression': True,
    'upload_batch': False,
})

# old settings
# ---------------------------------------------------------------------------------------

def dna_hps(priority: int = 0):

    # second HPS
    # more distil searching
    # less n_steps more agents
    # less entropy
    # only 32 samples

    search_params = {
        # ppo params
        'entropy_bonus':     Categorical(3e-4, 1e-3, 3e-3, 1e-2, 3e-2),
        'agents':            Categorical(64, 128, 256, 512),
        'n_steps':           Categorical(32, 64, 128),
        'gamma':             Categorical(0.99, 0.997, 0.999),
        'gae_lambda':        Categorical(0.9, 0.95, 0.975),
        # dna params
        'policy_lr':         Categorical(1e-4, 2.5e-4, 5e-4),
        'distil_lr':         Categorical(1e-4, 2.5e-4, 5e-4),
        'value_lr':          Categorical(1e-4, 2.5e-4, 5e-4),
        'td_lambda':         Categorical(0.9, 0.95, 0.975),
        'policy_epochs':     Categorical(1, 2, 3),
        'value_epochs':      Categorical(1, 2, 3),
        'distil_epochs':     Categorical(1, 2, 3),
        'distil_beta':       Categorical(0.5, 1.0, 2.0),
        'policy_mini_batch_size': Categorical(256, 512, 1024, 2048),
        'value_mini_batch_size': Categorical(256, 512, 1024, 2048),
        'distil_mini_batch_size': Categorical(256, 512, 1024, 2048),
        'replay_size':       Categorical(*[x * (8*1024) for x in [1, 2, 4, 8]]),
        'repeated_action_penalty': Categorical(0, 0.25, 1.0),
        'entropy_scaling':   Categorical(True, False),

        # replay params
        'replay_mode':       Categorical("overwrite", "sequential", "uniform", "off"),
    }

    main_params = {
        'checkpoint_every': int(5e6),
        'workers': WORKERS,
        'architecture': 'dual',
        'export_video': False,
        'epochs': 50,
        'use_compression': True,
        'warmup_period': 1000,
        'disable_ev': False,
        'seed': 0,
        'mutex_key': "DEVICE",

        # hard mode
        "terminal_on_loss_of_life": False,
        "reward_clipping": "off",
        "full_action_space": True,
        "repeat_action_probability": 0.25,

        'max_grad_norm': 5.0,
        'ppo_epsilon': 0.1,
        'hidden_units': 256,

        # tvf params
        'use_tvf': False,

        # distil / replay buffer (This would have been called h11 before
        'distil_period': 1,
        'replay_size': 0,       # off for now...
        'distil_beta': 2.0,     # was 1.0

        'replay_mode': "uniform",

        # horizon
        'gamma': 0.999, # was 0.997

        # other
        'observation_normalization': True,
    }

    def fixup_params(params):

        rollout_size = params['agents'] * params['n_steps']

        # default to using rollout size for distil
        params['distil_batch_size'] = rollout_size

        # set replay_size to 0 if it is not being used
        if params['replay_mode'] == "off":
            params["replay_size"] = 0

        # limit epochs to 6 (otherwise they will be too slow...)
        epoch_params = ['policy_epochs', 'value_epochs', 'distil_epochs']
        while sum(params[x] for x in epoch_params) > 6:
            dice_roll = np.random.randint(0, 3)
            if params[epoch_params[dice_roll]] > 1:
                params[epoch_params[dice_roll]] -= 1

        params["use_compression"] = params['replay_size'] + rollout_size > 32*1024
        params["disable_ev"] = True

    random_search(
        "DNA_SEARCH",
        main_params,
        search_params,
        count=32,
        process_up_to=32,
        envs=ATARI_3_VAL,
        hook=fixup_params,
        priority=priority,
    )


def dna_lambda(priority: int=0):
    """
    Demonstrate that GAE for advantages and for return values can be different.
    """

    for gae_lambda in [0.6, 0.8, 0.9, 0.95, 0.975]:
        for td_lambda in [0.8, 0.95]:
            add_run(
                experiment="DNA_LAMBDA",
                run_name=f"td_lambda={td_lambda} gae_lambda={gae_lambda}",
                default_args=DNA_TUNED_ARGS,
                env_args=HARD_MODE_ARGS,
                gae_lambda=gae_lambda,
                td_lambda=td_lambda,
                priority=priority,
                seeds=5,
                subset=ATARI_3_VAL,
                seed_params={4: {'are_mode': 'shadow'}}
            )
    # additional runs to check noise
    for gae_lambda in [0.95]:
        for td_lambda in [0.6, 0.8, 0.9, 0.95, 0.975, 0.9875, 0.99375, 1.0]:
            add_run(
                experiment="DNA_LAMBDA",
                run_name=f"td_lambda={td_lambda} gae_lambda={gae_lambda}",
                default_args=DNA_TUNED_ARGS,
                env_args=HARD_MODE_ARGS,
                gae_lambda=gae_lambda,
                td_lambda=td_lambda,
                priority=priority,
                seeds=[4],
                subset=ATARI_3_VAL,
                seed_params={4: {'are_mode': 'shadow'}}
            )


def dna_A57(priority=0):

    for path, env_args in zip(["A57_HARD", "A57_EASY"], [HARD_MODE_ARGS, EASY_MODE_ARGS]):

        COMMON_ARGS = {
            'experiment': path,
            'seeds': 1,
            'subset': ATARI_57,
            'priority': priority,
        }

        add_run(
            run_name="dna_tuned",
            default_args=DNA_TUNED_ARGS,
            env_args=env_args,
            **COMMON_ARGS
        )

        add_run(
            run_name="ppo_tuned_fat",
            default_args=PPO_TUNED_ARGS,
            env_args=env_args,
            **COMMON_ARGS
        )

def ppg_final(priority=0):
    """
    Our main results...
    (including ablations)
    """

    add_run(
        experiment="DNA_FINAL",
        run_name=f"ppg",
        default_args=PPG_ARGS,
        env_args=HARD_MODE_ARGS,
        priority=priority,
        seeds=3,
        subset=ATARI_5,
    )


def dna_final_easy(priority=0):

    add_run(
        experiment="DNA_FINAL_EASY",
        run_name="dna_tuned",
        default_args=DNA_TUNED_ARGS,
        env_args=EASY_MODE_ARGS,
        priority=priority,
        seeds=3,
        subset=ATARI_5,
    )

    add_run(
        experiment="DNA_FINAL_EASY",
        run_name="ppo_tuned_fat",
        default_args=PPO_TUNED_ARGS,
        env_args=EASY_MODE_ARGS,
        priority=priority,
        seeds=3,
        subset=ATARI_5,
    )

    add_run(
        experiment="DNA_FINAL_EASY",
        run_name="ppo_orig",
        default_args=PPO_ORIG_ARGS,
        env_args=EASY_MODE_ARGS,
        priority=priority,
        seeds=3,
        subset=ATARI_5,
    )


def dna_final(priority=0):
    """
    Our main results...
    """

    # hard

    add_run(
        experiment="DNA_FINAL",
        run_name="dna_tuned",
        default_args=DNA_TUNED_ARGS,
        env_args=HARD_MODE_ARGS,
        priority=priority,
        seeds=3,
        subset=ATARI_5,
    )

    add_run(
        experiment="DNA_FINAL",
        run_name="ppo_tuned_fat",
        default_args=PPO_TUNED_ARGS,
        env_args=HARD_MODE_ARGS,
        priority=priority,
        seeds=3,
        subset=ATARI_5,
    )

def dna_tuning(priority=0):

    COMMON_ARGS = {
        'experiment': "DNA_TUNING",
        'seeds': 3,
        'subset': ATARI_3_VAL,
        'priority': priority,
    }

    for epochs in [1, 2, 3, 4]:
        add_run(
            run_name=f"epochs=2{epochs}2",
            default_args=HPS_ARGS,
            env_args=HARD_MODE_ARGS,
            policy_epochs=2,
            value_epochs=epochs,
            distil_epochs=2,
            **COMMON_ARGS
        )
    for epochs in [1, 2, 3, 4]:
        add_run(
            run_name=f"epochs={epochs}12",
            default_args=HPS_ARGS,
            env_args=HARD_MODE_ARGS,
            policy_epochs=epochs,
            value_epochs=1,
            distil_epochs=2,
            **COMMON_ARGS
        )
    for epochs in [0, 1, 2, 3]:
        add_run(
            run_name=f"epochs=21{epochs}",
            default_args=HPS_ARGS,
            env_args=HARD_MODE_ARGS,
            policy_epochs=2,
            value_epochs=1,
            distil_epochs=epochs,
            **COMMON_ARGS
        )


def dna_distil(priority=0):

    COMMON_ARGS = {
        'experiment': "DNA_DISTIL",
        'seeds': 3,
        'subset': ATARI_3_VAL,
        'priority': priority,
    }

    for distil_mode in ['off', 'value', 'features', 'projection']:
        add_run(
            run_name=f"distil_mode={distil_mode}",
            default_args=HPS_ARGS,
            env_args=HARD_MODE_ARGS,
            policy_epochs=2,
            value_epochs=1,
            distil_epochs=0 if distil_mode == "off" else 2,
            distil_mode=distil_mode,
            **COMMON_ARGS
    )


def dna_noise(priority=0):

    COMMON_ARGS = {
        'experiment': "DNA_NOISE",
        'seeds': 1,
        'subset': ATARI_3_VAL,
        'priority': priority,
    }

    add_run(
        run_name=f"",
        default_args=HPS_ARGS,
        env_args=HARD_MODE_ARGS,
        are_mode="shadow",
        **COMMON_ARGS
    )


def ppo_tuning(priority: int = 0):

    COMMON_ARGS = {
        'experiment': "PPO_TUNING",
        'seeds': 3,
        'subset': ATARI_3_VAL,
        'priority': priority,
    }

    add_run(
        run_name=f"reference",
        default_args=PPO_ORIG_ARGS,
        env_args=HARD_MODE_ARGS,
        **COMMON_ARGS,
    )

    for policy_epochs in [1, 2, 3, 4]:
        gae_lambda = 0.95
        add_run(
            run_name=f"epochs={policy_epochs} lambda={gae_lambda}",
            default_args=PPO_ORIG_ARGS,
            env_args=HARD_MODE_ARGS,
            policy_epochs=policy_epochs,
            gae_lambda=gae_lambda,
            td_lambda=gae_lambda,
            **COMMON_ARGS,
        )

    for gae_lambda in [0.8, 0.9, 0.95, 0.975]:
        policy_epochs = 1
        add_run(
            run_name=f"epochs={policy_epochs} lambda={gae_lambda}",
            default_args=PPO_ORIG_ARGS,
            env_args=HARD_MODE_ARGS,
            policy_epochs=policy_epochs,
            gae_lambda=gae_lambda,
            td_lambda=gae_lambda,
            **COMMON_ARGS,
        )


def dna_ablations(priority:int = 0):

    COMMON_ARGS = {
        'experiment': "DNA_FINAL",
        'seeds': 3,
        'subset': ATARI_5,
        'priority': priority,
        'hostname': "",
        'env_args': HARD_MODE_ARGS,
    }

    # no seeds for ablations, just want to get an idea for what matters.

    # ---------------------------
    # DNA ablations

    add_run(
        run_name="dna_tuned_pmbs_512",
        default_args=DNA_TUNED_ARGS,
        policy_mini_batch_size=512,
        **COMMON_ARGS
    )

    add_run(
        run_name="dna_no_distil",
        default_args=DNA_TUNED_ARGS,
        distil_epochs=0,
        **COMMON_ARGS
    )

    add_run(
        run_name="dna_fixed_lambda",
        default_args=DNA_TUNED_ARGS,
        gae_lambda=0.95,
        td_lambda=0.95,
        **COMMON_ARGS
    )

    add_run(
        run_name="ppo_basic",
        default_args=PPO_TUNED_ARGS,
        policy_epochs=2,
        policy_network='nature',
        gae_lambda=0.95,
        td_lambda=0.95,
        policy_mini_batch_size=512,
        **COMMON_ARGS
    )

    # ---------------------------
    # PPO ablations

    add_run(
        run_name="ppo_2_fat",
        default_args=PPO_TUNED_ARGS,
        policy_epochs=2,
        **COMMON_ARGS
    )

    add_run(
        run_name="ppo_nature",
        default_args=PPO_TUNED_ARGS,
        policy_network="nature",
        **COMMON_ARGS
    )

    add_run(
        run_name="ppo_lambda",
        default_args=PPO_TUNED_ARGS,
        gae_lambda=0.8,
        td_lambda=0.95,
        **COMMON_ARGS
    )

    add_run(
        run_name="ppo_orig",
        default_args=PPO_ORIG_ARGS,
        **COMMON_ARGS
    )

    # ---------------------------
    # PPG ablations

    add_run(
        run_name="ppg",
        default_args=PPG_ARGS,
        **COMMON_ARGS
    )

    add_run(
        run_name="ppg_tuned",
        default_args=PPG_ARGS,
        policy_epochs=2,
        value_epochs=1,
        aux_epochs=2,
        **COMMON_ARGS
    )


def setup():

    # dna_hps()

    # ---------------------------
    # atari-3 validation


    ppo_tuning()
    dna_tuning()

    dna_noise()
    dna_distil()
    dna_lambda()

    # ---------------------------
    # atari-5 test set

    dna_final()
    dna_final_easy()
    ppg_final()
    dna_ablations()

    # ---------------------------
    # atari-57
    dna_A57()


