
import torch
import torch.nn.functional as F
import os
import sys
import numpy as np
import torch
import imageio
import json
from envs.Overcooked_Env_new import Overcooked_NEW
from envs.Overcooked_Env import Overcooked
from baseline_policy.ai_policy.policy_prefrence import Policy
from flask import Flask, render_template,request
app = Flask(__name__)

# global variables
env = None
policy = None
actions = [-1,-1] # List of recorded actions
timestep = 0 # step
score = 0 # score
name = 'name' # player name
obs = None # Obs
rnn_state = None  # RNN
mask = None # MASK

# # Overcooked Layouts：unident_s , random3
# env = Overcooked(r'unident_s',seed=3,featurize_type=("ppo","ppo")) 
# # AI policy：Load the model in the HSP algorithm
# policy = Policy(map_name=r'unident_s')
# ckpt_path = r"./baseline_policy/unident_s/hsp_adaptive_seed1.pt" 
# policy.load_checkpoint(ckpt_path)
# policy.prep_rollout()

    
# # Overcooked Layouts：many_orders,soup_coordination,distant_tomato
# env = Overcooked_NEW(r'many_orders',seed=3,featurize_type=("ppo","ppo")) 
# # AI policy：Load the model in the HSP algorithm
# policy = Policy(map_name=r'many_orders')
# ckpt_path = r"./baseline_policy/many_orders/fcp_adaptive_seed1.pt" 
# policy.load_checkpoint(ckpt_path)
# policy.prep_rollout()

def load_ENV_and_Model(env_name, ckpt_path):
    global env
    global policy
    if env_name in ['many_orders','soup_coordination','distant_tomato']:
        env = Overcooked_NEW(env_name,seed=3,featurize_type=("ppo","ppo")) 
    if env_name in ['unident_s','random3']:
        env = Overcooked(env_name,seed=3,featurize_type=("ppo","ppo")) 
    policy = Policy(map_name=env_name)
    policy.load_checkpoint(ckpt_path)
    policy.prep_rollout()

@app.route('/')
def index():
    return render_template('register_baseline.html')
    
# Shows how Flask reads local images from the server and returns the image stream to the front-end for display
def return_img_stream(img_local_path):
    """
    工具函数:
    获取本地图片流
    :param img_local_path:文件单张图片的本地绝对路径
    :return: 图片流
    """
    import base64
    img_stream = ''
    with open(img_local_path, 'rb') as img_f:
        img_stream = img_f.read()
        img_stream = base64.b64encode(img_stream).decode()
    return img_stream 

@app.route("/cook",methods=["GET", "POST"])#
def cook():
    global name
    if request.method == "POST":
        name = request.form.get("name")
    if request.method == "GET":
        name1 = request.args.get("name")
    img_path = 'static/overcooked2.png'
    img_stream = return_img_stream(img_path)
    return render_template('index_baseline.html',img_stream=img_stream, timestep = 'timestep', score='score', name=name)

def create_gif(image_list, gif_name, duration = 1.0):
    '''
    :param image_list: 这个列表用于存放生成动图的图片
    :param gif_name: 字符串,所生成gif文件名,带.gif后缀
    :param duration: 图像间隔时间
    :return:
    '''
    frames = []
    for image_name in image_list:
        frames.append(imageio.imread(image_name))
    imageio.mimsave(gif_name, frames, 'GIF', duration=duration)
    return
   
def write_gif(gif_name = r"save.gif",pic_path = r"./static/img/",duration = 0.5):      
    image_list=[]
    pic_name = os.listdir(pic_path)
    for i in range(len(pic_name)):
        path=os.path.join(pic_path,str(i)+'.png')
        image_list.append(path)
    create_gif(image_list, os.path.join(r"./static/video/",gif_name), duration)   

@app.route("/test/",methods=["GET", "POST"])
def test():
    global timestep
    global actions
    global score
    global name
    global policy
    global env
    global recog_num
    global obs
    global rnn_state
    global mask
    
    params = list(request.args.keys())
    action = int(params[0].split('_')[-1])

    if (action == -1):
        # 清空内容
        try:
            save_dir = r"./static/img"
            imgs_dir = os.listdir(save_dir)
            if len(imgs_dir) > 0:
                for img_path in imgs_dir:
                    img_path = save_dir + '/' + img_path
                    os.remove(img_path)
        except:
            pass
        rnn_state = np.zeros((2, 1, 64), dtype=np.float32)
        mask = np.ones((2, 1), dtype=np.float32)
        obs, _, _ = env.reset() 
        obs = np.stack(obs) # 按axis=0 ，组合obs 
        timestep = 0
        img_path = 'static/img/'+str(timestep)+'.png'
        img_stream = return_img_stream(img_path)
        return render_template('index_baseline.html',img_stream=img_stream, timestep = timestep, score=score, name=name)
    
    elif action != -2 :
        #try:
        timestep += 1
                        
        actions[1] = [action]
        #actions[0] = action  
            
        # 前向计算
        try:
            obs = torch.Tensor(obs).cuda() 
        except:
            obs = obs
        ai_action,rnn_state  = policy.act(obs,rnn_state ,mask)
        ai_action = ai_action.cpu().numpy()
        print('ai_action:',ai_action)
        actions[0] = ai_action[0]
        obs, share_obs, rewards, dones, infos, available_actions = env.step(actions)
        obs = np.stack(obs) 
                
        if dones[0] == True or dones[1] == True:
            save_name = 'mep'
            try:
                write_gif(save_name+'.gif')
            except:
                pass
            dones = [False,False]
            env.reset()
            img_path = 'static/game_over.png'
            img_stream = return_img_stream(img_path)
            return render_template('index_baseline.html',img_stream=img_stream, timestep = timestep, score=score, name=name)
        else:   
            #env.render3(timestep)
            try:
                img_path = 'static/img/'+str(timestep)+'.png'
                img_stream = return_img_stream(img_path)
                actions = [-2,-2]
                rewards = np.array(rewards)
                score += rewards[1]
                return render_template('index_baseline.html',img_stream=img_stream, timestep= timestep, score=score, name=name)
            except:
                timestep = timestep -1
                pass
    img_path = 'static/img/'+str(timestep)+'.png'
    img_stream = return_img_stream(img_path)
    return render_template('index_baseline.html',img_stream=img_stream, timestep=timestep, score=score, name=name)
    
if __name__ == '__main__':
    # Demo1: python app_Baseline.py many_orders ./baseline_policy/many_orders/fcp_adaptive_seed1.pt
    # Demo2: python app_Baseline.py unident_s ./baseline_policy/unident_s/hsp_adaptive_seed1.pt
    load_ENV_and_Model(sys.argv[1],sys.argv[2])
    app.run(port=8689)