#!/usr/bin/env python
# coding: utf-8

# # Experiment Paramter List

# In[2]:
import atari_py as ap
from mdp_helper_files.exp_track_helper import Experiment, ExperimentPool
import sys

def conv_name(n, suffix = "NoFrameskip-v0"):
    return "".join([n[0].upper() + n[1:] for n in n.split("_")]) + suffix


game_list = ap.list_games()
discrete_envs = [conv_name(g, suffix="NoFrameskip-v0") for g in sorted(game_list)]
sotchastic_envs = [conv_name(g, suffix="NoFrameskip-v4") for g in sorted(game_list)]
all_envs = [*discrete_envs, *sotchastic_envs]

# Atari Tests
ataridetExps = ExperimentPool()
for env in  all_envs:
    for bottleneck_size in [16,32,64]:
        for latent_type in ["PreTrainedDQN", "OfflineDQN", "BCQ", "BCQ2","BCQ3","Random"]: # Doesnt matter in this case
            for mdp_build_k in [1,5,11]:
                for penalty_beta in [0, 0.001,0.01, 0.1, 1,10,100,1000,10000,1000000]:
                    for discount_factor in [0.9, 0.99, 0.999]:
                        build_eval_job_id = f"AD-{env}-BuildEvalMDP-L{latent_type}-{mdp_build_k}NN-P{penalty_beta}-G{discount_factor}-B{bottleneck_size}"
                        ataridetExps.add_experiment(Experiment(id=build_eval_job_id,
                           meta="Build  MDP for optimal Policy",
                           expPrefix="python batch_mdp_test.py ",
                           expSuffix=f"--env {env} --bottleneck_size {bottleneck_size} --MAX_S_COUNT 2500000 --max_timesteps 2500000 --MAX_NS_COUNT {mdp_build_k+1}\
                                    --fill_with 0Q_src-KNN --mdp_build_k {mdp_build_k} --penalty_type linear --penalty_beta {penalty_beta} --gamma {discount_factor} --build_mdp --save_mdp\
                                    --eval_episode_count 10 --smoothing --policy_k 1 5 11 31 51 --normalize_by_distance --latent_type {latent_type}"))

                        eval_job_id = f"AD-{env}-EvalMDP-L{latent_type}-{mdp_build_k}NN-P{penalty_beta}-G{discount_factor}-B{bottleneck_size}"
                        ataridetExps.add_experiment(Experiment(id=eval_job_id,
                                                               meta="Build  MDP for optimal Policy",
                                                               expPrefix="python batch_mdp_test.py ",
                                                               expSuffix=f"--env {env} --bottleneck_size {bottleneck_size} --MAX_S_COUNT 2500000 --max_timesteps 2500000 --MAX_NS_COUNT {mdp_build_k+1}\
                                                            --fill_with 0Q_src-KNN --mdp_build_k {mdp_build_k} --penalty_type linear --penalty_beta {penalty_beta} --gamma {discount_factor} --load_mdp \
                                                            --eval_episode_count 10 --smoothing --policy_k 1 5 11 31 51 --normalize_by_distance --latent_type {latent_type}"))


                         
                            
# Atari Tests
for env in all_envs:
    ataridetExps.add_experiment(Experiment(id=f"AD-{env}-TrainOnlineDQN",
                                           meta="Train BCQ policy using the collected dataset",
                                           expPrefix="python main.py ",
                                           expSuffix=f"--train_behavioral --max_timesteps 10000000 --buffer_size 1000000 --env {env}"))

    ataridetExps.add_experiment(Experiment(id=f"AD-{env}-CollectDataset",
                                           meta="Collect Dataset using trained policy",
                                           expPrefix="python main.py ",
                                           expSuffix=f"--generate_buffer --max_timesteps 2500000 --buffer_size 2500000 --env {env}"))

    ataridetExps.add_experiment(Experiment(id=f"AD-{env}-TrainBCQ",
                                           meta="Train BCQ policy using the collected dataset",
                                           expPrefix="python main.py ",
                                           expSuffix=f"--train_BCQ --max_timesteps 1000000 --buffer_size 2500000 --env {env}"))

    ataridetExps.add_experiment(Experiment(id=f"AD-{env}-TrainBCQtwo",
                                           meta="Train BCQ policy using the collected dataset",
                                           expPrefix="python main.py ",
                                           expSuffix=f"--train_BCQ2 --max_timesteps 1000000 --buffer_size 2500000 --env {env}"))

    ataridetExps.add_experiment(Experiment(id=f"AD-{env}-TrainBCQthree",
                                           meta="Train BCQ policy using the collected dataset",
                                           expPrefix="python main.py ",
                                           expSuffix=f"--train_BCQ3 --max_timesteps 1000000 --buffer_size 2500000 --env {env}"))

    ataridetExps.add_experiment(Experiment(id=f"AD-{env}-TrainOfflineDQN",
                                           meta="Train OfflineDQN policy using the collected dataset",
                                           expPrefix="python main.py ",
                                           expSuffix=f"--train_OfflineDQN --max_timesteps 1000000 --buffer_size 2500000 --env {env}"))
    
    
data_size_name_map = {"verySmall":100000, "Small":500000, "Medium":1000000}    
    
# Atari Tests
for env in all_envs:
    for buffer_name, buffer_size in data_size_name_map.items():
#         ataridetExps.add_experiment(Experiment(id=f"AD-{env}-{buffer_name}-TrainOnlineDQN",
#                                                meta="Train BCQ policy using the collected dataset",
#                                                expPrefix="python main.py ",
#                                                expSuffix=f"--train_behavioral --max_timesteps {buf} --buffer_size 1000000 --env {env}"))

        ataridetExps.add_experiment(Experiment(id=f"AD-{env}-{buffer_name}-CollectDataset",
                                               meta="Collect Dataset using trained policy",
                                               expPrefix="python main.py ",
                                               expSuffix=f"--generate_buffer --buffer_name {buffer_name} --max_timesteps {buffer_size} --buffer_size 2500000 --env {env}"))

        ataridetExps.add_experiment(Experiment(id=f"AD-{env}-{buffer_name}-TrainBCQ",
                                               meta="Train BCQ policy using the collected dataset",
                                               expPrefix="python main.py ",
                                               expSuffix=f"--train_BCQ --buffer_name {buffer_name} --max_timesteps {buffer_size} --buffer_size {buffer_size} --env {env}"))

        ataridetExps.add_experiment(Experiment(id=f"AD-{env}-{buffer_name}-TrainBCQtwo",
                                               meta="Train BCQ policy using the collected dataset",
                                               expPrefix="python main.py ",
                                               expSuffix=f"--train_BCQ2 --buffer_name {buffer_name} --max_timesteps {buffer_size} --buffer_size {buffer_size} --env {env}"))

        ataridetExps.add_experiment(Experiment(id=f"AD-{env}-{buffer_name}-TrainBCQthree",
                                               meta="Train BCQ policy using the collected dataset",
                                               expPrefix="python main.py ",
                                               expSuffix=f"--train_BCQ3 --buffer_name {buffer_name} --max_timesteps {buffer_size} --buffer_size {buffer_size} --env {env}"))

        ataridetExps.add_experiment(Experiment(id=f"AD-{env}-{buffer_name}-TrainOfflineDQN",
                                               meta="Train OfflineDQN policy using the collected dataset",
                                               expPrefix="python main.py ",
                                               expSuffix=f"--train_OfflineDQN --buffer_name {buffer_name} --max_timesteps {buffer_size} --buffer_size {buffer_size} --env {env}"))


    
    
# CartPole Exps
cartPoleExps = ExperimentPool()
for env in ["CartPoleRandomFeature-v1", "CartPoleImage-v1"]:
    for bottleneck_size in [16,32,64]:
        for latent_type in ["PreTrainedDQN", "OfflineDQN", "BCQ", "BCQ2","BCQ3","Random"]: # Doesnt matter in this case
            for mdp_build_k in [1,5,11]:
                for penalty_beta in [0, 0.001,0.01, 0.1, 1,10,100,1000,10000,1000000]:
                    for discount_factor in [0.9, 0.99, 0.999]:
                        build_eval_job_id = f"CP-{env}-BuildEvalMDP-L{latent_type}-{mdp_build_k}NN-P{penalty_beta}-G{discount_factor}-B{bottleneck_size}"
                        cartPoleExps.add_experiment(Experiment(id=build_eval_job_id,
                           meta="Build  MDP for optimal Policy",
                           expPrefix="python batch_mdp_test.py ",
                           expSuffix=f"--env {env} --bottleneck_size {bottleneck_size} --MAX_S_COUNT 100000 --max_timesteps 100000 --MAX_NS_COUNT {mdp_build_k+1}\
                                    --fill_with 0Q_src-KNN --mdp_build_k {mdp_build_k} --penalty_type linear --penalty_beta {penalty_beta} --gamma {discount_factor} --build_mdp --save_mdp\
                                    --eval_episode_count 50 --smoothing --policy_k 1 5 11 31 51 --normalize_by_distance --latent_type {latent_type}"))

                        eval_job_id = f"CP-{env}-EvalMDP-L{latent_type}-{mdp_build_k}NN-P{penalty_beta}-G{discount_factor}-B{bottleneck_size}"
                        cartPoleExps.add_experiment(Experiment(id=eval_job_id,
                                                               meta="Build  MDP for optimal Policy",
                                                               expPrefix="python batch_mdp_test.py ",
                                                               expSuffix=f"--env {env} --bottleneck_size {bottleneck_size} --MAX_S_COUNT 100000 --max_timesteps 100000 --MAX_NS_COUNT {mdp_build_k+1}\
                                                            --fill_with 0Q_src-KNN --mdp_build_k {mdp_build_k} --penalty_type linear --penalty_beta {penalty_beta} --gamma {discount_factor} --load_mdp \
                                                            --eval_episode_count 50 --smoothing --policy_k 1 5 11 31 51 --normalize_by_distance --latent_type {latent_type}"))
                        
                        
# Atari Tests
dummyExps = ExperimentPool()

for k in ["A1", "A2", "A3"]:
    ataridetExps.add_experiment(Experiment(id=f"dummyExp-{k}",
                                           meta="DummyExperiments",
                                           expPrefix="python dummyExp.py ",
                                           expSuffix=""))


pools = [ataridetExps, dummyExps, cartPoleExps]
all_exp_keys = []
for pool in pools:
    all_exp_keys.extend(list(pool.expPool.keys()))
assert len(all_exp_keys) == len(set(all_exp_keys))


ExpPool = ExperimentPool.joinPools(*pools)



if __name__ == '__main__':
    query_exp_id = sys.argv[1]
    for exp_id in ExpPool.expPool:
        if query_exp_id in exp_id:
            print(exp_id)
