#!/usr/bin/env python
# coding: utf-8

# In[1]:


from tqdm import tqdm
import argparse
import numpy as np

import argparse
import copy
import importlib
import json
import os

import numpy as np
import torch

import discrete_BCQ
import discrete_BCQ2
import discrete_BCQ3
import DQN
import utils

from IPython import display
import matplotlib.pyplot as plt
from main import *

from bigmdp.utils.tmp_vi_helper import *
from bigmdp.mdp.mdp_agent import SimpleAgent
from bigmdp.mdp.MDP_GPU_KDD import FullMDP
import Experiments as eXP
from mdp_helper_files.arg_def import *
import pickle as pk
import mdp_helper_files.dqn_code as dqn
from functools import partial


# In[2]:


# ! jupyter nbconvert --to script batch_mdp_test.ipynb


# In[3]:



# Load parameters
parser = argparse.ArgumentParser()
ArgumentDict = {}

parser.add_argument("--env", default="PongNoFrameskip-v0")     # OpenAI gym environment name
parser.add_argument("--seed", default=0, type=int)             # Sets Gym, PyTorch and Numpy seeds
parser.add_argument("--buffer_name", default="Default")        # Prepends name to filename
parser.add_argument("--max_timesteps", default=1e6, type=int)  # Max time steps to run environment or train for
parser.add_argument("--BCQ_threshold", default=0.3, type=float)# Threshold hyper-parameter for BCQ
parser.add_argument("--low_noise_p", default=0.2, type=float)  # Probability of a low noise episode when generating buffer
parser.add_argument("--rand_action_p", default=0.2, type=float)# Probability of taking a random action when generating buffer, during non-low noise episode
parser.add_argument("--train_behavioral", action="store_true") # If true, train behavioral policy
parser.add_argument("--generate_buffer", action="store_true")  # If true, generate buffer
parser.add_argument("--dummy_run", action="store_true")

parser.add_argument("--save_full_checkpoints", action="store_true") # If true, save all intermediate policies
parser.add_argument("--fast_run", action="store_true") # If true, save all intermediate policies
parser.add_argument("--eval_checkpoint", default=0, type=int) # If true, save all intermediate policies



ArgumentDict.update({"mainArgs": ["env", "seed", "buffer_name" , "max_timesteps", "BCQ_threshold", "low_noise_p", "rand_action_p",
                    "train_behavioral" , "generate_buffer",]})

parser.add_argument("--exp_meta", help="Set something sensible for a simple experiment with a small number of runs, used for Sweep generation", type=str, default="Default_Exp_Name")
parser.add_argument("--exp_id", help="Used for grouping the runs, add a hash of the experiment used for sweep generation", type=str, default="E404")

# MDP Build parameters
mdpBuildArgs = parser.add_argument_group("MDP build arguments")
mdpBuildArgs.add_argument("--unknown_transition_reward", help="default reward for unknown transitions to absorbing state", type=int, default=-1000)
mdpBuildArgs.add_argument("--rmax_reward", help="Default reward for RMAX reward", type=int, default= 10000)
mdpBuildArgs.add_argument("--balanced_exploration", help="Try to go to all states equally often", type=int, default= 0)
mdpBuildArgs.add_argument("--rmax_threshold", help="Number of travesal before annealing rmax reward", type=int, default= 2)
mdpBuildArgs.add_argument("--MAX_S_COUNT", help="maximum state count  for gpu rewource allocation", type=int, default= 250000)
mdpBuildArgs.add_argument("--MAX_NS_COUNT", help="maximum nest state count  for gpu rewource allocation", type=int, default=20)
mdpBuildArgs.add_argument("--def_device", help="Default device to use for building the MDP", type=str, default= "GPU")
mdpBuildArgs.add_argument("--weight_transitions", help="Caluclate transition prob based on transition frequencies?", type=int, default= 1)
# mdpBuildArgs.add_argument("--weight_neighbors", help="Caluclate transition prob based on neighbor distances ?", type=int, default= 1)
mdpBuildArgs.add_argument("--fill_with", help="Define how to fill missing state actions", type=str, default="0Q_src-KNN", choices=["0Q_src-KNN", "1Q_dst-KNN","kkQ_dst-1NN", "none"])
mdpBuildArgs.add_argument("--mdp_build_k", help="Number of Nearest neighbor to consider k", type=int, default= 1)
mdpBuildArgs.add_argument("--knn_delta", help="Define the bias parmeter for nearest neighbor distance", type=float, default=1e-8)
mdpBuildArgs.add_argument("--penalty_type", help="penalized predicted rewards based on the distance to the state", type=str, default="linear", choices=["none", "linear", "exponential"])
mdpBuildArgs.add_argument("--penalty_beta", help="beta multiplyer for penalizing rewards based on distance", type=float, default= 1)
mdpBuildArgs.add_argument("--within_radius", help="Radius to cap the prediction to absorbing state", type=int, default= 100)
mdpBuildArgs.add_argument("--filter_with_abstraction", help="Set to true, to filter the states to be added based on the radius.", type=int, default= 0)
mdpBuildArgs.add_argument("--normalize_by_distance", help="set it on if the transition probabilities should be normalized by distance.", action = "store_true")

ArgumentDict.update({"mdpBuildArgs": ["unknown_transition_reward", "rmax_reward", "balanced_exploration" , "rmax_threshold", "MAX_S_COUNT", "def_device", "weight_transitions",
                    "fill_with" , "mdp_build_k", "knn_delta", "penalty_type" ,"penalty_beta", "within_radius", "filter_with_abstraction", "normalize_by_distance"]})

# MDP solve and lift up parameters
mdpSolveArgs = parser.add_argument_group("MDP build arguments")
mdpSolveArgs.add_argument("--gamma", help="Discount Factor for Value iteration", type=float, default= 0.99)
mdpSolveArgs.add_argument("--slip_probability", help="Slip probability for safe policy", type=float, default= 0.1)
mdpSolveArgs.add_argument("--target_vi_error", help="target belllman backup error for considering solved", type=float, default= 0.001)
mdpSolveArgs.add_argument("--bellman_backup_every", help="Do a bellman backups every __k frames", type=int, default= 100)
mdpSolveArgs.add_argument("--n_backups", help="The number of backups for every backup step", type=int, default= 10)
ArgumentDict.update({"mdpSolveArgs":["gamma", "slip_probability", "target_vi_error","bellman_backup_every", "n_backups", "policy_k",]})

# Network parameters
netDefArgs =parser.add_argument_group("Network Definition arguments")
netDefArgs.add_argument("--hidden_state_size", help="Size of hidden state for MLP", type=int, default= 64)
netDefArgs.add_argument("--encoder_type", help = "set to one of the choices to change network encoder type", type = str, choices = ["conv", "conv_bn", "conv_small", "linear", "none"], default = "none")
netDefArgs.add_argument("--bottleneck_size", help="Size of the latent space.", type=int, default= 16)
netDefArgs.add_argument("--dont_do_pca", help="set to do principal component analysis when using dqn representation, peertinent for Atari for now",action = "store_true")
netDefArgs.add_argument("--latent_type", help="Set appropriately amont the choices contrastive and Q learning", type = str, default="Contrastive") # none simply means that it cannot be trained, loss function wont be defined, used for random projection
ArgumentDict.update({"netDefArgs":["hidden_state_size", "encoder_type", "bottleneck_size","do_pca","learnt_latent", "latent_type", ]})

# Evaluation Parameters
evalArgs =parser.add_argument_group("Evaluation Arguments")
evalArgs.add_argument("--build_mdp", help="set to build the MDP from the data and the network",action = "store_true")
evalArgs.add_argument("--load_mdp", help="set to load the MDP from memory",action = "store_true")
evalArgs.add_argument("--save_mdp", help="set to save the built MDP",action = "store_true")
evalArgs.add_argument("--test", help="set to test the policy in the environment", action = "store_true")
evalArgs.add_argument("--test_path_following", help="set to test the policy in the environment", action = "store_true")
evalArgs.add_argument("--video_count", help="Set to greater than 0 to log video", type=int, default=0)
evalArgs.add_argument("--eval_episode_count", help="Number of episodes to evaluate the policy", type=int, default=249)
evalArgs.add_argument("--smoothing", help="Use more than one K NN for lifting up the policy", action="store_true")
evalArgs.add_argument("--soft_q", help="Sample according to Q values rather than max action", action="store_true")
evalArgs.add_argument("--smooth_with_seen", help="do nearest neighbor on seen state action pair than seen states", action="store_true")
evalArgs.add_argument("--policy_k", help="List the lift up parameter policy_k you want to test with", nargs="+", type= int, default=[1])
ArgumentDict.update({"evalArgs":["build_mdp", "load_mdp", "save_mdp","test", "generate_video", "eval_episode_count","test_path_following",
                                 "video_count", "smoothing", "soft_q", "smooth_with_seen","policy_k"]})


# exp = eXP.ExpPool.get_by_id("AD-BreakoutNoFrameskip-v0-BuildEvalMDP-LOfflineDQN-5NN-P0.01-G0.99-B16")
# args = parser.parse_args(exp.expSuffix.split(" "))

args = parser.parse_args()


# In[4]:


for argGroup, argList in ArgumentDict.items():
    print_args(args,to_show_args = argList, title = argGroup)
    print("\n")


# In[5]:



# Make env and determine properties
ATARI_PREPROCESSING_PARAMS["max_noop"] = 0 if args.env[-2:]=="v0" else 30 # if no sticky actions creates an initial state distribution.
env, is_atari, state_dim, num_actions = utils.make_env(args.env, ATARI_PREPROCESSING_PARAMS)
parameters = ATARI_PARAMETERS if is_atari else REGULAR_PARAMETERS

env.seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize buffer
parameters["buffer_size"] = max(parameters["buffer_size"], args.max_timesteps)
print("New Buffer Size:",parameters["buffer_size"])
train_buffer = utils.ReplayBuffer(state_dim, is_atari, ATARI_PREPROCESSING_PARAMS, parameters["batch_size"], parameters["buffer_size"], device)


# For saving files
if args.save_full_checkpoints:
    env_setting = f"{args.env}_{args.seed}_T{args.eval_checkpoint}"
else:
    env_setting = f"{args.env}_{args.seed}"

full_buffer_name = f"{args.buffer_name}_{env_setting}"

args.env_setting = env_setting
print(f"Env setting used:{args.env_setting}")

# Load replay buffer
train_buffer.load(f"./buffers/{args.buffer_name}_{args.env}_{args.seed}")
print("Length of Traning Buffer", len(train_buffer))


# In[8]:


# args.dummy_run = True
if args.dummy_run:
    print("Big Warning here dummy run is being triggered.")

# In[9]:


if args.dummy_run:
    train_buffer.crt_size = 50000
    args.MAX_S_COUNT = 50000
    args.eval_episode_count = 1
    args.policy_k = [1,5]
    # args.load_mdp = False
    # args.build_mdp = False
    # args.save_mdp = False


# In[10]:


args.do_pca = not args.dont_do_pca

def load_Agent(args):

    print(f"Using learnt represetntation {args.latent_type}")

    # Initialize and load policy
    policy_class_map = {"BCQ": discrete_BCQ.discrete_BCQ,
                        "BCQ2": discrete_BCQ2.discrete_BCQ,
                        "BCQ3": discrete_BCQ3.discrete_BCQ,
                        "PreTrainedDQN": DQN.DQN,
                        "OfflineDQN":DQN.DQN,
                        "Random": DQN.DQN}
    policy_class = policy_class_map[args.latent_type]
    policy = policy_class(
                    is_atari,
                    num_actions,
                    state_dim,
                    device,
                    discount = parameters["discount"],
                    optimizer = parameters["optimizer"],
                    optimizer_parameters= parameters["optimizer_parameters"],
                    polyak_target_update= parameters["polyak_target_update"],
                    target_update_frequency= parameters["target_update_freq"],
                    tau= parameters["tau"],
                    initial_eps= parameters["initial_eps"],
                    end_eps= parameters["end_eps"],
                    eps_decay_period= parameters["eps_decay_period"],
                    eval_eps= parameters["eval_eps"],
                    )

    if args.latent_type == "PreTrainedDQN":
        if os.path.exists(f"./models/best_behavioral_{args.env_setting}"):
            policy.load(f"./models/best_behavioral_{args.env_setting}")
        else:
            policy.load(f"./models/behavioral_{args.env_setting}")
    elif args.latent_type == "OfflineDQN":
        print("Loading, ", f"./models/OfflineDQN_{full_buffer_name}")
        policy.load(f"./models/OfflineDQN_{full_buffer_name}")
    elif args.latent_type == "BCQ":
        policy.load(f"./models/BCQ_{full_buffer_name}")
    elif args.latent_type == "BCQ2":
        policy.load(f"./models/BCQ2_{full_buffer_name}")
    elif args.latent_type == "BCQ3":
        policy.load(f"./models/BCQ3_{full_buffer_name}")
        print(f"loaded policy from ./models/BCQ3_{full_buffer_name}")
    elif args.latent_type == "Random":
        print("random projection used, network is untrained")
    else:
        assert False

    myDQNAgent = dqn.DQNAgent((ATARI_PREPROCESSING_PARAMS['state_history'], *env.observation_space.shape),
                              env.action_space.n, wandb_logger=None,
                              use_cuda=True, double_dqn=False, multiplyer = 1)

    myDQNAgent.net = policy.Q


    if args.do_pca and not args.load_mdp:
        print("PCA Started")
        myDQNAgent.use_cuda = True
        train_buffer.batch_size = 256
        all_latent_batches = []
        start_end_indexes = get_iter_indexes(len(train_buffer), train_buffer.batch_size)

        with torch.no_grad():
            for start_i, end_i in tqdm(start_end_indexes):
                batch = train_buffer.sample_indices(list(range(start_i, end_i)))
                batch_s, batch_a, batch_ns, batch_r, batch_d = batch
                batch_latent = myDQNAgent.net.encode(batch_s)
                all_latent_batches.append(batch_latent)
        latent_dataset = torch.cat(all_latent_batches)

        print("Latent Dataset ccollected")


        myDQNAgent.fit_pca(latent_size=args.bottleneck_size, data = latent_dataset.cpu().numpy())
        myDQNAgent.pca_flag = 1

        print("PCA Complete")

    return myDQNAgent


# In[11]:


latentNetwork = load_Agent(args)


# In[15]:


# Sanity Check

print("SanityCheck , getting dqn performance")
env.max_episode_length = ATARI_PREPROCESSING_PARAMS["max_episode_timesteps"]
if args.latent_type in ["OfflineDQN", "PreTrainedDQN", "Random"]:
    policy = lambda s : partial(latentNetwork.get_action, epsilon = 0)(s)
elif args.latent_type in ["BCQ", "BCQ2", "BCQ3"]:
    policy = lambda s : partial(latentNetwork.get_bcq_action, epsilon = 0)(s)
avg_reward = evaluate_on_env(env,policy, eps_count=args.eval_episode_count, render=False)[0]
np.save(f"results/{args.latent_type}_{full_buffer_name}", np.array([avg_reward]))
print("---------------------------------------")
print(f"Evaluation over 10 episodes: {avg_reward:.3f}")
print("---------------------------------------")


# In[14]:


# print(info)


# In[15]:


# start_end_indexes = get_iter_indexes(len(train_buffer)-1, train_buffer.batch_size)
# train_buffer.sample_indices(start_end_indexes[-1])[0].shape


# In[16]:


print_args(args,ArgumentDict["mdpBuildArgs"], title="mdpBuildArgs")
print_args(args,ArgumentDict["mdpSolveArgs"], title="mdpSolveArgs")
print_args(args,ArgumentDict["evalArgs"], title="evalArgs")


# In[ ]:





# In[17]:


# train_buffer.sample_indices([1,2,3])[0].shape


# In[18]:


# train_buffer.sample_indices([1,2,3])[0].shape


# In[20]:


def build_mdp(self, train_buffer, batch_parse=False):
    print("Parsing Dataset Batch Parse:", batch_parse)

    _batch_size = 256
    start_end_indexes = get_iter_indexes(len(train_buffer), _batch_size)
    for start_i, end_i in tqdm(start_end_indexes):
        batch = train_buffer.sample_indices(list(range(start_i, end_i)))
        batch_ob, batch_a, batch_ob_prime, batch_r, batch_nd = batch
        batch_d = [not nd for nd in batch_nd]
        self.consume_batch(np.array(batch_ob), np.array(batch_a).squeeze(), np.array(batch_ob_prime),
                              np.array(batch_r).squeeze(), np.array(batch_d).squeeze(), False,
                              commit_seen_flag=False, commit_pred_flag=False, update_funknown_flag=False,
                              solve_mdp_flag=False)

    print("Parsing Complete")
    print("Commiting Seen Transitions")
    self.commit_seen_transitions()
    print("Commit Complete")

    self.commit_predicted_transitions(verbose=True)
    if self.fill_with == "none":
        print("filling fully unknown states here, total unknown state count:{}",self.mdp_T.unknown_state_action_count)
        self.mdp_T.fill_fully_unknown_states()
    else:
        assert self.mdp_T.unknown_state_action_count == 0

    self.solve_mdp()
    print("% of missing transitions", self.mdp_T.unknown_state_action_count / (
            len(self.mdp_T.tD) * len(self.mdp_T.A)))
    self.seed_policies(smoothing=False, soft_q=False)


# In[ ]:





# In[22]:



empty_MDP = FullMDP(A= list(range(env.action_space.n)),
                    ur=args.unknown_transition_reward,
                    vi_params={"gamma":  args.gamma,
                               "slip_prob": args.slip_probability,
                               "rmax_reward": args.rmax_reward,
                               "rmax_thres": args.rmax_threshold,
                               "balanced_explr": args.balanced_exploration,
                              "rmin":-1000},
                    mdp_params={"weight_transitions": args.weight_transitions,
                            "mdp_build_k":args.mdp_build_k,
                            "plcy_lift_k": 1,
                            "knn_delta" : args.knn_delta,
                            "calc_action_vector":False,
                            "penalty_beta":args.penalty_beta,
                            "penalize_unknown_transitions": args.penalty_type != "none"
                            },
                    MAX_S_COUNT= args.MAX_S_COUNT,
                    MAX_NS_COUNT= args.MAX_NS_COUNT,
                    default_mode=args.def_device)

myAgent =  SimpleAgent(empty_MDP, net = latentNetwork,  fill_with = args.fill_with,
                       sample_neighbors = False, pred_error_threshold = args.within_radius,
                       penalty_type=args.penalty_type, penalty_beta = args.penalty_beta,
                       abstraction_flag=False, normalize_by_distance= args.normalize_by_distance)


# mdp cache is for one value of k and can be share by multiple values of penalty_beta and gamma #
mdp_cache_id = f"L[{args.latent_type}]-B[{args.bottleneck_size}]-Fill[{args.fill_with}_K{args.mdp_build_k}]"
mdp_id = f"L[{args.latent_type}]-B[{args.bottleneck_size}]-Fill[{args.fill_with}_K{args.mdp_build_k}]-Penalty[{args.penalty_type}_{args.penalty_beta}]-G[{args.gamma}]"
mdp_path = f"mdps/cache_{args.buffer_name}_{env_setting}_{mdp_cache_id}.pk"


# print(args.build_mdp)
if args.build_mdp:
    train_buffer.device = "cpu"
    build_mdp(myAgent, train_buffer, batch_parse = True)
    train_buffer.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if args.load_mdp:
    myAgent.load_mdp_from_cache( file_path = mdp_path)

if args.save_mdp:
    myAgent.cache_mdp( file_path = mdp_path)


myAgent.smoothing = args.smoothing
myAgent.soft_q = args.soft_q
myAgent.mdp_T.smooth_with_seen = args.smooth_with_seen
myAgent.smoothing, myAgent.soft_q, myAgent.mdp_T.smooth_with_seen,len(myAgent.mdp_T.tD)


# In[19]:


# for s in list(myAgent.mdp_T.tD.keys())[2:4]:
#     for a in myAgent.mdp_T.tD[s]:
#         for ns in myAgent.mdp_T.tD[s][a]:
#             print(myAgent.mdp_T.s2idx[s], a,myAgent.mdp_T.s2idx[ns], round(myAgent.mdp_T.tD[s][a][ns],3) )


# In[40]:

if args.fast_run:
    print("skipping metrics dump")
else:
    all_distr = myAgent.log_all_mdp_metrics(1)
    pk.dump(all_distr, open(f"results/{args.buffer_name}_{env_setting}_{mdp_id}_metricDistr.pk", "wb"))


# In[20]:


# import seaborn as sns
# f, axes = plt.subplots(3, 3, figsize = (8,16))
# for x in range(3):
#     for y in range(3):
#         idx = x*3 +y
#         label, data = list(all_distr.keys())[idx], list(all_distr.values())[idx]
#         ax = sns.distplot(data, hist = True, kde= False, ax = axes[x][y])
#         ax.set(ylabel='State Count')
#         ax.set_title(label)
# plt.subplots_adjust(top = 1.2, right= 2)


# In[22]:


# args.eval_episode_count = 10


# In[ ]:

if args.fast_run:
    table_header = ["fill_with", "Prediction Model ", "Penalty type", "penalty beta", "Policy K"] + [ k + 'Policy Avg Score' for k in list(myAgent.policies)[:1]]
    table_rows = []
    env.max_episode_length = ATARI_PREPROCESSING_PARAMS["max_episode_timesteps"]
    for policy_k in args.policy_k:
        myAgent.smoothing = True
        myAgent.soft_q = False
        myAgent.mdp_T.plcy_lift_k = policy_k
        myAgent.mdp_T.smooth_with_seen = False

        sum_reward_running = {policy_name:evaluate_on_env(env,  policy, eps_count=args.eval_episode_count,progress_bar=True)[0]
                              for policy_name,policy in list(myAgent.policies.items())[:1]}

        table_rows.append([args.fill_with.replace("KNN", str(args.mdp_build_k )+ "NN"),
                                                  str(args.latent_type), args.penalty_type, args.penalty_beta, policy_k] +
                                           list(sum_reward_running.values()) )

        print(policy_k, sum_reward_running)

    pk.dump(table_rows, open(f"results/{args.buffer_name}_{env_setting}_{mdp_id}_evalTable.pk", "wb"))
    print(table_rows)
else:
    table_header = ["fill_with", "Prediction Model ", "Penalty type", "penalty beta", "Policy K"] + [ k +'Policy Avg Score' for k in myAgent.policies] + ['eps Env - Safe Policy']
    table_rows = []
    env.max_episode_length = ATARI_PREPROCESSING_PARAMS["max_episode_timesteps"]
    for policy_k in args.policy_k:
        myAgent.smoothing = True
        myAgent.soft_q = False
        myAgent.mdp_T.plcy_lift_k = policy_k
        myAgent.mdp_T.smooth_with_seen = False

        sum_reward_running = {policy_name:evaluate_on_env(env,  policy, eps_count=args.eval_episode_count,progress_bar=True)[0]
                              for policy_name,policy in myAgent.policies.items()}

        eps_safe_reward = evaluate_on_env(env, myAgent.policies["safe"], eps_count=args.eval_episode_count, progress_bar=True, eval_eps= 0.1)[0]

        table_rows.append([args.fill_with.replace("KNN", str(args.mdp_build_k )+ "NN"),
                                                  str(args.latent_type), args.penalty_type, args.penalty_beta, policy_k] +
                                           list(sum_reward_running.values()) + [eps_safe_reward])

        print(policy_k, sum_reward_running)

    pk.dump(table_rows, open(f"results/{args.buffer_name}_{env_setting}_{mdp_id}_evalTable.pk", "wb"))
    print(table_rows)




# In[ ]:





# In[ ]:





# In[ ]:





# In[42]:


# def atariVideoOut(env, episode_count=2, max_steps=108000, policy = None, render = False):
#     policy = policy
#     custom_video = []
#     rewards = 0

#     if render:
#         s = env.reset()
#         combined_view =env.env.render("rgb_array")
#         img = plt.imshow(combined_view) # only call this once

#         def render_view(view):
#             img.set_data(view) # just update the data
#             display.display(plt.gcf())
#             display.clear_output(wait=True)
# #             time.sleep(0.1)

#     for i in range(episode_count):
#         s = env.reset()
#         s = np.array(s)

#         for i in range(max_steps):
#             combined_view = env.env.render("rgb_array") #add_title_on_top(combined_view, title_height=50, title_text =qval_text,font_size=25)
#             custom_video.append(combined_view)

#             if render:
#                 render_view(combined_view)

# #             a = policy.select_action(np.array(s), eval=True)

#             a = policy(s) #.select_action(np.array(s), eval=True)
#             s,r,d, i = env.step(a)
#             print(a, "life count", env.env.ale.lives())
#             s = np.array(s)
#             rewards += r
# #             if r != 0 :
# #                 print(r)
#             if d:
#                 break
#         combined_view = env.env.render("rgb_array") #add_title_on_top(combined_view, title_height=50, title_text =qval_text,font_size=25)
#         custom_video.append(combined_view)

#         if render:
#             render_view(combined_view)

#     custom_video = np.array(custom_video)
#     print("Avg Score:",rewards/episode_count)
#     return  rewards/episode_count, {}, custom_video


# In[43]:


# args.env


# In[44]:


# env.env.get_action_meanings()


# In[45]:


# # myAgent.smoothing = True
# # myAgent.soft_q = False
# # myAgent.mdp_T.plcy_lift_k = 11
# # myAgent.mdp_T.smooth_with_seen = False

# # policy = myAgent.policies["safe"]
# atariVideoOut(env, policy = lambda s : 3, episode_count = 1, render = False)
# # sum_reward_running = {policy_name:evaluate_on_env(env,  policy, eps_count=args.eval_episode_count,progress_bar=True)[0]
# #                       for policy_name,policy in myAgent.policies.items()}


# In[46]:


# sum_reward_running


# In[ ]:





# In[ ]:





# In[ ]:





# In[ ]:





# In[ ]:





# In[ ]:





# In[ ]:





# In[ ]:





# In[ ]:





# In[ ]:
