from seq2seq import Seq2seq
import argparse
import tensorflow as tf
import common.grid_world as grid_world
from common.config import Config
from queue import Queue
import copy
from ppo import Policy
import numpy as np
import random
from matplotlib import animation
import matplotlib.pyplot as plt

def save_frames_as_gif(frames, path='./Gif/', filename='gym_animation.gif'):

    #Mess with this to change frame size
    plt.figure(figsize=(frames[0].shape[1] / 36.0, frames[0].shape[0] / 36.0), dpi=36)

    patch = plt.imshow(frames[0])
    plt.axis('off')

    def animate(i):
        patch.set_data(frames[i])

    anim = animation.FuncAnimation(plt.gcf(), animate, frames = len(frames), interval=50)
    #anim.save(path + filename, writer='imagemagick', fps=60)
    anim.save(path + filename, writer='pillow', fps=60)

def get_env_config(args):
    if args == 'two_way_grid_world':
        env = grid_world.TwoWayGridWorld()
        config = Config(env, 'two_way_grid_world')
    return env, config


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--algo', type=str, default='ppo_mod', help='algo name')
    parser.add_argument('--agent_model_dir', type=str, default='ppo/')
    parser.add_argument('--seq2seq_model_dir', type=str, default='seq2seq/model/')
    parser.add_argument('--ppo_load_path', type=str, default=None, help='algo name')
    parser.add_argument('--env', type=str, default='two_way_grid_world', help='env')
    parser.add_argument('--gpu', type=int, default=-1, help='running on a specify gpu, -1 indicates using cpu')
    parser.add_argument('--seed', type=int, default=0, help='random seed')
    parser.add_argument('--num_episodes', type=int, default=50000, help='num_episodes')
    parser.add_argument('--test_num', type=int, default=20)

    args = parser.parse_args()
    env, config = get_env_config(args.env)

    frames = []
    for ep_cur in range(args.test_num):
        state = env.random_reset()
        state = [e / env.n_height for e in state]
        env.render()
        ep_length, ep_reward = 0, 0
        demo = open('demo_up.txt', 'a')
        while True:
            #choice = input('请选择1或2')
            # ep_length += 1
            # if ep_length <= 3:
            #     action = 2
            # elif ep_length <= 12:
            #     action = 1
            # elif ep_length <= 17:
            #     action = 2
            # #    next_state, reward, done, _ = env.step(action)
            # #elif ep_length > 4 and ep_length <= 7:
            # #    action = 1
            # #    next_state, reward, done, _ = env.step(action)
            # #elif ep_length > 7 and ep_length <= 10:
            # #    action = 2
            # #    next_state, reward, done, _ = env.step(action)
            # else:
            #     action = 2
            print('input your aciton')
            action=int(input())
            next_state, reward, done, _ = env.random_step(action)
            print(reward)
            demo.write('state:' + str(state) + '\n')
            demo.write('action:' + str(action) + '\n')
            demo.write('done:' + str(done) + '\n')
            env.render()
            frames.append(env.render(mode='rgb_array'))
            next_state = [e / env.n_height for e in next_state]

            state = next_state
            ep_reward += reward
            if done or ep_length >= config.solving_criteria[0]:
                demo.close()
                break
    gif_name = 'random_demo' + '.gif'
    save_frames_as_gif(frames, filename=gif_name)
    demo.close()
    ### force actiagent_model_dirons   强制走下面的门
