import math
import numpy as np


CAR_LENGTH = 4.084
CAR_WIDTH = 1.85
MAP_METER = 49.75
GRID_SIZE = 0.5
BEV_H = 200
BEV_W = 200
# ego-system coordinates to BEV feature map coordinates

def location_to_pixel_coordinate(x, y):
    x_min, x_max, y_min, y_max = -50, 50, -50, 50
    H, W = BEV_H, BEV_W
    X = int(W * (x - x_min + GRID_SIZE/2.) / (x_max - x_min)) # TODO: double check this. Also check if there is a huge difference between this and the original code
    Y = int(H * (y - y_min + GRID_SIZE/2.) / (y_max - y_min))
    valid = True
    if X < 0 or X >= W or Y < 0 or Y >= H:
        valid = False
    return X, Y, valid


def pixel_coordinate_to_location(X, Y):
    # A tensor based on pixel-coordinates would be of the form (...,200,200) -> (...,y,x)
    # In the matrix, the x pixel is represented by the second '200' entry
    # The y pixel is represented by the first '200' entry
    x_min, x_max, y_min, y_max = -50, 50, -50, 50 # TODO: maybe this should be 49.75?
    H, W = BEV_H, BEV_W
    
    # Add 0.5 to get the center of the pixel
    x = x_min + (X + 0.5) * (x_max - x_min) / W
    y = y_min + (Y + 0.5) * (y_max - y_min) / H
    
    valid = True
    if x < x_min or x > x_max or y < y_min or y > y_max:
        valid = False

    return x, y, valid


def rotate_bbox(x, y, dx, dy, theta):
    # Step 1: Calculate the center of the box
    cx, cy = x, y

    # Step 2: Calculate the coordinates of the corners relative to the center
    corners = [(dx / 2, dy / 2), (dx / 2, -dy / 2), (-dx / 2, -dy / 2), (-dx / 2, dy / 2)]

    # Step 3: Rotate each corner
    rotated_corners = []
    for px, py in corners:
        x_prime = px * math.cos(theta) + py * math.sin(theta)
        y_prime = - px * math.sin(theta) + py * math.cos(theta)
        rotated_corners.append((x_prime, y_prime))            

    # Step 4 Translate the rotated corners back
    final_corners = [[cx + x_prime, cy + y_prime] for x_prime, y_prime in rotated_corners]

    return np.array(final_corners)

if __name__ == "__main__":
    x, y = 10, 10
    X, Y, valid = location_to_pixel_coordinate(x, y)

    print(f"Location ({x}, {y}) is mapped to pixel coordinate ({X}, {Y})")

    x, y, valid = pixel_coordinate_to_location(X, Y)

    print(f"Pixel coordinate ({X}, {Y}) is mapped to location ({x}, {y})")