from flask import Flask, render_template,request
from envs.Overcooked_Env_new import Overcooked_NEW
from envs.Overcooked_Env import Overcooked
from replay_buffer import ReplayBuffer
import os
import numpy as np
import imageio
import copy
import sys
app = Flask(__name__)

# 全局变量
env = None
env_name = None
actions = [-1,-1]
timestep = 0
score = 0
name1 = 'aaa'
name2 = 'bbb'
replay_buffer = []
index_id = 0
obs = None
save_path = r"./save_BCData"
if not os.path.exists(save_path):
    os.makedirs(save_path)
    
def load_Env(layout):
    global env
    global env_name
    env_name = layout
    if layout in ['mo1','many_orders','distant_tomato','soup_coordination']:
        env = Overcooked_NEW(layout,seed=3,featurize_type=("ppo","ppo"))
    if layout in ['random3','unident_s']:
        env = Overcooked(layout,seed=3,featurize_type=("ppo","ppo"))

@app.route('/')
def index():
    return render_template('register.html')

@app.route("/cook",methods=["GET", "POST"])#
def cook():
    global name
    if request.method == "POST":
        name = request.form.get("name")
    if request.method == "GET":
        name1 = request.args.get("name")
    img_path = 'static/overcooked2.png'
    img_stream = return_img_stream(img_path)
    return render_template('index_baseline.html',img_stream=img_stream, timestep = 'timestep', score='score', name=name)    

# 展示Flask如何读取服务器本地图片, 并返回图片流给前端显示
def return_img_stream(img_local_path):
    """
    工具函数:
    获取本地图片流
    :param img_local_path:文件单张图片的本地绝对路径
    :return: 图片流
    """
    import base64
    img_stream = ''
    with open(img_local_path, 'rb') as img_f:
        img_stream = img_f.read()
        img_stream = base64.b64encode(img_stream).decode()
    return img_stream
    

@app.route("/submit",methods=["GET", "POST"])#
def submit():
    global name1
    global name2
    global replay_buffer
    if request.method == "POST":
        name1 = request.form.get("name1")
        name2 = request.form.get("name2")
    if request.method == "GET":
        name1 = request.args.get("name1")
        name2 = request.args.get("name2")
    
   # 初始化 ReplayBuffer
    if env_name == 'mo1' or env_name == 'many_orders':
        obs_shape = [2, 5*5*26]
    if env_name == 'distant_tomato':
        obs_shape = [2, 7*5*26]
    if env_name == 'random3':
        obs_shape = [2, 8*5*20]
    if env_name == 'soup_coordination':
        obs_shape = [2, 11*5*26]
    if env_name == 'unident_s':
        obs_shape = [2, 9*5*20]
    action_shape = [2, 1]
    reward_shape = [2, 1]
    dones_shape = [2, 1]
    replay_buffer = ReplayBuffer(obs_shape=obs_shape,
                                action_shape=action_shape,
                                reward_shape=reward_shape,
                                dones_shape=dones_shape,
                                capacity=400,
                                device='cuda')
    

    img_path = 'static/overcooked2.png'
    img_stream = return_img_stream(img_path)
    return render_template('index.html',img_stream=img_stream, timestep = 'timestep', score='score', name1=name1,name2=name2)

def create_gif(image_list, gif_name, duration = 1.0):
    '''
    :param image_list: 这个列表用于存放生成动图的图片
    :param gif_name: 字符串，所生成gif文件名，带.gif后缀
    :param duration: 图像间隔时间
    :return:
    '''
    frames = []
    for image_name in image_list:
        frames.append(imageio.imread(image_name))
    imageio.mimsave(gif_name, frames, 'GIF', duration=duration)
    return
   
def write_gif(gif_name = r"save.gif",pic_path = r"./static/img/",duration = 0.5):      
    image_list=[]
    pic_name = os.listdir(pic_path)
    for i in range(len(pic_name)):
        path=os.path.join(pic_path,str(i)+'.png')
        image_list.append(path)
    create_gif(image_list, os.path.join(r"./static/video/",gif_name), duration)   

@app.route("/test/",methods=["GET", "POST"])
def test():
    global timestep
    global actions
    global score
    global name1
    global name2
    global replay_buffer
    global index_id
    global obs
    params = list(request.args.keys())
    print(params[0],params[1])
    action1 = int(params[0].split('_')[-1])
    action2 = int(params[1].split('_')[-1])
    if (action1 == -1) & (action2 == -1):
        # 清空内容
        try:
            save_dir = r"./static/img"
            imgs_dir = os.listdir(save_dir)
            if len(imgs_dir) > 0:
                for img_path in imgs_dir:
                    img_path = save_dir + '/' + img_path
                    os.remove(img_path)
        except:
            pass
        timestep = 0
        score = 0
        actions = [-2,-2]
        #env.render2(False)
        obs, _, _ = env.reset() 
        obs = np.stack(obs) 
        img_path = 'static/img/'+str(timestep)+'.png'
        img_stream = return_img_stream(img_path)
        return render_template('index.html',img_stream=img_stream, timestep = timestep, score=score, name1=name1,name2=name2)
    
    if action1 != -2:
        actions[0] = [action1]
    if action2 != -2:
        actions[1] = [action2]
    if (actions[0] != -2 ) & (actions[1] != -2):
        now_obs = copy.deepcopy(obs)
        timestep += 1
        next_obs, share_obs, rewards, dones, infos, available_actions = env.step(actions)
        next_obs = np.stack(next_obs)
        # 数据存储
        input_save = now_obs.reshape(2,-1)
        output_save = next_obs.reshape(2,-1)
        replay_buffer.add(input_save, np.array(actions).reshape(2,-1), np.array([rewards[-1],rewards[-1]]).reshape(2,-1), output_save, np.array(dones).reshape(2,-1))
        obs = copy.deepcopy(next_obs)
        if dones[0] == True or dones[1] == True:
            index_id += 1
            save_name = str(name1)+'&'+str(name2)+'_'+str(index_id)
            replay_buffer.save(save_path, save_name) 
            # try:
                # write_gif(save_name+'.gif')
            # except:
                # pass
            done = False
            #env.reset()
            img_path = 'static/game_over.png'
            img_stream = return_img_stream(img_path)
            return render_template('index.html',img_stream=img_stream, timestep = timestep, score=score, name1=name1,name2=name2)
        else:   
            #env.render2(done)
            
            img_path = 'static/img/'+str(timestep)+'.png'
            img_stream = return_img_stream(img_path)
            actions = [-2,-2]
            score += rewards[1]
            return render_template('index.html',img_stream=img_stream, timestep= timestep, score=score, name1=name1,name2=name2)
 
    img_path = 'static/img/'+str(timestep)+'.png'
    img_stream = return_img_stream(img_path)
    return render_template('index.html',img_stream=img_stream, timestep=timestep, score=score, name1=name1,name2=name2)
    
if __name__ == '__main__':
    print('Layout:',sys.argv)
    load_Env(sys.argv[1])
    app.run(port=8997)