import math
import numpy as np
from scipy.spatial.transform.rotation import Rotation
from home_robot.core.interfaces import Observations
from typing import Any, Dict, List, Optional, Tuple
from sklearn.cluster import DBSCAN
from scipy.spatial import ConvexHull
import cv2
import matplotlib.pyplot as plt
import torch

def xy_to_px_tensor(points: torch.Tensor, pixels_per_meter, map_size):
    gps = points.clone()
    gps[..., 1] = -gps[..., 1]
    px = gps_to_px_tensor(gps, pixels_per_meter, map_size)
    return px

def gps_to_px_tensor(gps: torch.Tensor, pixels_per_meter: float, map_size: int):
    px = torch.flip(torch.round(gps * pixels_per_meter + map_size // 2).long(), [-1])
    return px

def xy_to_px(points: np.ndarray, pixels_per_meter, map_size):
    gps = points.copy()
    gps[..., 1] = -gps[..., 1]
    px = gps_to_px(gps, pixels_per_meter, map_size)
    return px

def px_to_xy(px, pixels_per_meter, map_size):
    points = px_to_gps(px, pixels_per_meter, map_size)
    points[..., 1] = -points[..., 1]
    return points

def gps_to_px(gps: np.ndarray, pixels_per_meter: float, map_size: int):
    px = np.rint(gps * pixels_per_meter + map_size // 2).astype(np.int32)[..., ::-1]
    return px

def px_to_gps(px: np.ndarray, pixels_per_meter: float, map_size: int):
    gps = ((px - map_size // 2) / pixels_per_meter)[..., ::-1]
    return gps

def obs_to_tf(obs: Observations):
    tf_camera_to_episodic = np.eye(4, dtype=np.float32)
    pan, tilt = obs.joint[[8, 9]]
    tf_camera_to_episodic[:3, :3] = Rotation.from_euler('ZYX', [
        obs.compass[0] - np.pi / 2 + pan, -tilt, 0
    ]).as_matrix() # same
    tf_camera_to_episodic[:3, 3] = obs.camera_pose[:3, 3]
    tf_camera_to_episodic[:2, 3] = -tf_camera_to_episodic[:2, 3]  # flip x and y

    # print('obs.camera_pose', obs.camera_pose)
    # print('tf_camera_to_episodic', tf_camera_to_episodic)
    return tf_camera_to_episodic
    
def gen_poses_to_tf_om(pose: np.ndarray, obs: Observations): #ZYX, XZY, (YXZ), ZXY, (XYZ), YZX
    transform = np.array([[0, 0, 1, 0], [-1, 0, 0, 0], [0, -1, 0, 0], [0, 0, 0, 1]], dtype=np.float32)
    tf_camera_to_episodic = np.eye(4, dtype=np.float32) 
    tf_camera_to_episodic[:3, :3] = Rotation.from_euler('ZYX', [pose[2] - np.pi / 2, 0.5246, 0]).as_matrix() # same
    tf_camera_to_episodic = tf_camera_to_episodic@transform

    tf_camera_to_episodic[0, 3] = pose[1]
    tf_camera_to_episodic[1, 3] = -pose[0]
    tf_camera_to_episodic[2, 3] = obs.camera_pose[2, 3]

    return tf_camera_to_episodic


def gs_to_gps(obj_locations: torch.Tensor):
    obj_locations = obj_locations.detach().cpu().numpy()  
    obj_locations = np.stack([
        -obj_locations[..., 1],
        obj_locations[..., 0]
    ], axis=-1)
    return obj_locations

def detect_receptacles(points: np.ndarray, eps: float = 0.2, min_samples: int = 5, step: int = 10):
    """Detect distinct receptacles using DBSCAN clustering on downsampled 2D (x, y) data."""
    
    if len(points) < 100:
        step = 1
    points_2d = points[:, :2][::step]
    clustering = DBSCAN(eps=eps, min_samples=min_samples).fit(points_2d)
    labels = clustering.labels_
    num_clusters = len(set(labels)) - (1 if -1 in labels else 0)
    
    cluster_means = []
    clusters = []
    for cluster in range(num_clusters):
        cluster_points = points_2d[labels == cluster]
        mean_point = np.mean(cluster_points, axis=0)
        clusters.append(cluster_points)
        cluster_means.append(mean_point)
    
    return labels, num_clusters, np.array(cluster_means), clusters, step

def get_poses_around_object(
    gps_object: np.ndarray,
    num_poses: int,
    radius: float,
    pixel_per_meter: float,
    navigable_map: Optional[np.ndarray],
    gps_agent: Optional[np.ndarray],
) -> List[np.ndarray]:
    """Generate poses around object if it exists."""
    angles = np.linspace(0, 2 * np.pi, num_poses, endpoint=False)
    gps_poses = np.array([[
        gps_object[0],
        gps_object[1],
        -np.pi
    ]]) + np.stack([
        radius * np.cos(angles),
        radius * np.sin(angles),
        angles,
    ], axis=1) # (num_positions, 3)

    if len(gps_poses) and navigable_map is not None:
        pxs = gps_to_px(gps_poses[:, :2], pixel_per_meter, navigable_map.shape[0])
        is_navigable = navigable_map[pxs[:, 0], pxs[:, 1]]
        gps_poses = gps_poses[is_navigable]

        #### DEBUG visualization
        # print("SIZE DEBUG get poses: ", pixel_per_meter, navigable_map.shape[0])
        # navigable_map_pts = (navigable_map.copy().astype(np.uint8))*255
        # navigable_map_pts = np.dstack([navigable_map_pts,navigable_map_pts,navigable_map_pts])
        # navigable_map_pts[pxs[:, 0], pxs[:, 1],:] = (255,0,0)
        # navigable_map_pts[pxs[is_navigable, 0], pxs[is_navigable, 1],:] = (0,0,255)

        # if gps_agent is not None:
        #     agent_pxs = gps_to_px(gps_agent[None, :2], pixel_per_meter, navigable_map.shape[0])
        #     navigable_map_pts[agent_pxs[:, 0], agent_pxs[:, 1],:] = (0,255,0)

        # plt.imsave('navigable_map_pts.png', np.flipud(navigable_map_pts))
        ####

    if len(gps_poses) and gps_agent is not None:
        dists = np.linalg.norm(gps_agent[None, :2] - gps_poses[:, :2], axis=1)
        min_dist_idx = np.argmin(dists)
        gps_poses = np.roll(gps_poses, -min_dist_idx, axis=0)
    
    return gps_poses

def plot_clusters(obj_location: np.ndarray, cluster_means: np.ndarray):
    # Plotting the object location
    plt.figure()
    plt.scatter(obj_location[:, 0], obj_location[:, 1], c='r')
    plt.scatter([mean[0] for mean in cluster_means], [mean[1] for mean in cluster_means], c='b', marker='x',  s=200)
    plt.title('Object Locations')
    plt.xlabel('X coordinate')
    plt.ylabel('Y coordinate')
    plt.grid(True)
    plt.savefig('object_locations.png')
    plt.close()

def get_bounds(map: np.ndarray, margin: int):
    if not np.any(map):
        return np.array([0, map.shape[0], 0, map.shape[1]])
    ys = np.where(np.any(map, axis=1))[0]
    xs = np.where(np.any(map, axis=0))[0]
    bounds = np.array([ys[0], ys[-1]+1, xs[0], xs[-1]+1])
    bounds[::2] -= margin
    bounds[1::2] += margin
    bounds[:2] = np.clip(bounds[:2], 0, map.shape[0])
    bounds[2:] = np.clip(bounds[2:], 0, map.shape[1])

    return bounds