
import torch
import torch.nn.functional as F
import os
import numpy as np
import torch
import imageio
import json
from envs.Overcooked_Env_new import Overcooked_NEW
from hsp.envs.overcooked_new.script_agent import SCRIPT_AGENTS
from hsp.envs.overcooked_new.src.overcooked_ai_py.mdp.actions import Action, Direction
from flask import Flask, render_template,request
from bc_policy.BCAgent import BCAgent
app = Flask(__name__)

# global variables
actions = [-1,-1] # List of recorded actions
timestep = 0 # time
score = 0 # score
name = 'name' # name
obs = None # obs
rnn_state = None  # RNN
mask = None # MASK
pot1_num = 0 # Count the number of ingredients in the pot in Layouts 
pot2_num = 0
pot3_num = 0
now_inst = 0
action_list = []
    
# load: Overcooked_NEW
env = Overcooked_NEW(r'mo1',seed=3,featurize_type=("ppo","ppo"))

# load: skill
# Parameters
params = {}
params['env'] = 'mo1'
params['lr'] = 0.001
params['gamma'] = 0.95
params['obs_dim'] = None
params['action_dim'] = 6
params['hidden_dim'] = 64
model_path = r"./bc_policy/ManyOrders"
# 拿取 洋葱 放到 第一个锅 中
Many_orders_place_onion_in_pot1 = BCAgent(params).cuda()
Many_orders_place_onion_in_pot1.load_state_dict(torch.load(os.path.join(model_path,r"Many_orders_place_onion_in_pot1/agent_0_22000.th")))
Many_orders_place_onion_in_pot1.eval()
# 拿取 洋葱 放到 第二个锅 中
Many_orders_place_onion_in_pot2 = BCAgent(params).cuda()
Many_orders_place_onion_in_pot2.load_state_dict(torch.load(os.path.join(model_path,r"Many_orders_place_onion_in_pot2/agent_0_35000.th")))
Many_orders_place_onion_in_pot2.eval()
# 拿取 洋葱 放到 第三个锅 中
Many_orders_place_onion_in_pot3 = BCAgent(params).cuda()
Many_orders_place_onion_in_pot3.load_state_dict(torch.load(os.path.join(model_path,r"Many_orders_place_onion_in_pot3/agent_0_50000.th")))
Many_orders_place_onion_in_pot3.eval()
# 拿取 西红柿 放到 第一个锅 中
Many_orders_place_tomato_in_pot1 = BCAgent(params).cuda()
Many_orders_place_tomato_in_pot1.load_state_dict(torch.load(os.path.join(model_path,r"Many_orders_place_tomato_in_pot1/agent_0_49000.th")))
Many_orders_place_tomato_in_pot1.eval()
# 拿取 西红柿 放到 第二个锅 中
Many_orders_place_tomato_in_pot2 = BCAgent(params).cuda()
Many_orders_place_tomato_in_pot2.load_state_dict(torch.load(os.path.join(model_path,r"Many_orders_place_tomato_in_pot2/agent_0_27000.th")))
Many_orders_place_tomato_in_pot2.eval()
# 拿取 西红柿 放到 第三个锅 中
Many_orders_place_tomato_in_pot3 = BCAgent(params).cuda()
Many_orders_place_tomato_in_pot3.load_state_dict(torch.load(os.path.join(model_path,r"Many_orders_place_tomato_in_pot3/agent_0_9000.th")))
Many_orders_place_tomato_in_pot3.eval()
# 递送 第一个锅 中的汤到送餐口
Many_orders_deliver_soup_use_pot1 = BCAgent(params).cuda()
#Many_orders_deliver_soup_use_pot1.load_state_dict(torch.load(os.path.join(model_path,r"Many_orders_deliver_soup_use_pot1/agent_0_33000.th")))
Many_orders_deliver_soup_use_pot1.load_state_dict(torch.load(os.path.join(model_path,r"Many_orders_deliver_soup_use_pot1_aug/agent_0_80000.th")))
Many_orders_deliver_soup_use_pot1.eval()
# 递送 第二个锅 中的汤到送餐口
Many_orders_deliver_soup_use_pot2 = BCAgent(params).cuda()
#Many_orders_deliver_soup_use_pot2.load_state_dict(torch.load(os.path.join(model_path,r"Many_orders_deliver_soup_use_pot2/agent_0_35000.th")))
Many_orders_deliver_soup_use_pot2.load_state_dict(torch.load(os.path.join(model_path,r"Many_orders_deliver_soup_use_pot2_aug/agent_0_40000.th")))
Many_orders_deliver_soup_use_pot2.eval()
# 递送 第三个锅 中的汤到送餐口
Many_orders_deliver_soup_use_pot3 = BCAgent(params).cuda()
#Many_orders_deliver_soup_use_pot3.load_state_dict(torch.load(os.path.join(model_path,r"Many_orders_deliver_soup_use_pot3/agent_0_86000.th")))
Many_orders_deliver_soup_use_pot3.load_state_dict(torch.load(os.path.join(model_path,r"Many_orders_deliver_soup_use_pot3_aug/agent_0_90000.th")))
Many_orders_deliver_soup_use_pot3.eval()


@app.route('/')
def index():
    return render_template('register_baseline.html')
    
# 展示Flask如何读取服务器本地图片, 并返回图片流给前端显示
def return_img_stream(img_local_path):
    """
    工具函数:
    获取本地图片流
    :param img_local_path:文件单张图片的本地绝对路径
    :return: 图片流
    """
    import base64
    img_stream = ''
    with open(img_local_path, 'rb') as img_f:
        img_stream = img_f.read()
        img_stream = base64.b64encode(img_stream).decode()
    return img_stream 

@app.route("/cook",methods=["GET", "POST"])#
def cook():
    global name
    if request.method == "POST":
        name = request.form.get("name")
    if request.method == "GET":
        name1 = request.args.get("name")
    img_path = 'static/overcooked2.png'
    img_stream = return_img_stream(img_path)
    return render_template('index_baseline.html',img_stream=img_stream, timestep = 'timestep', score='score', name=name)

def create_gif(image_list, gif_name, duration = 1.0):
    '''
    :param image_list: 这个列表用于存放生成动图的图片
    :param gif_name: 字符串,所生成gif文件名,带.gif后缀
    :param duration: 图像间隔时间
    :return:
    '''
    frames = []
    for image_name in image_list:
        frames.append(imageio.imread(image_name))
    imageio.mimsave(gif_name, frames, 'GIF', duration=duration)
    return
   
def write_gif(gif_name = r"save.gif",pic_path = r"./static/img/",duration = 0.5):      
    image_list=[]
    pic_name = os.listdir(pic_path)
    for i in range(len(pic_name)):
        path=os.path.join(pic_path,str(i)+'.png')
        image_list.append(path)
    create_gif(image_list, os.path.join(r"./static/video/",gif_name), duration)   

@app.route("/test/",methods=["GET", "POST"])
def test():
    global timestep
    global actions
    global score
    global name
    global policy
    global env
    global recog_num
    global obs
    global rnn_state
    global mask
    global agent
    global pot1_num
    global pot2_num
    global pot3_num
    global now_inst
    global action_list
    
    params = list(request.args.keys())
    action = int(params[0].split('_')[-1])

    if (action == -1):
        # 清空内容
        try:
            save_dir = r"./static/img"
            imgs_dir = os.listdir(save_dir)
            if len(imgs_dir) > 0:
                for img_path in imgs_dir:
                    img_path = save_dir + '/' + img_path
                    os.remove(img_path)
        except:
            pass
        rnn_state = np.zeros((2, 1, 64), dtype=np.float32)
        mask = np.ones((2, 1), dtype=np.float32)
        obs, _, _ = env.reset() 
        obs = np.stack(obs) # 按axis=0 ，组合obs 
        timestep = 0
        img_path = 'static/img/'+str(timestep)+'.png'
        img_stream = return_img_stream(img_path)
        return render_template('index_baseline.html',img_stream=img_stream, timestep = timestep, score=score, name=name)
    
    elif action != -2 :
        #try:
        timestep += 1
                       
        actions[1] = [action]
        #actions[0] = action     
        
        # # 下游策略：只用右侧的锅
        # Instruct_List = [Many_orders_place_tomato_in_pot3,Many_orders_deliver_soup_use_pot3,Many_orders_deliver_soup_use_pot2] 
        # ai_action = Instruct_List[now_inst].act(obs[0],True)
        # #ai_action = Many_orders_deliver_soup_use_pot2.act(obs[0],True)
        # action_list.append(ai_action)
        # if len(list(set(action_list[-5:]))) == 1:
        #     ai_action = np.random.randint(0,6)
        # actions[0] = [ai_action]
        # obs, share_obs, rewards, dones, infos, available_actions = env.step(actions)
        # obs = np.stack(obs) 
        # # if infos['pot'][0] > 0:
        # #     pot1_num += 1
        # # if infos['pot'][1] > 0:
        # #     pot2_num += 1
        # if infos['pot'][2] > 0:
        #     pot3_num += 1
        # # if pot2_num  == 3:
        # #     now_inst = now_inst + 1   
        # #     pot2_num = 0
        # if pot3_num == 3:
        #     now_inst = now_inst + 1   
        #     pot3_num = 0
        # if infos['delivery'][0] > 0:
        #     now_inst = now_inst + 1   
        # if  now_inst >= len(Instruct_List) :
        #     now_inst = 0    
        # print('now_inst:',now_inst,pot1_num,pot2_num,pot3_num)


        # 下游策略: 递送
        Instruct_List = [Many_orders_deliver_soup_use_pot1,Many_orders_deliver_soup_use_pot2,Many_orders_deliver_soup_use_pot3] 
        ai_action = Instruct_List[now_inst].act(obs[0],True)
        #ai_action = Many_orders_deliver_soup_use_pot2.act(obs[0],True)
        actions[0] = [ai_action] 
        obs, share_obs, rewards, dones, infos, available_actions = env.step(actions)
        obs = np.stack(obs) 
        if infos['delivery'][0] > 0:
            now_inst = now_inst + 1   
        if  now_inst >= len(Instruct_List) :
            now_inst = 0    
        print('now_inst:',now_inst)
        
        # # 下游策略：烹饪
        # Instruct_List = [Many_orders_place_tomato_in_pot3,Many_orders_place_tomato_in_pot2,Many_orders_place_tomato_in_pot1] 
        # ai_action = Instruct_List[now_inst].act(obs[0],True)
        # #ai_action = Many_orders_deliver_soup_use_pot2.act(obs[0],True)
        # actions[0] = [ai_action]
        # obs, share_obs, rewards, dones, infos, available_actions = env.step(actions)
        # obs = np.stack(obs) 
        # if infos['pot'][0] > 0:
        #     pot1_num += 1
        # if infos['pot'][1] > 0:
        #     pot2_num += 1
        # if infos['pot'][2] > 0:
        #     pot3_num += 1
        # if pot1_num == 3:
        #     now_inst = now_inst + 1  
        #     pot1_num = 0
        # if pot2_num  == 3:
        #     now_inst = now_inst + 1   
        #     pot2_num = 0
        # if pot3_num == 3:
        #     now_inst = now_inst + 1   
        #     pot3_num = 0 
        # if  now_inst >= len(Instruct_List) :
        #     now_inst = 0    
        # print('now_inst:',now_inst,pot1_num,pot2_num,pot3_num)

        # # 下游策略：烹饪
        # Instruct_List = [Many_orders_place_onion_in_pot1,Many_orders_place_onion_in_pot2,Many_orders_place_onion_in_pot3] 
        # ai_action = Instruct_List[now_inst].act(obs[0],True)
        # #ai_action = Many_orders_deliver_soup_use_pot2.act(obs[0],True)
        # actions[0] = [ai_action]
        # obs, share_obs, rewards, dones, infos, available_actions = env.step(actions)
        # obs = np.stack(obs) 
        # if infos['pot'][0] > 0:
        #     pot1_num += 1
        # if infos['pot'][1] > 0:
        #     pot2_num += 1
        # if infos['pot'][2] > 0:
        #     pot3_num += 1
        # if pot1_num == 3:
        #     now_inst = now_inst + 1  
        #     pot1_num = 0
        # if pot2_num  == 3:
        #     now_inst = now_inst + 1   
        #     pot2_num = 0
        # if pot3_num == 3:
        #     now_inst = now_inst + 1   
        #     pot3_num = 0 
        # if  now_inst >= len(Instruct_List) :
        #     now_inst = 0    
        # print('now_inst:',now_inst,pot1_num,pot2_num,pot3_num)

        # # 下游策略：烹饪
        # Instruct_List = [Many_orders_place_onion_in_pot1,Many_orders_place_onion_in_pot2,Many_orders_place_onion_in_pot3] 
        # ai_action = Instruct_List[now_inst].act(obs[0],True)
        # #ai_action = Many_orders_deliver_soup_use_pot2.act(obs[0],True)
        # actions[0] = [ai_action]
        # obs, share_obs, rewards, dones, infos, available_actions = env.step(actions)
        # obs = np.stack(obs) 
        # if infos['pot'][0] > 0:
        #     pot1_num += 1
        # if infos['pot'][1] > 0:
        #     pot2_num += 1
        # if infos['pot'][2] > 0:
        #     pot3_num += 1
        # if pot1_num == 3:
        #     now_inst = now_inst + 1  
        #     pot1_num = 0
        # if pot2_num  == 3:
        #     now_inst = now_inst + 1   
        #     pot2_num = 0
        # if pot3_num == 3:
        #     now_inst = now_inst + 1   
        #     pot3_num = 0 
        # if  now_inst >= len(Instruct_List) :
        #     now_inst = 0    
        # print('now_inst:',now_inst,pot1_num,pot2_num,pot3_num)                        
                            
        if dones[0] == True or dones[1] == True:
            save_name = 'mep'
            try:
                write_gif(save_name+'.gif')
            except:
                pass
            dones = [False,False]
            env.reset()
            img_path = 'static/game_over.png'
            img_stream = return_img_stream(img_path)
            return render_template('index_baseline.html',img_stream=img_stream, timestep = timestep, score=score, name=name)
        else:   
            #env.render3(timestep)
            try:
                img_path = 'static/img/'+str(timestep)+'.png'
                img_stream = return_img_stream(img_path)
                actions = [-2,-2]
                rewards = np.array(rewards)
                score += rewards[1]
                return render_template('index_baseline.html',img_stream=img_stream, timestep= timestep, score=score, name=name)
            except:
                timestep = timestep -1
                pass
    img_path = 'static/img/'+str(timestep)+'.png'
    img_stream = return_img_stream(img_path)
    return render_template('index_baseline.html',img_stream=img_stream, timestep=timestep, score=score, name=name)
    
if __name__ == '__main__':
    app.run(port=8689)