
import torch
import math

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, boxes_a, boxes_b):
        """
        boxes_a: (N, 5) [x1, y1, x2, y2, angle]
        boxes_b: (M, 5)
        Returns:
            iou_bev: (N, M)
        """
        # Slow Python Loop Implementation for Correctness
        N = boxes_a.shape[0]
        M = boxes_b.shape[0]
        iou = torch.zeros(N, M, device=boxes_a.device)
        
        # CPU for geometry
        b_a = boxes_a.cpu().tolist()
        b_b = boxes_b.cpu().tolist()
        
        for i in range(N):
            for j in range(M):
                iou[i, j] = box_iou_rotated(b_a[i], b_b[j])
                
        return iou

def get_init_inputs():
    return []

def get_inputs():
    N, M = 10, 20
    # boxes [x1, y1, x2, y2, angle]
    # Make sure x2 > x1, etc.
    ba = torch.rand(N, 5) * 10
    ba[:, 2] = ba[:, 0] + 2 # width
    ba[:, 3] = ba[:, 1] + 2 # height
    
    bb = torch.rand(M, 5) * 10
    bb[:, 2] = bb[:, 0] + 2
    bb[:, 3] = bb[:, 1] + 2
    
    return [ba, bb]

# Helper Geometry Functions (Pure Python)

def rotate_point(cx, cy, angle, px, py):
    # Rotate p around c by angle (radians)
    # CUDA: 
    # angle_cos = cos(-angle) -- Wait, CUDA says cos(-angle).
    # "rotate the point in the opposite direction of box" -- Why?
    # Usually to check if point is in axis-aligned box?
    # BUT `box_overlap` function generates "oriented corners" using:
    # `rotate_around_center` which uses `cos(angle), sin(angle)`.
    # AND `check_in_box2d` uses `cos(-angle)`.
    # `check_in_box2d` rotates point BACK to axis-aligned frame to check bounds. Correct.
    # `box_overlap` rotates corners FORWARD to world frame to intersect. Correct.
    
    # We need FORWARD rotation for intersection.
    c = math.cos(angle)
    s = math.sin(angle)
    dx = px - cx
    dy = py - cy
    nx = dx * c - dy * s + cx
    ny = dx * s + dy * c + cy
    return nx, ny

def get_corners(box):
    x1, y1, x2, y2, angle = box
    cx = (x1 + x2) / 2
    cy = (y1 + y2) / 2
    
    # Corners in axis aligned
    corners_raw = [(x1, y1), (x2, y1), (x2, y2), (x1, y2)]
    
    # Rotate
    full_corners = []
    for (px, py) in corners_raw:
        full_corners.append(rotate_point(cx, cy, angle, px, py))
    return full_corners

def box_area(box):
    return (box[2] - box[0]) * (box[3] - box[1])

def cross_product(o, a, b):
    return (a[0] - o[0]) * (b[1] - o[1]) - (a[1] - o[1]) * (b[0] - o[0])

def intersection(p1, p2, q1, q2):
    # Line p1-p2 and q1-q2
    # standard line intersection
    denom = (p1[0] - p2[0]) * (q1[1] - q2[1]) - (p1[1] - p2[1]) * (q1[0] - q2[0])
    if abs(denom) < 1e-8:
        return None
        
    t = ((p1[0] - q1[0]) * (q1[1] - q2[1]) - (p1[1] - q1[1]) * (q1[0] - q2[0])) / denom
    u = -((p1[0] - p2[0]) * (p1[1] - q1[1]) - (p1[1] - p2[1]) * (p1[0] - q1[0])) / denom
    
    if 0 <= t <= 1 and 0 <= u <= 1:
        ix = p1[0] + t * (p2[0] - p1[0])
        iy = p1[1] + t * (p2[1] - p1[1])
        return (ix, iy)
    return None

def is_inside(box, p):
    # Check if point p is inside box (rotated)
    # Rotate p by -angle around center, check AABB.
    x1, y1, x2, y2, angle = box
    cx = (x1 + x2) / 2
    cy = (y1 + y2) / 2
    
    px, py = rotate_point(cx, cy, -angle, p[0], p[1])
    
    margin = 1e-5
    return (x1 - margin <= px <= x2 + margin) and (y1 - margin <= py <= y2 + margin)

def poly_area(points):
    if len(points) < 3: return 0.0
    area = 0.0
    for i in range(len(points)):
        j = (i + 1) % len(points)
        area += points[i][0] * points[j][1]
        area -= points[j][0] * points[i][1]
    return abs(area) / 2.0

def convex_hull_sort(points):
    # Sort points around centroid
    if not points: return []
    cx = sum(p[0] for p in points) / len(points)
    cy = sum(p[1] for p in points) / len(points)
    
    points.sort(key=lambda p: math.atan2(p[1] - cy, p[0] - cx))
    return points

def box_overlap(box_a, box_b):
    # 1. Get corners
    poly_a = get_corners(box_a)
    poly_b = get_corners(box_b)
    
    inter_points = []
    
    # 2. Intersections of edges
    for i in range(4):
        p1 = poly_a[i]
        p2 = poly_a[(i+1)%4]
        for j in range(4):
            q1 = poly_b[j]
            q2 = poly_b[(j+1)%4]
            inter = intersection(p1, p2, q1, q2)
            if inter:
                inter_points.append(inter)
    
    # 3. Corners inside
    for p in poly_a:
        if is_inside(box_b, p):
            inter_points.append(p)
    for p in poly_b:
        if is_inside(box_a, p):
            inter_points.append(p)
            
    # Remove dups
    unique = []
    for p in inter_points:
        found = False
        for u in unique:
            if abs(p[0]-u[0]) < 1e-5 and abs(p[1]-u[1]) < 1e-5:
                found = True
                break
        if not found:
            unique.append(p)
            
    # Sort
    unique = convex_hull_sort(unique)
    return poly_area(unique)

def box_iou_rotated(box_a, box_b):
    inter = box_overlap(box_a, box_b)
    area_a = box_area(box_a)
    area_b = box_area(box_b)
    union = area_a + area_b - inter
    if union < 1e-8: return 0.0
    return inter / union
