import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
import wandb, os, torch, imageio, pickle
from utils.logger import logger
import cv2, glob

def visual_ood_states_reachability(algo_obj, step, mode='image', suffix='online'):
    # visualize the ood states in image
    # get one ood trajectory from dataset
    file_index = np.random.randint(len(algo_obj._ood_traj_files), size=1).item()
    ood_path = algo_obj._ood_traj_files[file_index]
    with open(ood_path, "rb") as f:
        demos = pickle.load(f)
        if not isinstance(demos, list):
            demos = [demos]
        demo_indx = np.random.randint(len(demos), size=1).item()
        # print(f"Visualizing OOD Trajectory: {ood_path}, Demo Index: {demo_indx}")
        demo = demos[demo_indx]
        
        assert len(demo["obs"]) == len(demo["actions"]) + 1

        temp = []
        for item in demo["obs"]:
            temp.append(item["ob"])
        
        ood_data = {}
        ood_data["ob"] = {"ob": np.array(temp[:-1])}
        ood_data["ob_next"] = {"ob": np.array(temp[1:])}
        
        if algo_obj._config.reach_discriminator_input:
            temp_ac = []
            for item in demo["actions"]:
                temp_ac.append(item["ac"])
            ood_data["ac"] = {"ac": np.array(temp_ac)}

    ood_data_o = algo_obj._preprocess_data(ood_data, key="ob")
    if algo_obj._config.reach_discriminator_input:
        ood_data_ac = algo_obj._preprocess_data(ood_data, key="ac")
        ood_logit, _ = algo_obj._reachable(ood_data_o, ood_data_ac)
    else: # input s' only
        ood_logit, _ = algo_obj._reachable(ood_data_o)

    ood_output = torch.sigmoid(ood_logit)
    update_ite_indx = ood_path.split('_')[1]
    fname =  "ood_{}_{:05d}_step_{:011d}_{}".format(update_ite_indx, demo_indx, step, suffix)
    
    if mode == 'video':
        video = visualize_reachability_traj_video(algo_obj, ood_data["ob"]["ob"][:, :], ood_output, fname + '.mp4')
        return video
    else:
        ood_reachable_indx = torch.mode(ood_output.squeeze(dim=-1) > algo_obj._config.reachability_threshold, dim=0).values

        if ood_reachable_indx.sum() > 0:
            if algo_obj._config.backwards_relabelling:
                # find the biggest index that is reachable
                ood_reachable_indx = torch.max(torch.nonzero(ood_reachable_indx))

                # all the states before this index are reachable
                ood_reachable_data = ood_data["ob"]["ob"][:ood_reachable_indx+1, :]
                ood_unreachable_data = ood_data["ob"]["ob"][ood_reachable_indx+1:, :]
            else:
                # all the states that are reachable
                _ood_reachable_indx = torch.nonzero(ood_reachable_indx).cpu().numpy().squeeze()
                _ood_unreachable_indx = torch.nonzero(~ood_reachable_indx).cpu().numpy().squeeze()
                ood_reachable_data = ood_data["ob"]["ob"][_ood_reachable_indx, :]
                ood_unreachable_data = ood_data["ob"]["ob"][_ood_unreachable_indx, :]
            
            im = visualize_reachability(algo_obj, ood_reachable_data, step, fname, ood_unreachable_data, ood=True, wall=True)
            return im
        else:
            return None

from matplotlib.colors import LinearSegmentedColormap
color_list = ['skyblue','white','pink']
my_cmap = LinearSegmentedColormap.from_list("",color_list)

def visualize_reachability_traj_video(algo_obj, ood_traj, ood_reachability, fname, wall=True):
    # visualize the ood trajectory in video
    ood_reachability = ood_reachability.squeeze().detach().cpu().numpy()
    traj_data = ood_traj[:, 0:2]

    # expert_traj = algo_obj._data_dataset[algo_obj._config.target_task_index_in_demo_path]._data
    expert_traj = algo_obj._dataset._data # TODO: add if else for different algo (GAIL vs. DVD)
    color = algo_obj._config.env.split('-')[2]
    dict_ = {'red': 4, 'blue': 6, 'magenta': 8, 'yellow': 10}
    goal_pos = expert_traj[0]['ob']['ob'][dict_[color]:dict_[color]+2]
    expert = []
    for item in expert_traj:
        expert.append(item['ob']['ob'][0:2])
    expert = np.stack(expert, axis=0)

    x = np.concatenate((expert[:, 0:1], np.expand_dims(goal_pos[0:1], axis=0)), axis=0)
    y = np.concatenate((expert[:, 1:2], np.expand_dims(goal_pos[1:2], axis=0)), axis=0)
    x = x + 0.65
    y = y + 0.75
    colors = np.concatenate(([1 for _ in range(len(expert))], [12]), axis=0)
    size = np.concatenate(([5 for _ in range(len(expert))], [100]), axis=0)

    if wall:
        def parse_maze(maze_str):
            lines = maze_str.strip().split('\\')
            width, height = len(lines), len(lines[0])
            maze_arr = np.zeros((width, height), dtype=np.int32)
            for w in range(width):
                for h in range(height):
                    tile = lines[w][h]
                    if tile == '#':
                        maze_arr[w][h] = 0
                    elif tile == 'G':
                        maze_arr[w][h] = 255
                    elif tile == ' ' or tile == 'O' or tile == '0':
                        maze_arr[w][h] = 64
                    elif tile == 'D':
                        maze_arr[w][h] = 200
                    else:
                        raise ValueError('Unknown tile type: %s' % tile)
            return maze_arr

        maze = parse_maze(algo_obj.layout)
        # flip and rotate maze so that the maze layout matches the dots in scatter plot
        maze = np.flip(maze, axis=1)
        maze = np.rot90(maze, k=1, axes=(0, 1))
        
    # animate the trajectory
    frames = []
    # TODO: change the colormap? viridis/Set3/magma/RdBu
    for ite in range(len(traj_data)):
        plt.pcolormesh(maze)
        # plt.scatter(expert[:, 0:1] + 0.65, expert[:, 1:2] + 0.75, c='g', s=5)
        # plt.scatter(goal_pos[0:1] + 0.65, goal_pos[1:2] + 0.75, c='y', s=100)
        ss0 = plt.scatter(x, y, c=colors, s=size, cmap='Set3')
        ss = plt.scatter(traj_data[:ite+1, 0:1] + 0.65, traj_data[:ite+1, 1:2] + 0.75, c=ood_reachability[:ite+1], s=20, cmap=my_cmap, vmin=0, vmax=1)
        plt.colorbar(ss)
        plt.legend(['target task demos', 'ood_traj', 'goal'], bbox_to_anchor=(0.5, -0.05), loc='upper center', ncol=3)

        # remove white padding
        plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
        plt.axis('off')
        plt.axis('image')

        # redraw the canvas
        fig = plt.gcf()
        fig.canvas.draw()

        # convert canvas to image using numpy
        img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
        img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        
        # add text
        h, w = img.shape[:2]
        img = np.concatenate([img, np.zeros((50, w, 3))], 0)
        text = 'reachability: {}'.format(ood_reachability[ite])
        cv2.putText(img, text, (10, h+50), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
        # convert img from float to uint8
        img = img.astype(np.uint8)

        plt.close()
        frames.append(img)
    
    path = os.path.join(algo_obj._config.ood_visual_dir, fname)
    imageio.mimsave(path, frames, fps=15.0)
    video = wandb.Video(path, fps=15, format="mp4")
    
    return video

def visualize_reachability(algo_obj, reachable_data, step, suffix, _unreach=None, ood=False, wall=True): 
    # visualize the reachable and unreachable states in image
    if ood:
        print(reachable_data.shape)
        if reachable_data.ndim == 1:
            reachable_data = np.expand_dims(reachable_data, axis=0)
        if _unreach.ndim == 1:
            _unreach = np.expand_dims(_unreach, axis=0)
        reach = reachable_data[:, 0:2]
        unreach = _unreach[:, 0:2]
    else:
        reach = []
        for item in reachable_data:
            reach.append(item['ob'][0]['ob'][0:2])
        reach = np.stack(reach, axis=0)

        visual_unreachable_data, _, _ = algo_obj._unreachable_buffer.sample(algo_obj._config.batch_size)
        visual_unreachable_data_ = visual_unreachable_data['ob']['ob'][:]  # {'ob':{'ob': bs x 12 array of states}}
        unreach = []
        for item in visual_unreachable_data_:
            unreach.append(item[0:2])
        unreach = np.stack(unreach, axis=0)

    #expert_traj = algo_obj._data_dataset[algo_obj._config.target_task_index_in_demo_path]._data
    expert_traj = algo_obj._dataset._data # TODO: add if else for different algo (GAIL vs. DVD)
    color = algo_obj._config.env.split('-')[2]
    dict_ = {'red': 4, 'blue': 6, 'magenta': 8, 'yellow': 10}
    goal_pos = expert_traj[0]['ob']['ob'][dict_[color]:dict_[color]+2]
    expert = []
    for item in expert_traj:
        expert.append(item['ob']['ob'][0:2])
    expert = np.stack(expert, axis=0)

    # plot reach, unreach, and expert in different colors
    plt.clf()
    if wall:
        def parse_maze(maze_str):
            lines = maze_str.strip().split('\\')
            width, height = len(lines), len(lines[0])
            maze_arr = np.zeros((width, height), dtype=np.int32)
            for w in range(width):
                for h in range(height):
                    tile = lines[w][h]
                    if tile == '#':
                        maze_arr[w][h] = 0
                    elif tile == 'G':
                        maze_arr[w][h] = 255
                    elif tile == ' ' or tile == 'O' or tile == '0':
                        maze_arr[w][h] = 64
                    elif tile == 'D':
                        maze_arr[w][h] = 200
                    else:
                        raise ValueError('Unknown tile type: %s' % tile)
            return maze_arr

        maze = parse_maze(algo_obj.layout)
        # flip and rotate maze so that the maze layout matches the dots in scatter plot
        maze = np.flip(maze, axis=1)
        maze = np.rot90(maze, k=1, axes=(0, 1))
        plt.pcolormesh(maze)
        
    if ood:
        fname = suffix + '.png'
        path = os.path.join(algo_obj._config.ood_visual_dir, fname)
    else:
        fname = "outer_{:09d}_inner_{:09d}_step_{:011d}_{}.png".format(algo_obj._pretrain_outer_loop,
                                                                        algo_obj._pretrain_inner_loop, step, suffix)

        path = os.path.join(algo_obj._config.pretrain_dir, fname)

    x = np.concatenate((expert[:, 0:1], reach[:, 0:1], unreach[:, 0:1], np.expand_dims(goal_pos[0:1], axis=0)),
                        axis=0)
    y = np.concatenate((expert[:, 1:2], reach[:, 1:2], unreach[:, 1:2], np.expand_dims(goal_pos[1:2], axis=0)),
                        axis=0)
    x = x + 0.65
    y = y + 0.75

    colors = np.concatenate(([1 for _ in range(len(expert))], [8 for _ in range(len(reach))], [5 for _ in range(len(unreach))], [12]), axis=0)
    if unreach.shape[0] != 0:
        text_labels = ['target task demos', 'unreach', 'new re-label', 'goal'] # for ood visualization, unreach is the rest of the ood trajectory except the re-label part
    else:
        text_labels = ['target task demos', 'new re-label', 'goal']
    size = np.concatenate(([5 for _ in range(len(expert))], [10 for _ in range(len(reach))], [10 for _ in range(len(unreach))], [100]), axis=0)
    ss = plt.scatter(x, y, c=colors, s=size, cmap='Set3')
    plt.axis('square')
    plt.legend(ss.legend_elements()[0], text_labels, bbox_to_anchor=(0.5, -0.05), loc='upper center', ncol=6)
    plt.savefig(path)
    im = wandb.Image(path)
    return im

def visualize_reachability_colormap(algo_obj, step, suffix='colormap'):

    lines = algo_obj.layout.strip().split('\\')
    width, height = len(lines), len(lines[0])
    maze_points = []
    maze_ac = []
    # for w in range(0, 2*width+1):
    #     for h in range(0, 2*height+1):
    #         for x_direction in range(-1, 2, 1):
    #             for y_direction in range(-1, 2, 1):
    #                 maze_points.append([w/2, h/2, x_direction, y_direction, 4.0, 16.0, 4.0, 7.0, 16.0, 7.0, 16.0, 13.0]) # assume the agent is still
    #                 #maze_ac.append([x_direction, y_direction])
    for w in range(0, width):
        for h in range(0, height):
            tile = lines[w][h]
            if tile == '#':
                continue
            for x_direction in range(-1, 2, 1):
                for y_direction in range(-1, 2, 1):
                    maze_points.append([w, h, x_direction, y_direction, 4.0, 16.0, 4.0, 7.0, 16.0, 7.0, 16.0, 13.0]) # assume the agent is still
                    

    maze_points = np.array(maze_points)
    maze_points = {'ob': {'ob': maze_points}}
    maze_points_o = algo_obj._preprocess_data(maze_points, key="ob")
    
    if algo_obj._config.reach_discriminator_input:
        maze_ac = np.array(maze_ac)
        maze_ac = {'ac': {'ac': maze_ac}}
        maze_ac_o = algo_obj._preprocess_data(maze_ac, key="ac")
        maze_logit, _ = algo_obj._reachable(maze_points_o, maze_ac_o)
    else:
        maze_logit, _ = algo_obj._reachable(maze_points_o)
    maze_output = torch.sigmoid(maze_logit)
    maze_output = maze_output.squeeze().detach().cpu().numpy()

    # create an arrary of size (map_width, map_height) to store the output
    _count = 0
    # map_width, map_height = (2*width + 1)*3, (2*height + 1)*3
    # maze_output_array = np.zeros((map_width, map_height))
    # for w in range(int(map_width//3)):
    #     for h in range(int(map_height//3)):
    #         for x_direction in range(3):
    #             for y_direction in range(3):
    #                 maze_output_array[w*3 + x_direction][h*3 + y_direction] = maze_output[_count]
    #                 _count += 1
    maze_output_array = np.zeros((width, height))
    for w in range(width):
        for h in range(height):
            if lines[w][h] == '#':
                maze_output_array[w][h] = -0.1
            else:
                for x_direction in range(3):
                    for y_direction in range(3):
                        average_value = maze_output[_count]
                        _count += 1
                
                maze_output_array[w][h] = average_value/9

    maze_output_array = np.flip(maze_output_array, axis=1)
    maze_output_array = np.rot90(maze_output_array, k=1, axes=(0, 1))
    print(maze_output_array)
    plt.clf()
    plt.pcolormesh(maze_output_array)
    plt.colorbar()
    fname = "outer_{:09d}_inner_{:09d}_step_{:011d}_{}.png".format(algo_obj._pretrain_outer_loop,
                                                                    algo_obj._pretrain_inner_loop, step, 'colormap_' + suffix)
    print(fname)
    path = os.path.join(algo_obj._config.pretrain_dir, fname)
    plt.savefig(path)
    im = wandb.Image(path)
    return im


def visualize_reachability_all(algo_obj, suffix, wall=True): 
    # visualize all the states in the reachable and unreachable buffer in two images
    visual_unreachable_data = algo_obj._unreachable_buffer._buffer
    unreach = []
    for item in visual_unreachable_data['ob']:
        unreach.append(item[0]['ob'][0:2])
    unreach = np.stack(unreach, axis=0)

    visual_reachable_data = algo_obj._reachable_buffer._buffer
    reach = []
    for item in visual_reachable_data['ob']:
        reach.append(item[0]['ob'][0:2])
    reach = np.stack(reach, axis=0)

    expert_traj = algo_obj._data_dataset[algo_obj._config.target_task_index_in_demo_path]._data
    color = algo_obj._config.env.split('-')[2]
    dict_ = {'red': 4, 'blue': 6, 'magenta': 8, 'yellow': 10}
    goal_pos = expert_traj[0]['ob']['ob'][dict_[color]:dict_[color]+2]
    expert = []
    for item in expert_traj:
        expert.append(item['ob']['ob'][0:2])
    expert = np.stack(expert, axis=0)

    # plot reach, unreach, and expert in different colors
    plt.clf()
    def parse_maze(maze_str):
        lines = maze_str.strip().split('\\')
        width, height = len(lines), len(lines[0])
        maze_arr = np.zeros((width, height), dtype=np.int32)
        for w in range(width):
            for h in range(height):
                tile = lines[w][h]
                if tile == '#':
                    maze_arr[w][h] = 0
                elif tile == 'G':
                    maze_arr[w][h] = 255
                elif tile == ' ' or tile == 'O' or tile == '0':
                    maze_arr[w][h] = 64
                elif tile == 'D':
                    maze_arr[w][h] = 200
                else:
                    raise ValueError('Unknown tile type: %s' % tile)
        return maze_arr

    maze = parse_maze(algo_obj.layout)
    # flip and rotate maze so that the maze layout matches the dots in scatter plot
    maze = np.flip(maze, axis=1)
    maze = np.rot90(maze, k=1, axes=(0, 1))
    plt.pcolormesh(maze)
        

    fname_reach, fname_unreach = "all_reach_{}.png".format(suffix), "all_unreach_{}.png".format(suffix)
    path_reach, path_unreach = os.path.join(algo_obj._config.pretrain_dir, fname_reach), os.path.join(algo_obj._config.pretrain_dir, fname_unreach)

    x_reach = np.concatenate((expert[:, 0:1], reach[:, 0:1], np.expand_dims(goal_pos[0:1], axis=0)), axis=0)
    y_reach = np.concatenate((expert[:, 1:2], reach[:, 1:2], np.expand_dims(goal_pos[1:2], axis=0)), axis=0)
    x_reach = x_reach + 0.65
    y_reach = y_reach + 0.75

    x_unreach = np.concatenate((expert[:, 0:1], unreach[:, 0:1], np.expand_dims(goal_pos[0:1], axis=0)), axis=0)
    y_unreach = np.concatenate((expert[:, 1:2], unreach[:, 1:2], np.expand_dims(goal_pos[1:2], axis=0)), axis=0)
    x_unreach = x_unreach + 0.65
    y_unreach = y_unreach + 0.75

    colors = np.concatenate(([1 for _ in range(len(expert))], [8 for _ in range(len(reach))], [12]), axis=0)
    text_labels = ['target task demos', 'reach', 'goal'] 
    size = np.concatenate(([5 for _ in range(len(expert))], [10 for _ in range(len(reach))], [100]), axis=0)
    ss = plt.scatter(x_reach, y_reach, c=colors, s=size, cmap='Set3')
    plt.axis('square')
    plt.legend(ss.legend_elements()[0], text_labels, bbox_to_anchor=(0.5, -0.05), loc='upper center', ncol=3)
    plt.savefig(path_reach)
    im_reach = wandb.Image(path_reach)

    colors = np.concatenate(([1 for _ in range(len(expert))], [8 for _ in range(len(unreach))], [12]), axis=0)
    text_labels = ['target task demos', 'unreach', 'goal']
    size = np.concatenate(([5 for _ in range(len(expert))], [10 for _ in range(len(unreach))], [100]), axis=0)
    plt.clf()
    plt.pcolormesh(maze)
    ss = plt.scatter(x_unreach, y_unreach, c=colors, s=size, cmap='Set3')
    plt.axis('square')
    plt.legend(ss.legend_elements()[0], text_labels, bbox_to_anchor=(0.5, -0.05), loc='upper center', ncol=3)
    plt.savefig(path_unreach)
    im_unreach = wandb.Image(path_unreach)
    return im_reach, im_unreach


### visual tool for trainner.py
# from PIL import Image
# import seaborn as sn
def create_heatmap(trainer_obj, rollout, step, target_taskID=None):
    """
    Creates a heatmap of the rollouts.
    """
    heatmap = np.zeros((trainer_obj._env.env.env.env.maze_arr.shape))
    for ob, ac in zip(rollout["ob"], rollout["ac"]): # len(ob) = len(ac) + 1, iterate over all actions
        ob_t, ac_t = {}, {}
        ob_t["ob"] = torch.tensor(ob["ob"], dtype=torch.float32).to(trainer_obj._config.device)
        ac_t["ac"] = torch.tensor(ac["ac"], dtype=torch.float32).to(trainer_obj._config.device)
        
        reward = trainer_obj._agent._predict_reward(ob_t, ac_t, target_taskID)["rew"].cpu().numpy().squeeze()
        heatmap[int(ob["ob"][0]), int(ob["ob"][1])] += reward
        print("step:%d, pos:(%d, %d), reward:%f" % (step, int(ob["ob"][0]), int(ob["ob"][1]), reward))

    
    # set the start and goal positions in the heatmap
    target_pos = trainer_obj._env.env.env.env._target
    array = heatmap.flatten()
    array.sort()
    range = (array[-1] - array[0])/10.0
    # agent position in the maze isn't precise, but doesn't effect a lot as a whole
    heatmap[int(rollout["ob"][0]["ob"][0]), int(rollout["ob"][0]["ob"][1])] = array[0]-range 
    heatmap[int(target_pos[0]), int(target_pos[1])] = array[-1]+range
    print(heatmap)

    # Clear the current figure. Otherwise, the heatmap will have multiple legends
    plt.clf()
    hm = sn.heatmap(data = heatmap)
    fname = "{}_step_{:011d}_{}.png".format(trainer_obj._config.env, step, trainer_obj._config.num_eval,)
    path = os.path.join(trainer_obj._config.record_dir, fname)
    figure = hm.get_figure()
    figure.savefig(path)
    # figure is rotated by 180 degrees, which is the same as the video
    im = Image.open(path)
    out = im.rotate(180, expand=True)
    out.save(path)

    wandb.log({"test_ep/heatmap": wandb.Image(path)}, step=step)

def visualize_rollout(trainer_obj, rollout, step, target_taskID=None):
    obs, act, rewards = [], [], []
    for ob, ac in zip(rollout["ob"], rollout["ac"]): # len(ob) = len(ac) + 1, iterate over all actions
        ob_t, ac_t = {}, {}
        ob_t["ob"] = torch.tensor(ob["ob"], dtype=torch.float32).to(trainer_obj._config.device)
        ac_t["ac"] = torch.tensor(ac["ac"], dtype=torch.float32).to(trainer_obj._config.device)
        
        reward = trainer_obj._agent._predict_reward(ob_t, ac_t, target_taskID)["rew"].cpu().numpy().squeeze()
        rewards.append(reward)
        obs.append(ob["ob"])
        act.append(ac["ac"])
        #print("step:%d, pos:(%d, %d), reward:%f" % (step, int(ob["ob"][0]), int(ob["ob"][1]), reward))

    fname = "{}_step_{:011d}_visual{}.png".format(trainer_obj._config.env, step, trainer_obj._config.num_eval,)
    path = os.path.join(trainer_obj._config.record_dir, fname)
    plt.scatter(obs, act, c=rewards, cmap='Blues')
    plt.savefig(path)
    # figure is rotated by 180 degrees, which is the same as the video
    im = Image.open(path)
    out = im.rotate(180, expand=True)
    out.save(path)

    wandb.log({"test_ep/visualized_rollouts": wandb.Image(path)}, step=step)       

# Custom sort key function
import re
def extract_number(filename):
    match = re.search(r's(\d+)_repeat', filename)
    if match:
        return int(match.group(1))
    return 0  # Return a default value if no number is found

def make_gif(path, video_p, algo='dvd', format="mp4", is_sigmoid=False):
    # format = "mp4" or "gif"
    # make gif from the images
    files = os.listdir(path)
    prefix_dvd = "heatmap_dvd_s"
    prefix_reach = "heatmap_reach_s"
    prefix_total = "heatmap_total_s"

    image_file_dvd, image_file_reach, image_file_total = [], [], []
    for file in files:
        if file.startswith(prefix_dvd):
            image_file_dvd.append(file)
    image_file_dvd = sorted(image_file_dvd, key=extract_number)
    img_dvd = [imageio.imread(os.path.join(path, file)) for file in image_file_dvd]
    if is_sigmoid:
        imageio.mimsave(os.path.join(video_p, f'heatmap_dvd_sig.{format}'), img_dvd)
    else:
        imageio.mimsave(os.path.join(video_p, f'heatmap_dvd.{format}'), img_dvd)

    if algo == "reachable_gail":
        for file in files:
            if file.startswith(prefix_reach):
                image_file_reach.append(file)
            elif file.startswith(prefix_total):
                image_file_total.append(file)
        image_file_reach = sorted(image_file_reach, key=extract_number)
        image_file_total = sorted(image_file_total, key=extract_number)

        img_reach = [imageio.imread(os.path.join(path, file)) for file in image_file_reach]
        img_total = [imageio.imread(os.path.join(path, file)) for file in image_file_total]
        if is_sigmoid:
            imageio.mimsave(os.path.join(video_p, f'heatmap_reach_sig.{format}'), img_reach)
            imageio.mimsave(os.path.join(video_p, f'heatmap_total_sig.{format}'), img_total)
        else:
            imageio.mimsave(os.path.join(video_p, f'heatmap_reach.{format}'), img_reach)
            imageio.mimsave(os.path.join(video_p, f'heatmap_total.{format}'), img_total)

# visualize the heatmap with 8 direactions
def plot_square_sector(num_sectors, ax, center, size, angle_start, delta_angle, color):
    # Converts angle and size to vertex coordinates within a square
    def square_sector_vertices(center, size, angle_start, delta_angle):
        if num_sectors == 4:
            if angle_start == 0.0:
                vertices = [center, (center[0], center[1] + size/2), (center[0] + size/2, center[1] + size/2), (center[0] + size/2, center[1])]
            elif angle_start == 90.0:
                vertices = [center, (center[0] + size/2, center[1]), (center[0] + size/2, center[1] - size/2), (center[0], center[1] - size/2)]
            elif angle_start == 180.0:
                vertices = [center, (center[0], center[1] - size/2), (center[0] - size/2, center[1] - size/2), (center[0] - size/2, center[1])]
            elif angle_start == 270.0:
                vertices = [center, (center[0] - size/2, center[1]), (center[0] - size/2, center[1] + size/2), (center[0], center[1] + size/2)]
            else:
                raise ValueError("angle_start must be 0, 90, 180, or 270")

        elif num_sectors == 8:
            if angle_start == 0.0:
                vertices = [center, (center[0], center[1] + size/2), (center[0] + size/2, center[1] + size/2)]
            elif angle_start == 45.0:
                vertices = [center, (center[0] + size/2, center[1] + size/2), (center[0]+ size/2, center[1] )]
            elif angle_start == 90.0:
                vertices = [center, (center[0] + size/2, center[1]), (center[0] + size/2, center[1] - size/2)]
            elif angle_start == 135.0:
                vertices = [center, (center[0] + size/2, center[1] - size/2), (center[0], center[1] - size/2)]
            elif angle_start == 180.0:
                vertices = [center, (center[0], center[1] - size/2), (center[0] - size/2, center[1] - size/2)]
            elif angle_start == 225.0:
                vertices = [center, (center[0] - size/2, center[1] - size/2), (center[0] - size/2, center[1])]
            elif angle_start == 270.0:
                vertices = [center, (center[0] - size/2, center[1]), (center[0] - size/2, center[1] + size/2)]
            elif angle_start == 315.0:
                vertices = [center, (center[0] - size/2, center[1] + size/2), (center[0], center[1] + size/2)]
            else:
                raise ValueError("angle_start must be 0, 45, 90, 135, 180, 225, 270, or 315")
        
        return vertices

    vertices = square_sector_vertices(center, size, angle_start, delta_angle)
    polygon = patches.Polygon(vertices, closed=True, color=color)
    ax.add_patch(polygon)

def draw_heatmap_sectors(ax, data, num_sectors, vmin, vmax, available_loc, drawe_arrow=True):
    n, m, sectors = data.shape
    assert sectors == num_sectors, "Data third dimension must match number of sectors"
    
    delta_angle = 360 / num_sectors
    cell_size = 1  # Full size of each cell
    
    # Normalize data for color mapping
    norm = plt.Normalize(vmin, vmax)
    cmap = plt.cm.viridis
    
    for i in range(n):
        for j in range(m):
            center = (i + 0.5, j + 0.5)  # center of the cell
            for k in range(num_sectors):
                angle_start = k * delta_angle
                color = cmap(norm(data[i, j, k]))
                plot_square_sector(num_sectors, ax, center, cell_size, angle_start, delta_angle, color)
    
    # Draw grid lines for the whole grid
    for i in range(n + 1):
        ax.plot([0, m], [i, i], 'k-')
    for j in range(m + 1):
        ax.plot([j, j], [0, n], 'k-')
    
    # Add arrows
    ## get the index of the max value in each sector
    if drawe_arrow:
        directions = np.argmax(data, axis=2)

        # Direction mappings: 0 = up, 1 = right, 2 = down, 3 = left
        # These are (dx, dy) for each direction
        if num_sectors == 4:
            direction_map = {
                0: (0, 0.5),   # Up
                1: (0.5, 0),   # Right
                2: (0, -0.5),  # Down
                3: (-0.5, 0)   # Left
            }
        elif num_sectors == 8:
            direction_map = {
                0: (0, 0.5),   # Up
                1: (0.5, 0.5), # Up-right
                2: (0.5, 0),   # Right
                3: (0.5, -0.5),# Down-right
                4: (0, -0.5),  # Down
                5: (-0.5, -0.5),# Down-left
                6: (-0.5, 0),  # Left
                7: (-0.5, 0.5) # Up-left
            }
        for i in available_loc:
            # Convert angle to radians and calculate vector components
            dx, dy = direction_map[directions[i[0], i[1]]]

            # Draw arrow (ax.arrow(x, y, dx, dy))
            if num_sectors == 4:
                if directions[i[0], i[1]] == 0:
                    ax.arrow(i[0] + 0.5, i[1], dx, dy, head_width=0.3, head_length=0.3, fc='black', ec='black')
                elif directions[i[0], i[1]] == 2:
                    ax.arrow(i[0] + 0.5, i[1] + 1.0, dx, dy, head_width=0.3, head_length=0.3, fc='black', ec='black')
                elif directions[i[0], i[1]] == 1:
                    ax.arrow(i[0], i[1] + 0.5, dx, dy, head_width=0.3, head_length=0.3, fc='black', ec='black')
                elif directions[i[0], i[1]] == 3:
                    ax.arrow(i[0] + 1, i[1] + 0.5, dx, dy, head_width=0.3, head_length=0.3, fc='black', ec='black')

            elif num_sectors == 8:
                if directions[i[0], i[1]] == 0:
                    ax.arrow(i[0] + 0.5, i[1], dx, dy, head_width=0.3, head_length=0.3, fc='black', ec='black')
                elif directions[i[0], i[1]] == 1:
                    ax.arrow(i[0], i[1], dx, dy, head_width=0.3, head_length=0.3, fc='black', ec='black')
                elif directions[i[0], i[1]] == 2:
                    ax.arrow(i[0], i[1] + 0.5, dx, dy, head_width=0.3, head_length=0.3, fc='black', ec='black')
                elif directions[i[0], i[1]] == 3:
                    ax.arrow(i[0], i[1] + 1, dx, dy, head_width=0.3, head_length=0.3, fc='black', ec='black')
                elif directions[i[0], i[1]] == 4:
                    ax.arrow(i[0] + 0.5, i[1] + 1, dx, dy, head_width=0.3, head_length=0.3, fc='black', ec='black')
                elif directions[i[0], i[1]] == 5:
                    ax.arrow(i[0] + 1, i[1] + 1, dx, dy, head_width=0.3, head_length=0.3, fc='black', ec='black')
                elif directions[i[0], i[1]] == 6:
                    ax.arrow(i[0] + 1, i[1] + 0.5, dx, dy, head_width=0.3, head_length=0.3, fc='black', ec='black')
                elif directions[i[0], i[1]] == 7:
                    ax.arrow(i[0] + 1, i[1], dx, dy, head_width=0.3, head_length=0.3, fc='black', ec='black')
        
    ax.set_xlim(0, m)
    ax.set_ylim(0, n)
    ax.set_aspect('equal')

def visualize_heatmap_8sections(repeat_time, data, output_path, x, y, colors, size, available_loc):
    plt.clf()
    fig, ax = plt.subplots()
    min_rew, max_rew = np.min(data), np.max(data)
    draw_heatmap_sectors(ax, data, repeat_time, min_rew, max_rew, available_loc)
    plt.scatter(x, y, c=colors, s=size, cmap='Set3')

    sm = plt.cm.ScalarMappable(cmap=plt.cm.viridis, norm=plt.Normalize(vmin=min_rew, vmax=max_rew))
    cbar = plt.colorbar(sm, ax=ax, orientation='vertical')
    cbar.set_label('Reward Value')
    fig.savefig(output_path, dpi=300)
    plt.close(fig)# Close the plot to free up memory

def sigmoid(z):
    return 1/(1 + np.exp(-z))

def heatmap_arrows(data, output_path, x, y, colors, size, available_loc):
    plt.clf()
    fig, ax = plt.subplots()

    # draw heatmap with arrows
    min_rew, max_rew = np.min(data), np.max(data)
    data = data.reshape(21, 21, -1)
    data = np.repeat(data, 4, axis=2)
    # ax.imshow(data, cmap='viridis', interpolation='nearest')
    draw_heatmap_sectors(ax, data, 4, min_rew, max_rew, available_loc, drawe_arrow=False)
    plt.scatter(x, y, c=colors, s=size, cmap='Set3')

    # each cell has a direction, where pointing to the max value of the cells around it in the available_loc
    for i in available_loc:
        # find the near cells
        xx, yy = i[0], i[1]
        near_cells = []
        for dx in [-1, 0, 1]:
            for dy in [-1, 0, 1]:
                if (dx, dy) != (0, 0) and (xx + dx, yy + dy) in available_loc:
                    near_cells.append((data[xx + dx, yy + dy,0], dx, dy))
        
        # Convert angle to radians and calculate vector components
        max_value = -np.inf
        for cell in near_cells:
            if cell[0] > max_value:
                max_value = cell[0]
                dx, dy = cell[1], cell[2]

        dx, dy = dx * 0.5, dy * 0.5
        # Draw arrow (ax.arrow(x, y, dx, dy))
        if dx==0 and dy==0.5:
            ax.arrow(i[0] + 0.5, i[1], dx, dy, head_width=0.3, head_length=0.3, fc='black', ec='black')
        elif dx==0.5 and dy==0.5:
            ax.arrow(i[0], i[1], dx, dy, head_width=0.3, head_length=0.3, fc='black', ec='black')
        elif dx==0.5 and dy==0:
            ax.arrow(i[0], i[1] + 0.5, dx, dy, head_width=0.3, head_length=0.3, fc='black', ec='black')
        elif dx==0.5 and dy==-0.5:
            ax.arrow(i[0], i[1] + 1, dx, dy, head_width=0.3, head_length=0.3, fc='black', ec='black')
        elif dx==0 and dy==-0.5:
            ax.arrow(i[0] + 0.5, i[1] + 1, dx, dy, head_width=0.3, head_length=0.3, fc='black', ec='black')
        elif dx==-0.5 and dy==-0.5:
            ax.arrow(i[0] + 1, i[1] + 1, dx, dy, head_width=0.3, head_length=0.3, fc='black', ec='black')
        elif dx==-0.5 and dy==0:
            ax.arrow(i[0] + 1, i[1] + 0.5, dx, dy, head_width=0.3, head_length=0.3, fc='black', ec='black')
        elif dx==-0.5 and dy==0.5:
            ax.arrow(i[0] + 1, i[1], dx, dy, head_width=0.3, head_length=0.3, fc='black', ec='black')
    

    sm = plt.cm.ScalarMappable(cmap=plt.cm.viridis, norm=plt.Normalize(vmin=min_rew, vmax=max_rew))
    cbar = plt.colorbar(sm, ax=ax, orientation='vertical')
    cbar.set_label('Reward Value')
    # plt.scatter(x, y, c=colors, s=size, cmap='Set3')
    fig.savefig(output_path, dpi=300)
    plt.close(fig)# Close the plot to free up memory

if __name__ == "__main__":
    path = "./log/maze2d-large-blue-v3.reachable_gail.final_heatmap.123/visual"
    video_p = "./log/maze2d-large-blue-v3.reachable_gail.final_heatmap.123"
    make_gif(path, video_p, algo='reachable_gail', format="mp4")

    # test function visualize_heatmap_8sections with simplified, controlled data
    # test_data = np.tile(np.linspace(0, 1, 8), (21, 21, 1))  # Simple gradient data across sectors
    # visualize_heatmap_8sections(test_data, 'output_test.png', [1], [1], ['red'], [10])