import cv2
from scipy.ndimage import gaussian_filter
import numpy as np
import torch

from matplotlib import pyplot as plt


def plot_heatmap(
    map_img,
    energies,
    positions,
    device,
    normalize_cnt=False,
    normalize_max=True,
    scatter_pos=None,
    acc=True,
    alpha=0.8,
    scatter_scale=10,
    blur=True,
    wall_color=[101, 113, 145],
    scatter_color=[255, 0, 0],
    color_map=None,
    blur_sigma=2,
):

    w, h = map_img.shape[:2]

    heatmap = torch.zeros((w, h), device=device)
    if energies is None:
        energies = torch.ones_like(positions[:, 0]).float()

    if positions is not None:
        positions = torch.clamp(positions, min=0, max=max(w, h) - 1)
        x_indices = (positions[:, 0]).int()
        y_indices = (positions[:, 1]).int()

        heatmap.index_put_((x_indices, y_indices), energies.reshape((-1,)), accumulate=acc)

        if normalize_cnt:
            cnt = torch.ones_like(heatmap).to(device).int()
            cnt.index_put_((x_indices, y_indices), torch.ones_like(x_indices).to(device).int(), accumulate=acc)
            heatmap = heatmap / cnt

        hm = heatmap.detach().cpu()

        if blur:
            hm = gaussian_filter(hm, sigma=blur_sigma)

        if normalize_max:
            hm = hm / hm.max()

        heatmap = 255 * hm

    heatmap = np.uint8(heatmap)

    if color_map is not None:
        heatmap = cv2.applyColorMap(heatmap, color_map)
    else:
        heatmap = cv2.cvtColor(heatmap, cv2.COLOR_GRAY2BGR)
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_RGB2BGR)

    black_pixels_mask = np.all(map_img == [0, 0, 0], axis=-1)
    white_pixels_mask = np.all(map_img == [255, 255, 255], axis=-1)

    map_img[black_pixels_mask] = heatmap[black_pixels_mask]
    map_img[white_pixels_mask] = wall_color

    map_img = cv2.resize(map_img, (500, 500), interpolation=0)

    if scatter_pos is not None:
        factor = 500 / w
        scatter_pos = scatter_pos.reshape((-1, 2))
        for pos in scatter_pos:
            overlay = map_img.copy()
            cv2.circle(overlay,(int(pos[1] * factor + factor / 2),int(pos[0] * factor + factor / 2)),scatter_scale,scatter_color,-1,)
            map_img = cv2.addWeighted(overlay, alpha, map_img, 1 - alpha, 0)

    return map_img


def plot_trajectory(env, trajectory, goal_pos=None, scatter_scale=5, goal_color=[0, 255, 255], start_color=[0, 255, 0], agent_color=[255, 0, 0]):
    map_img = env.render_map()
    white_pixels_mask = np.all(map_img == [255, 255, 255], axis=-1)
    map_img[white_pixels_mask] = [101, 113, 145]

    w, h = map_img.shape[:2]
    overlay = map_img.copy()

    if goal_pos is not None:
        for pos in goal_pos:
            cv2.circle(overlay,(int(pos[1]), int(pos[0])),scatter_scale,goal_color,-1)

    if trajectory is not None:
        for pos in trajectory:
            cv2.circle(overlay,(int(pos[1]), int(pos[0])),scatter_scale,agent_color,-1)
        cv2.circle(overlay,(int(trajectory[0, 1]), int(trajectory[0, 0])),scatter_scale,start_color,-1) # Start position
    return overlay


def plot_goal(env, goal_abstraction, buffer, goal, goal_position=None, blur_sigma=2):
    device = goal.device
    transitions = buffer.sample(10000, 1, to_device=device)
    goals = goal_abstraction.goal_encoder.encode(transitions["observation"]).detach()

    goal = goal.repeat((10000, 1))

    energies = torch.sigmoid(
        goal_abstraction.subset_energy.energy(goals, goal))
    img = plot_heatmap(device=device, blur_sigma=blur_sigma, map_img=env.render_map(), energies=energies, scatter_scale=5, 
                       positions=transitions["position"].reshape((-1, 2)), scatter_pos=goal_position, scatter_color=[0, 255, 255], 
                       color_map=cv2.COLORMAP_INFERNO, normalize_cnt=True)
    return img
