import matplotlib.pyplot as plt
import os, sys
from posixpath import basename
import numpy as np
import cv2
import json
import argparse
from prettytable import PrettyTable
from PIL import Image, ImageDraw, ImageFont
from glob import glob
import pdb
from itertools import zip_longest

PALLTE = np.array([[255, 255, 255], [128, 0, 0], [0, 128, 0], [0, 0, 128], [128, 128, 0], \
         [128, 0, 128], [0, 128, 128], [128, 128, 128], [255, 0, 0], [0, 255, 0], [0, 0, 255]])

TTF_FILE = "/usr/share/fonts/truetype/freefont/FreeMono.ttf"
# TTF_FILE = "FreeMono.ttf"

name_to_id_mapping = {"player": 0, "point": 1, "balloon": 2}
id_to_name_mapping = {v: k for k, v in name_to_id_mapping.items()} 

action_to_string_mapping = {
    (0, 0, 0, 0, 1, 0, 0, 0, 0): "up",
    (0, 0, 0, 0, 1, 0, 0, 0, 1): "up_shoot",
    (0, 0, 0, 0, 0, 1, 0, 0, 0): "down",
    (0, 0, 0, 0, 0, 1, 0, 0, 1): "down_shoot",
    (0, 0, 0, 0, 0, 0, 1, 0, 0): "left",
    (0, 0, 0, 0, 0, 0, 0, 1, 0): "right",
    (0, 0, 0, 0, 0, 0, 0, 0, 0): "noop",
}
string_to_action_mapping = {v: k for k, v in action_to_string_mapping.items()} 

# ID_NAME_MAPPING = {0: "protagonist", 1:"bullet-of-protagonist", 2:"ring-object", 3:"egg", 4:"bullet-of-enemy"}
def png2pngwithBox(env):
    json_dir = os.path.join("data/{}-Json".format(env))
    tmp_dir = os.path.join("data/{}-withBox".format(env))
    os.makedirs(tmp_dir, exist_ok=True)
    json_path_list = glob(os.path.join(json_dir, "*json"))
    for json_path in json_path_list:
        with open(json_path) as f:
            env_dict = json.load(f)
        img_path = json_path.replace("-Json", "").replace(".json", ".png")
        img = Image.open(img_path)
        for obj, val in env_dict.items():
            tl_x = val["topleft_x"]
            tl_y = val["topleft_y"]
            br_x = val["bottomright_x"]
            br_y = val["bottomright_y"]
            obj_cls = val["object_class"]
            draw = ImageDraw.Draw(img, "RGBA")
            draw.rectangle(((tl_x, tl_y), (br_x, br_y)), outline=(PALLTE[obj_cls][0], PALLTE[obj_cls][1], PALLTE[obj_cls][2], 127), width=2)
            img.save("{}".format(os.path.join(tmp_dir, os.path.basename(img_path))))
        del draw

def find_action(env_dict):
    player_id = name_to_id_mapping["player"]
    for obj, val in env_dict.items():
        obj_class = val["object_class"]
        if obj_class == player_id:
            return val["action"]

def find_reward(env_dict):
    for obj, val in env_dict.items():
        if "total_reward" in val.keys():
            return env_dict[obj]["total_reward"]

def find_case2(env_dict):
    for obj, val in env_dict.items():
        if "case" in val.keys():
            return env_dict[obj]["case"]

def find_case1(env_dict):
    return env_dict["action"]

def json2pngwithTable(env):
    json_dir = os.path.join("data/{}-Json".format(env))
    tmp_dir = os.path.join("data/{}-withTable".format(env))
    os.makedirs(tmp_dir, exist_ok=True)
    json_path_list = glob(os.path.join(json_dir, "*json"))
    for json_path in json_path_list:
        img_path = json_path.replace("-Json", "").replace(".json", ".png")
        h, w, c = cv2.imread(img_path).shape
        h_div_w_ratio = h/w

        tab = PrettyTable()
        tab.field_names = ["Meaning", "Real-time Value"]

        with open(json_path) as f:
            env_dict = json.load(f)

        action = find_action(env_dict)
        reward = find_reward(env_dict)
        case = find_case(env_dict)
        # y_coord = (env_dict["obj1"]["topleft_y"] + env_dict["obj1"]["bottomright_y"])//2
        # tab.add_row(['STATE',' '])
        # tab.add_row(['y-coord of player:','{}'.format(y_coord)])
        # x_coord = (env_dict["obj1"]["topleft_x"] + env_dict["obj1"]["bottomright_x"])//2
        # tab.add_row(['x-coord of ring-object:','{}'.format(x_coord)])
        # tab.add_row(['-------------------------', '------------------'])
        tab.add_row(['REWARD',' '])
        tab.add_row(['reward','{}'.format(reward)])
        tab.add_row(['-------------------------', '------------------'])
        tab.add_row(['ACTION',' '])
        tab.add_row(['action','{}'.format(action)])
        tab.add_row(['-------------------------', '------------------'])
        tab.add_row(['CASE',' '])
        tab.add_row(['case','{}'.format(case)])

        tab_info = str(tab)
        space = 5

        font = ImageFont.truetype(TTF_FILE, 15, encoding='utf-8')
        im = Image.new('RGB',(h, w),(0,0,0,0))
        draw = ImageDraw.Draw(im, "RGB")
        img_size = draw.multiline_textsize(tab_info, font=font)
        w_new = img_size[0]+space*2
        h_new = max(int(h_div_w_ratio * w_new), img_size[1]+space*2)
        im_new = im.resize((w_new, h_new)) # resize orider should be w, h
        del draw
        del im
        draw = ImageDraw.Draw(im_new, 'RGB')
        draw.multiline_text((space,space), str(tab_info), fill=(255,255,255), font=font)
        im_new.save(os.path.join(tmp_dir, os.path.basename(json_path).replace(".json", ".png")), "png", dpi=(600, 600))
        del draw
            
def put_on_canvas(canvas, arr, loc):
    try:
        x0,y0 = loc
        h,w = arr.shape[0], arr.shape[1]
        # h = min(h,canvas.shape[0]-x0)
        # w = min(w,canvas.shape[0]-y0)
        canvas[x0:x0+h, y0:y0+w] += arr
    except ValueError:
        print('canvas is tooooo small!!!!')

    return canvas

def get_bbox_arr(H, W,d,color):
    # 这个函数能初始化一个空心的框 （也可以变成实心的）
    arr = np.zeros([H,W,3]).astype(np.uint8)
    color = np.array(color).reshape(1,1,3)
    arr[:d] = color
    arr[H-d:] = color
    arr[:,:d] = color
    arr[:,W-d:] = color
    return arr


def get_lights_arr(box_resize_to, case, max_case):
    arr = np.zeros([box_resize_to[0]//5, box_resize_to[1],3]).astype(np.uint8)
    arr_h, arr_w, arr_c = arr.shape
    for i in range(max_case):
        loc = (int((arr_w//max_case)*(i+0.5)), arr_h//2 )
        rad = 17
        if int(case) == i:
            cv2.circle(arr, loc, rad+8, (0,255,0) ,-1)
        else:
            cv2.circle(arr, loc, rad, (0,0,255) ,-1)
    # print(cv2.imwrite("data/1.png", arr))
    return arr

def get_text_arr(box_resize_to, text="", color=(255, 255, 255), font_size=0.5):
    arr = np.zeros([box_resize_to[0]//5, box_resize_to[1],3]).astype(np.uint8)
    arr_h, arr_w, arr_c = arr.shape
    cv2.putText(arr, text, (0, 20), cv2.FONT_HERSHEY_SIMPLEX, font_size,  color, 1)
    # print(cv2.imwrite("data/1.png", arr))
    return arr
# get_text_arr([500, 500], text="CNN Teacher with its observation(top) \\ and action(bottom)")
# exit(0)
# get_lights_arr([450, 450], 1, 6)

def generate_video(env1, env2, video_path):

    # pngwithTable_dir = os.path.join("data/{}-withTable".format(env))
    # pngwithBox_dir = os.path.join("data/{}-withBox".format(env))

    # pngwithTable_path_list = glob(os.path.join(pngwithTable_dir, "*.png"))
    # pngwithBox_path_list = glob(os.path.join(pngwithBox_dir, "*.png"))
    # pngwithTable_path_list = sorted(pngwithTable_path_list)
    # pngwithBox_path_list = sorted(pngwithBox_path_list)
    # assert len(pngwithTable_path_list) == len(pngwithBox_path_list)

    env1_frames = sorted(glob(os.path.join("data/{}".format(env1), "*.png")))
    env2_frames = sorted(glob(os.path.join("data/{}".format(env2), "*.png")))

    canvas_seq = list()
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    videowriter = None

    # -------- main loop to generate video -------
    H_canvas, W_canvas = 900,1300

    box_resize_to = [550, 550]

    loc_env1_frame = [10,10]
    loc_env2_frame = [10,680]
    loc_lights_env1 = [loc_env1_frame[0]+box_resize_to[0], loc_env1_frame[1]]
    loc_lights_env2 = [loc_env2_frame[0]+box_resize_to[0], loc_env2_frame[1]]

    debug_count = 1.8e3
    color_env_box = [30, 100, 80]

    count = 0
    videowriter = cv2.VideoWriter(video_path, fourcc, 100, (W_canvas, H_canvas))
    for env1_path, env2_path in zip_longest(env1_frames, env2_frames):
        count += 1
        if count > debug_count:
            break
        
        canvas = np.zeros((H_canvas, W_canvas, 3)).astype(np.uint8)

        if env1_path is not None:
            img_env1 = cv2.imread(env1_path)
            img_env1 = cv2.resize(img_env1, box_resize_to)
            json_path_env1 = env1_path.replace("AstroRoboSasa-Nes-PPO/", "AstroRoboSasa-Nes-PPO-Json/").replace(".png", ".json")
            with open(json_path_env1) as f:
                env_dict_env1 = json.load(f)
            case_env1 = find_case1(env_dict_env1)   
            img_lights_env1 = get_lights_arr(box_resize_to, case_env1, max_case=6)
            canvas = put_on_canvas(canvas=canvas, arr=img_env1, loc=loc_env1_frame)
            canvas = put_on_canvas(canvas=canvas, arr=img_lights_env1, loc=loc_lights_env1)
            
            env_box_d = 10
            bbox_env1 = get_bbox_arr(H=box_resize_to[0]+2*env_box_d+(box_resize_to[0]//5), W=box_resize_to[0]+2*env_box_d, d=env_box_d, color=color_env_box)
            loc_env_box = [ loc_env1_frame[0]-env_box_d ,loc_env1_frame[1]-env_box_d ]
            canvas = put_on_canvas(canvas=canvas, arr=bbox_env1, loc=loc_env_box)

            text_env1 = get_text_arr(box_resize_to, text="CNN Teacher with its observation(top) and action(bottom)", font_size=0.57)
            loc_env1_text = [ loc_lights_env1[0]+box_resize_to[0]//4 ,loc_lights_env1[1]+env_box_d ]
            canvas = put_on_canvas(canvas=canvas, arr=text_env1, loc=loc_env1_text)
        else:
            img_env1 = None
    
        if env2_path is not None:
            img_env2 = cv2.imread(env2_path)
            img_env2 = cv2.resize(img_env2, box_resize_to)
            json_path_env2 = env2_path.replace("AstroRoboSasa-Nes/", "AstroRoboSasa-Nes-Json/").replace(".png", ".json")
            with open(json_path_env2) as f:
                env_dict_env2 = json.load(f)
            case_env2 = find_case2(env_dict_env2)
            img_lights_env2 = get_lights_arr(box_resize_to, case_env2, max_case=6)
            canvas = put_on_canvas(canvas=canvas, arr=img_env2, loc=loc_env2_frame)
            canvas = put_on_canvas(canvas=canvas, arr=img_lights_env2, loc=loc_lights_env2)
            
            env_box_d = 10
            bbox_env2 = get_bbox_arr(H=box_resize_to[0]+2*env_box_d+(box_resize_to[0]//5), W=box_resize_to[0]+2*env_box_d, d=env_box_d, color=color_env_box)
            loc_env_box = [ loc_env2_frame[0]-env_box_d ,loc_env2_frame[1]-env_box_d ]
            canvas = put_on_canvas(canvas=canvas, arr=bbox_env2, loc=loc_env_box)

            text_env2 = get_text_arr(box_resize_to, text="Distilled Symbolic Policy with its observation(top) and action(bottom)", font_size=0.47)
            loc_env2_text = [ loc_lights_env2[0]+box_resize_to[0]//4 ,loc_lights_env2[1] ]
            canvas = put_on_canvas(canvas=canvas, arr=text_env2, loc=loc_env2_text)
            # print(cv2.imwrite("data/1.png", canvas))
        else:
            img_env2 = None
            

        canvas_seq.append(canvas)

    for canvas in canvas_seq:
        videowriter.write(canvas)
    videowriter.release()
    cv2.destroyAllWindows()
        

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="")
    parser.add_argument("--env1", default="AstroRoboSasa-Nes-PPO", help="env name")
    parser.add_argument("--env2", default="AstroRoboSasa-Nes", help="env name")
    args = parser.parse_args()

    env1 = args.env1
    env2 = args.env2
    save_path = os.path.join("/home/zhiwen/projects/SymReL_original/data/videos/video#2", "{}.mp4".format(env1+env2))
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    # print("==================== conducting png2pngwithBox ====================")
    # png2pngwithBox(env)
    # print("==================== conducting json2pngwithTable ====================")
    # json2pngwithTable(env)
    print("==================== conducting generate_video ====================")
    generate_video(env1, env2, save_path)