import matplotlib.pyplot as plt
from IPython import display
import numpy as np

def atariVideoOut(env, episode_count=2, max_steps=108000, policy=None, render=False):
    policy = policy
    custom_video = []
    rewards = 0

    if render:
        s = env.reset()
        combined_view = env.env.render("rgb_array")
        img = plt.imshow(combined_view)  # only call this once

        def render_view(view):
            img.set_data(view)  # just update the data
            display.display(plt.gcf())
            display.clear_output(wait=True)
    #             time.sleep(0.1)

    for i in range(episode_count):
        s = env.reset()
        s = np.array(s)

        for i in range(max_steps):
            combined_view = env.env.render(
                "rgb_array")  # add_title_on_top(combined_view, title_height=50, title_text =qval_text,font_size=25)
            custom_video.append(combined_view)

            if render:
                render_view(combined_view)

            a = policy(a)
            s, r, d, i = env.step(a)
            s = np.array(s)
            rewards += r
            #             if r != 0 :
            #                 print(r)
            if d:
                break
        combined_view = env.env.render(
            "rgb_array")  # add_title_on_top(combined_view, title_height=50, title_text =qval_text,font_size=25)
        custom_video.append(combined_view)

        if render:
            render_view(combined_view)

    custom_video = np.array(custom_video)
    print("Avg Score:", rewards / episode_count)
    return rewards / episode_count, {}, custom_video