import numpy as np
import torch
import trimesh
import cv2
import itertools
import flow_vis
import copy
import matplotlib.pyplot as plt
from PIL import Image
from scipy.spatial.transform import Rotation

from typing import Any, List, Tuple, Dict
from plot.utils.processing import create_ego_box3d, project_box3d_to_bev, project_box3d_to_image, create_face_vertices, create_ego_box3d_rotated
from plot.utils.geometry import geotrf


Cam2Trimesh = np.array([
    [1, 0, 0, 0],
    [0, -1, 0, 0],
    [0, 0, -1, 0],
    [0, 0, 0, 1]
])

OPENGL = np.array([[1, 0, 0, 0],
                   [0, -1, 0, 0],
                   [0, 0, -1, 0],
                   [0, 0, 0, 1]])

colors_discrete = [
    "blue", "red", "green", "purple", "orange", "cyan", "magenta", "brown",
    "lime", "pink", "yellow", "teal", "gold", "indigo", "gray", "black"
]
    

CAM_COLORS = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 0, 255), (255, 204, 0), (0, 204, 204),
              (128, 255, 255), (255, 128, 255), (255, 255, 128), (0, 0, 0), (128, 128, 128)]


class ColorMap2d:
    def __init__(self, filename=None):
        self._colormap_file = filename 
        self._img = (plt.imread(self._colormap_file)*255).astype(np.uint8)
        
        self._height = self._img.shape[0]
        self._width = self._img.shape[1]

    def __call__(self, X):
        assert len(X.shape) == 2
        output = np.zeros((X.shape[0], 3), dtype=np.uint8)
        for i in range(X.shape[0]):
            x, y = X[i, :]
            xp = int((self._width-1) * x)
            yp = int((self._height-1) * y)
            xp = np.clip(xp, 0, self._width-1)
            yp = np.clip(yp, 0, self._height-1)
            output[i, :] = self._img[yp, xp]
        return output
    

def get_2d_colors(xys, H, W):
    import sys
    sys.path.insert(0, "alltracker")
    N,D = xys.shape
    assert(D==2)
    bremm = ColorMap2d("./alltracker/utils/bremm.png")
    new_xys = copy.deepcopy(xys.astype(np.float32))
    new_xys[:,0] /= float(W-1)
    new_xys[:,1] /= float(H-1)
    colors = bremm(new_xys)
    # print('colors', colors)
    # colors = (colors[0]*255).astype(np.uint8) 
    # colors = (int(colors[0]),int(colors[1]),int(colors[2]))
    return colors


def generate_distinct_colors(n_colors):
    # Combine multiple colormaps to generate more colors if needed
    # colormaps = ['Paired', 'Accent', 'Dark2', 'Set1', 'Set2', 'Set3', 'tab10', 'tab20']
    # colormaps = ['tab20b', 'tab20c', 'Pastel1', 'Paired', "Set3", ]
    colormaps = ['Set1', 'Set2', 'Set3', 'tab20', 'Paired']
    colors = []
    for cmap_name in colormaps:
        cmap = plt.get_cmap(cmap_name)
        colors.extend([cmap(i) for i in range(cmap.N)])
        if len(colors) >= n_colors:
            break
    # Limit to the requested number of colors
    colors = colors[:n_colors]
    # Convert colors to RGB format (scale 0-255)
    colors_rgb = np.array([np.array(c[:3])*255 for c in colors])
    return colors_rgb.astype(int)


def get_color(ind, hex=False):
    colors = [(255, 0, 0),
        (0, 255, 0),
        (0, 0, 255),
        (230, 150, 140),
        (70, 70, 70),
        (102, 102, 156),
        (0, 0, 90),
        (190, 153, 153),
        (244, 35, 232),
        (150, 120, 90),
        (220, 220, 0),
        (107, 142, 35),
        (180, 165, 180),
        (152, 251, 152),
        (70, 130, 180),
        (220, 20, 60),
        (255, 0, 0),
        (250, 170, 160),
        (0, 0, 142),
        (0, 0, 70),
        (150, 100, 100),
        (0, 60, 100),
        (0, 0, 110),
        (153, 153, 153),
        (0, 80, 100),
        (250, 170, 30),
        (0, 0, 230),
        (119, 11, 32),
        (0, 0, 142), ]

    color = colors[ind % len(colors)]

    if hex:
        return '#%02x%02x%02x' % (color[0], color[1], color[2])
    else:
        return color
    

def visualize_pcd(pts, colors=None, mask=None):
    if isinstance(pts, list):
        if colors is None:
            colors = []
            rgbs = generate_distinct_colors(len(pts))
            for i in range(len(pts)):
                color = np.ones_like(pts[i])
                # color *= np.array(rgbs[i])
                color *= np.array(get_color(i))
                colors.append(color.astype(np.uint8))
        
        pts = np.concatenate(pts, axis=0)
        colors = np.concatenate(colors, axis=0)
        
    if colors is None:
        colors = np.ones_like(pts)
        colors *= np.array([0, 255, 0])
        colors = colors.astype(np.uint8)
    if mask is not None:
        pts = pts[mask]
        colors = colors[mask]
    scene = trimesh.Scene()
    pcd = trimesh.PointCloud(pts.reshape(-1, 3), colors.reshape(-1, 3))
    scene.add_geometry(pcd, transform=Cam2Trimesh)
    scene.show(line_settings={"point_size": 3})


def draw_point_tracks(image, points, visibility, colors, radius=2, conf_thr=0.1, inds=None):
    N, D = points.shape

    if inds is not None:
        points = points[inds]
        visibility = visibility[inds]
        # confs = confs[inds]
        colors = colors[inds]
    
    for i in range(N):
        xy = points[i].round().astype(int)
        color = tuple(map(int, colors[i]))

        if visibility[i] > 0.5:
            thickness = -1  # filled in
        else:
            thickness = 1   # hollow

        # if confs[i] > conf_thr:
        #     cv2.circle(image, (xy[0], xy[1]), radius, color, thickness)
        cv2.circle(image, (xy[0], xy[1]), radius, color, thickness)


def draw_box2d(image: np.ndarray, box: list, color: tuple = (0, 0, 255)):
    image = image[..., ::-1].astype(np.uint8)
    cv2.rectangle(image, (box[0], box[1]), (box[2], box[3]), color, 2, lineType=cv2.LINE_AA)
    return image[..., ::-1]


class2idx = {
    "Car": 0,

    "Pedestrian": 1,
    "Person": 1,

    "Motorcycle": 2,
    "Cycle": 2,
    "Bicycle": 2,

    "Carpet": 3,
    "Table": 4,
    "Sofa": 5,
    "Basket": 6,
    "Vase": 7,
    "Boat": 9,
    "Bus": 10,
}

def visualize_box3d_on_image(image: np.ndarray, boxes: List[float], idxs: List[bool], 
                             intrinsic: np.ndarray, rots=None, labels=None, save_fig=None):
    fig = plt.figure(figsize=(15, 10))
    ax_img = fig.add_subplot(1, 1, 1)
    ax_img.set_xticks([])
    ax_img.set_yticks([])

    colors = generate_distinct_colors(50)

    for obj_idx in range(len(boxes)):
        obj_id = idxs[obj_idx]
        # color = colors[class2idx[labels[obj_idx]]].tolist()
        color = get_color(obj_idx)
        box3d = create_ego_box3d_rotated(*boxes[obj_idx], rots[obj_idx])  # (8, 3)
        # visualize projected boxes on image
        box3d_proj = project_box3d_to_image(box3d, intrinsic)
        # if np.any(box3d_proj[0] < 0) or np.any(box3d_proj[0] > W) or np.any(box3d_proj[1] < 0) or np.any(box3d_proj[1] > H):
        #     continue
        verts2d = create_face_vertices(box3d_proj)
        if labels is not None:
            draw_3d_box_with_label(image, obj_id, verts2d, color, labels[obj_idx], thickness=2)
        else:
            draw_3d_box(image, verts2d, color, thickness=2)

    ax_img.imshow(image)
    fig.tight_layout(pad=4, h_pad=None, w_pad=None)
    if save_fig is not None:
        fig.savefig(str(save_fig), dpi=150)
        plt.close()
    else:
        plt.show()


def draw_3d_box_with_label(im, idx, verts, color=(0, 200, 200), label=None, thickness=1):
    for face_idx in [0, 4, 8, 12]:
        for i in range(3):
            v1 = verts[face_idx+i]
            v2 = verts[face_idx+i+1]
            cv2.line(im, (int(v1[0]), int(v1[1])), (int(v2[0]), int(v2[1])), color, thickness, cv2.LINE_AA)

    x_min, y_min = np.min(verts, axis=0)
    x_max, y_max = np.max(verts, axis=0)
    x_min, y_min, x_max, y_max = int(x_min), int(y_min), int(x_max), int(y_max)
    font = cv2.FONT_HERSHEY_SIMPLEX
    font_scale = 0.5
    font_thickness = 1
    text_size = cv2.getTextSize(label, font, font_scale, font_thickness)[0]

    text_x = x_min
    text_y = y_min - 10 if y_min - 10 > text_size[1] else y_min + text_size[1] + 10
    if text_x < 0 or text_x > 960:
        text_x = 0
    
    # Draw a filled rectangle as background for the text
    cv2.rectangle(im, (text_x, text_y - text_size[1] - 5), 
                  (text_x + text_size[0] + 5, text_y + 5), color, -1)

    # Put the label text
    # label = f"{label}#{idx}"
    cv2.putText(im, label, (text_x + 5, text_y), font, font_scale, (255, 255, 255), font_thickness, cv2.LINE_AA)
    # draw_transparent_polygon(im, verts[0:4, :], blend=0.7, color=color)


def visualize_bev_and_box3d(image: np.ndarray, boxes: List[float], movings: List[bool], 
                            intrinsic: np.ndarray, points: List[np.ndarray] = None, 
                            labels=None,
                            save_fig=None):
    arrow_length = 1.5

    fig = plt.figure(figsize=(15, 10))
    ax_img = fig.add_subplot(1, 2, 1)
    ax_bev = fig.add_subplot(1, 2, 2)
    ax_img.set_xticks([])
    ax_img.set_yticks([])
    ax_bev.set_xlim(-25, 25)
    ax_bev.set_ylim(0, 80)
    ax_bev.set_aspect('equal')
    ax_bev.set_xlabel("X (meters) (Left ⬅️➡️ Right)", fontsize=12)
    ax_bev.set_ylabel("Z (meters) (Near ⬆️⬇️ Far)", fontsize=12)
    ax_bev.set_title("Bird-Eye View - All Queries & Clouds", fontsize=14, fontweight='bold')
    # ax.legend(fontsize=12, loc="upper right", markerscale=2)
    ax_bev.grid(True)

    H, W, _ = image.shape

    for obj_idx in range(len(boxes)):
        color = get_color(obj_idx)
        color_float = tuple(map(lambda x: x/255, color))
        box3d = create_ego_box3d(*boxes[obj_idx])  # (8, 3)
        box_bev = project_box3d_to_bev(box3d)   # (4, 2)
        box_bev = np.vstack((box_bev, box_bev[0]))

        # visualize points, box and direction on bev
        if points is not None:
            ax_bev.scatter(points[obj_idx][:, 0], points[obj_idx][:, 2], s=1, color=color_float, alpha=0.8)
        
        ax_bev.plot(box_bev[:, 0], box_bev[:, 1], color=color_float, linewidth=2)

        if movings is not None:
            if movings[obj_idx]:
                x2 = boxes[obj_idx][0] + arrow_length * np.cos(-boxes[obj_idx][-1])
                y2 = boxes[obj_idx][2] + arrow_length * np.sin(-boxes[obj_idx][-1])
                ax_bev.annotate("", xytext=(boxes[obj_idx][0], boxes[obj_idx][2]), xy=(x2, y2),
                                    arrowprops=dict(facecolor="white", shrink=0, headwidth=6, headlength=6, width=3))
        
        # visualize projected boxes on image
        box3d_proj = project_box3d_to_image(box3d, intrinsic)
        # if np.any(box3d_proj[0] < 0) or np.any(box3d_proj[0] > W) or np.any(box3d_proj[1] < 0) or np.any(box3d_proj[1] > H):
        #     continue
        verts2d = create_face_vertices(box3d_proj)
        draw_3d_box(image, verts2d, color, thickness=2)

    ax_img.imshow(image)
    fig.tight_layout(pad=4, h_pad=None, w_pad=None)
    if save_fig is not None:
        fig.savefig(str(save_fig), dpi=150)
        plt.close()
    else:
        plt.show()
    

def visualize_pseudo_lidar_bev(points, gt_centers=None, save_path=None, limit_axes=False):
    # fig, axes = plt.subplots(1, 2)
    plt.figure(figsize=(12, 12))
    num_queries = len(points)

    color_cycle = itertools.cycle(colors_discrete[:num_queries])

    query_color_map = {}
    for query_idx, obj_points in enumerate(points):
        if query_idx not in query_color_map:
            query_color_map[query_idx] = next(color_cycle)
        
        plt.scatter(obj_points[:, 0], obj_points[:, 2], s=1, c=query_color_map[query_idx], alpha=1.0)
    
    if gt_centers is not None:
        plt.scatter(gt_centers[:, 0], gt_centers[:, 2], s=200, facecolors='white', edgecolors='black', linewidths=2, alpha=1.0)
    
    plt.title("Bird-Eye View - All Queries & Clouds", fontsize=14, fontweight='bold')
    plt.xlabel("X (meters) (Left ⬅️➡️ Right)", fontsize=12)
    plt.ylabel("Z (meters) (Near ⬆️⬇️ Far)", fontsize=12)

    if limit_axes:
        plt.xlim([-45, 40])
        plt.ylim([0, 85])

    plt.grid(True)
    plt.legend(fontsize=12, loc="upper right", markerscale=2)

    if save_path is not None:
        plt.savefig(save_path)
        plt.close()
    else:
        plt.show()


def bev_plot_with_direction(points, centroid, vectors, principal_axis):
    plt.scatter(points[:, 0], points[:, 1], 10, c='tab:gray', zorder=1)
    if principal_axis == 0: 
        plt.quiver(centroid[0], centroid[1], vectors[0, 0], vectors[1, 0], zorder=2, scale=5.0, color='r')
        plt.quiver(centroid[0], centroid[1], vectors[0, 1], vectors[1, 1], zorder=2, scale=5.0, color='b')
    else:   # z-axis = blue
        plt.quiver(centroid[0], centroid[1], vectors[0, 0], vectors[1, 0], zorder=2, scale=5.0, color='b')
        plt.quiver(centroid[0], centroid[1], vectors[0, 1], vectors[1, 1], zorder=2, scale=5.0, color='r')
    
    plt.scatter(centroid[0], centroid[1], 200, c='g', zorder=3)
    plt.gca().set_aspect('equal')
    plt.show()

def draw_3d_box(im, verts, color=(0, 200, 200), thickness=1):
    for face_idx in [0, 4, 8, 12]:
        for i in range(3):
            v1 = verts[face_idx+i]
            v2 = verts[face_idx+i+1]
            cv2.line(im, (int(v1[0]), int(v1[1])), (int(v2[0]), int(v2[1])), color, thickness, cv2.LINE_AA)

    # draw_transparent_polygon(im, verts[0:4, :], blend=0.7, color=color)






def draw_transparent_polygon(im, verts, blend=0.5, color=(0, 255, 255)):
    mask = get_polygon_grid(im, verts[:4, :])
    im[mask, 0] = im[mask, 0] * blend + (1 - blend) * color[0]
    im[mask, 1] = im[mask, 1] * blend + (1 - blend) * color[1]
    im[mask, 2] = im[mask, 2] * blend + (1 - blend) * color[2]


def get_polygon_grid(im, poly_verts):
    from matplotlib.path import Path

    nx = im.shape[1]
    ny = im.shape[0]
    #poly_verts = [(1, 1), (5, 1), (5, 9), (3, 2), (1, 1)]

    # Create vertex coordinates for each grid cell...
    # (<0,0> is at the top left of the grid in this system)
    x, y = np.meshgrid(np.arange(nx), np.arange(ny))
    x, y = x.flatten(), y.flatten()

    points = np.vstack((x, y)).T

    path = Path(poly_verts)
    grid = path.contains_points(points)
    grid = grid.reshape((ny, nx))

    return grid


def visualize_cams(poses, focal, cam_size=0.05):
    import trimesh
    scene = trimesh.Scene()
    for i, pose in enumerate(poses):
        cam_color = CAM_COLORS[i % len(CAM_COLORS)]
        add_scene_cam(scene, pose, cam_color, focal, screen_width=cam_size)
    scene.show()

def add_scene_cam(scene, pose_c2w, edge_color, focal=None, screen_width=0.03):
    H = W = focal / 1.1

    # create fake camera
    height = max( screen_width/10, focal * screen_width / H )
    width = screen_width * 0.5**0.5
    rot45 = np.eye(4)
    rot45[:3, :3] = Rotation.from_euler('z', np.deg2rad(45)).as_matrix()
    rot45[2, 3] = -height  # set the tip of the cone = optical center
    aspect_ratio = np.eye(4)
    aspect_ratio[0, 0] = W/H
    transform = pose_c2w @ OPENGL @ aspect_ratio @ rot45
    cam = trimesh.creation.cone(width, height, sections=4)  # , transform=transform)

    # this is the camera mesh
    rot2 = np.eye(4)
    rot2[:3, :3] = Rotation.from_euler('z', np.deg2rad(2)).as_matrix()
    vertices = np.r_[cam.vertices, 0.95*cam.vertices, geotrf(rot2, cam.vertices)]
    vertices = geotrf(transform, vertices)
    faces = []
    for face in cam.faces:
        if 0 in face:
            continue
        a, b, c = face
        a2, b2, c2 = face + len(cam.vertices)
        a3, b3, c3 = face + 2*len(cam.vertices)

        # add 3 pseudo-edges
        faces.append((a, b, b2))
        faces.append((a, a2, c))
        faces.append((c2, b, c))

        faces.append((a, b, b3))
        faces.append((a, a3, c))
        faces.append((c3, b, c))

    # no culling
    faces += [(c, b, a) for a, b, c in faces]

    cam = trimesh.Trimesh(vertices=vertices, faces=faces)
    cam.visual.face_colors[:, :3] = edge_color
    scene.add_geometry(cam)


