import matplotlib.pyplot as plt
import os, sys
from posixpath import basename
import numpy as np
import cv2
import json
import argparse
from prettytable import PrettyTable
from PIL import Image, ImageDraw, ImageFont
from glob import glob
import pdb

PALLTE = np.array([[255, 255, 255], [128, 0, 0], [0, 128, 0], [0, 0, 128], [128, 128, 0], \
         [128, 0, 128], [0, 128, 128], [128, 128, 128], [255, 0, 0], [0, 255, 0], [0, 0, 255]])

TTF_FILE = "/usr/share/fonts/truetype/freefont/FreeMono.ttf"
# TTF_FILE = "FreeMono.ttf"

name_to_id_mapping = {"player": 0, "enemy": 1, "pong": 2}
id_to_name_mapping = {v: k for k, v in name_to_id_mapping.items()} 

action_to_string_mapping = {
            (0, 0, 0, 0, 1, 0, 0, 0, 0): "up",
            (0, 0, 0, 0, 0, 1, 0, 0, 0): "down",
            (0, 0, 0, 0, 0, 0, 0, 0, 0): "noop",
        }
string_to_action_mapping = {v: k for k, v in action_to_string_mapping.items()} 

DEBUG_COUNT = 1600

# ID_NAME_MAPPING = {0: "protagonist", 1:"bullet-of-protagonist", 2:"ring-object", 3:"egg", 4:"bullet-of-enemy"}
def png2pngwithBox(env):
    json_dir = os.path.join("data/{}-Json".format(env))
    tmp_dir = os.path.join("data/{}-withBox".format(env))
    os.makedirs(tmp_dir, exist_ok=True)
    json_path_list = sorted(glob(os.path.join(json_dir, "*json")))
    count = 0
    for json_path in json_path_list:
        count += 1
        if count > DEBUG_COUNT:
            break
        with open(json_path) as f:
            env_dict = json.load(f)
        img_path = json_path.replace("-Json", "").replace(".json", ".png")
        img = Image.open(img_path)
        for obj, val in env_dict.items():
            tl_x = val["topleft_x"]
            tl_y = val["topleft_y"] + 32
            br_x = val["bottomright_x"]
            br_y = val["bottomright_y"]+ 32
            obj_cls = val["object_class"]
            draw = ImageDraw.Draw(img, "RGBA")
            draw.rectangle(((tl_x, tl_y), (br_x, br_y)), outline=(PALLTE[obj_cls][0], PALLTE[obj_cls][1], PALLTE[obj_cls][2], 127), width=2)
            img.save("{}".format(os.path.join(tmp_dir, os.path.basename(img_path))))
        del draw

def find_action(env_dict):
    player_id = name_to_id_mapping["player"]
    for obj, val in env_dict.items():
        obj_class = val["object_class"]
        if obj_class == player_id:
            return val["action"]

def find_reward(env_dict):
    for obj, val in env_dict.items():
        if "total_reward" in val.keys():
            return env_dict[obj]["total_reward"]

def find_case(env_dict):
    for obj, val in env_dict.items():
        if "case" in val.keys():
            return env_dict[obj]["case"]

def json2pngwithTable(env):
    json_dir = os.path.join("data/{}-Json".format(env))
    tmp_dir = os.path.join("data/{}-withTable".format(env))
    os.makedirs(tmp_dir, exist_ok=True)
    json_path_list = sorted(glob(os.path.join(json_dir, "*json")))
    count = 0
    for json_path in json_path_list:
        count += 1
        if count > DEBUG_COUNT:
            break
        img_path = json_path.replace("-Json", "").replace(".json", ".png")
        # h, w, c = cv2.imread(img_path).shape
        # h_div_w_ratio = h/w
        h = 240
        w = 224

        tab = PrettyTable()
        tab.field_names = ["Meaning", "Real-time Value"]

        with open(json_path) as f:
            env_dict = json.load(f)

        action = find_action(env_dict)
        reward = find_reward(env_dict)
        case = find_case(env_dict)
        if case == -1:
            case = 0
        # y_coord = (env_dict["obj1"]["topleft_y"] + env_dict["obj1"]["bottomright_y"])//2
        # tab.add_row(['STATE',' '])
        # tab.add_row(['y-coord of player:','{}'.format(y_coord)])
        # x_coord = (env_dict["obj1"]["topleft_x"] + env_dict["obj1"]["bottomright_x"])//2
        # tab.add_row(['x-coord of ring-object:','{}'.format(x_coord)])
        # tab.add_row(['-------------------------', '------------------'])
        # tab.add_row(['REWARD',' '])
        tab.add_row(['Reward','{}'.format(reward)])
        tab.add_row(['-------------------------', '------------------'])
        # tab.add_row(['ACTION',' '])
        tab.add_row(['Action','{}'.format(action)])
        tab.add_row(['-------------------------', '------------------'])
        # tab.add_row(['CASE',' '])
        tab.add_row(['Case','{}'.format(case)])

        tab_info = str(tab)
        space = 5

        font = ImageFont.truetype(TTF_FILE, 15, encoding='utf-8')
        im = Image.new('RGB',(w, h),(0,0,0,0))
        draw = ImageDraw.Draw(im, "RGB")
        img_size = draw.multiline_textsize(tab_info, font=font)
        # w_new = img_size[0]+space*2
        # h_new = max(int(h_div_w_ratio * w_new), img_size[1]+space*2)
        w_new = 460 
        h_new = 429
        im_new = im.resize((w_new, h_new)) # resize orider should be w, h
        del draw
        del im
        draw = ImageDraw.Draw(im_new, 'RGB')
        draw.multiline_text((space,space), str(tab_info), fill=(255,255,255), font=font)
        im_new.save(os.path.join(tmp_dir, os.path.basename(json_path).replace(".json", ".png")), "png", dpi=(600, 600))
        del draw
            
def put_on_canvas(canvas, arr, loc):
    try:
        x0,y0 = loc
        h,w = arr.shape[0], arr.shape[1]
        # h = min(h,canvas.shape[0]-x0)
        # w = min(w,canvas.shape[0]-y0)
        canvas[x0:x0+h, y0:y0+w] += arr
    except ValueError:
        print('canvas is tooooo small!!!!')

    return canvas

def get_bbox_arr(L,d,color):
    # 这个函数能初始化一个空心的框 （也可以变成实心的）
    arr = np.zeros([L,L,3]).astype(np.uint8)
    color = np.array(color).reshape(1,1,3)
    arr[:d] = color
    arr[L-d:] = color
    arr[:,:d] = color
    arr[:,L-d:] = color
    return arr

def generate_video(env, video_path):

    pngwithTable_dir = os.path.join("data/{}-withTable".format(env))
    pngwithBox_dir = os.path.join("data/{}-withBox".format(env))

    pngwithTable_path_list = glob(os.path.join(pngwithTable_dir, "*.png"))
    pngwithBox_path_list = glob(os.path.join(pngwithBox_dir, "*.png"))
    pngwithTable_path_list = sorted(pngwithTable_path_list)
    pngwithBox_path_list = sorted(pngwithBox_path_list)
    # assert len(pngwithTable_path_list) == len(pngwithBox_path_list)
    canvas_seq = list()
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    videowriter = None

    # -------- main loop to generate video -------
    H_canvas, W_canvas = 900,1400
    env_frame_L = 300
    env_box_d = 20
    color_env_box = [30, 100, 80]

    loc_env_frame = [10,10]
    loc_state = [450,10]
    loc_case = [50, 500]

    case_resize_to = [800, 600]
    box_resize_to = [450, 450]

    for table_path, box_path in zip(pngwithTable_path_list, pngwithBox_path_list):
        img_box = cv2.imread(box_path)
        videowriter = cv2.VideoWriter(video_path, fourcc, 100, (W_canvas, H_canvas))
    loc_env_box = [ loc_env_frame[0]-env_box_d ,loc_env_frame[1]-env_box_d ]

    env_frame_out_box = get_bbox_arr(L=env_frame_L+2*env_box_d, d=env_box_d, color=color_env_box)

    count = 0
    for table_path, box_path in zip(pngwithTable_path_list, pngwithBox_path_list):
        count += 1
        if count > DEBUG_COUNT:
            break
        json_path = table_path.replace("-withTable", "-Json").replace(".png", ".json")
        with open(json_path) as f:
            env_dict = json.load(f)
        case = 0
        case = find_case(env_dict)
        if int(case) <= 0:
            case = 0
        case_image_path = os.path.join("data/{}-Case/{}.png".format(env, case))
        case_arr = cv2.imread(case_image_path)
        case_arr = cv2.resize(case_arr, case_resize_to)

        img_box = cv2.imread(box_path)
        img_box = cv2.resize(img_box, box_resize_to)
        img_table = cv2.imread(table_path)
        h, w, c = img_box.shape
        # img_table = cv2.resize(img_table, [w, h])  # change here
        canvas = np.zeros((H_canvas, W_canvas, c)).astype(np.uint8)
        canvas = put_on_canvas(canvas=canvas, arr=img_box, loc=loc_env_frame)
        canvas = put_on_canvas(canvas=canvas, arr=img_table, loc=loc_state)
        # canvas = put_on_canvas(canvas=canvas, arr=env_frame_out_box, loc=loc_env_box)
        canvas = put_on_canvas(canvas=canvas, arr=case_arr, loc=loc_case)
        # --- video 2: add 红绿灯 here （贴框） -----
        # canvas = put_on_canvas(canvas=canvas, arr=env_frame_out_box, loc=loc_env_box)
        

        # plt.imshow(canvas)

        canvas_seq.append(canvas)

    for canvas in canvas_seq:
        videowriter.write(canvas)
    videowriter.release()
    cv2.destroyAllWindows()
        

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="")
    parser.add_argument("--env", default="Pong-Atari2600", help="env name")
    args = parser.parse_args()

    env = args.env
    save_path = os.path.join("/home/zhiwen/projects/SymReL_original/data/videos/video#1", "{}.mp4".format(env))
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    # print("==================== conducting png2pngwithBox ====================")
    # png2pngwithBox(env)
    # print("==================== conducting json2pngwithTable ====================")
    # json2pngwithTable(env)
    print("==================== conducting generate_video ====================")
    generate_video(env, save_path)