import pickle as pkl
import pandas
import numpy as np
import io
import os
import rosbag
from PIL import Image
import cv2

IMAGE_SIZE = (160, 120)


def process_uw_img(msg):
    img = np.frombuffer(msg.data, dtype=np.uint8)
    pil_image = Image.open(io.BytesIO(img))
    return pil_image


def process_tartan_img(msg):
    img = ros_to_numpy(msg, output_resolution=IMAGE_SIZE) * 255
    img = img.astype(np.uint8)
    # reverse the axis order to get the image in the right orientation
    img = np.moveaxis(img, 0, -1)
    # convert rgb to bgr
    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
    img = Image.fromarray(img)
    return img


process_img_func = {"uw": process_uw_img, "tartan": process_tartan_img}


def process_images(im_list: list, img_process_func) -> list:
    """
    Process image data from a topic that publishes sensor_msgs/Image into a list of PIL images
    """
    images = []
    for img_msg in im_list:
        img = img_process_func(img_msg)
        images.append(img)
    return images


def process_odom(odom_list: list) -> np.ndarray:
    """
    Process odom data from a topic that publishes nav_msgs/Odometry into position and yaw
    """
    xy = []
    yaws = []
    for odom in odom_list:
        position = odom.pose.pose.position
        xy.append([position.x, position.y])
        orientation = odom.pose.pose.orientation
        yaw = quat_to_yaw(orientation.x, orientation.y, orientation.z, orientation.w)
        yaws.append(yaw)
    return {"position": np.array(xy), "yaw": np.array(yaws)}


def quat_to_yaw(
    x: np.ndarray,
    y: np.ndarray,
    z: np.ndarray,
    w: np.ndarray,
    angular_offset: float = 0,
) -> np.ndarray:
    """
    Convert a batch quaternion into a yaw angle
    yaw is rotation around z in radians (counterclockwise)
    """
    t3 = 2.0 * (w * z + x * y)
    t4 = 1.0 - 2.0 * (y * y + z * z)
    yaw = np.arctan2(t3, t4)
    return yaw + angular_offset


def ros_to_numpy(
    msg, nchannels=3, empty_value=None, output_resolution=None, aggregate="none"
):
    if output_resolution is None:
        output_resolution = (msg.width, msg.height)

    is_rgb = "8" in msg.encoding
    if is_rgb:
        data = np.frombuffer(msg.data, dtype=np.uint8).copy()
    else:
        data = np.frombuffer(msg.data, dtype=np.float32).copy()

    data = data.reshape(msg.height, msg.width, nchannels)

    if empty_value:
        mask = np.isclose(abs(data), empty_value)
        fill_value = np.percentile(data[~mask], 99)
        data[mask] = fill_value

    data = cv2.resize(
        data,
        dsize=(output_resolution[0], output_resolution[1]),
        interpolation=cv2.INTER_AREA,
    )

    if aggregate == "littleendian":
        data = sum([data[:, :, i] * (256**i) for i in range(nchannels)])
    elif aggregate == "bigendian":
        data = sum([data[:, :, -(i + 1)] * (256**i) for i in range(nchannels)])

    if len(data.shape) == 2:
        data = np.expand_dims(data, axis=0)
    else:
        data = np.moveaxis(data, 2, 0)  # Switch to channels-first

    if is_rgb:
        data = data.astype(np.float32) / (
            255.0 if aggregate == "none" else 255.0**nchannels
        )

    return data
