import torch
from time import time
import matplotlib.pyplot as plt


def compute_overlap_area_torch(X, Y):
    # # Convert to sorted torch tensors
    # X = torch.sort(torch.tensor(X, device=device, dtype=torch.float32))[0]
    # Y = torch.sort(torch.tensor(Y, device=device, dtype=torch.float32))[0]
    X = torch.sort(X)[0]
    Y = torch.sort(Y)[0]
    n, m = X.numel(), Y.numel()
    # print(f"X shape: {X.shape}, Y shape: {Y.shape}")

    # Combine breakpoints (unique values from X and Y)
    breakpoints = torch.cat((X, Y)).unique(sorted=True)
    
    # Add -inf and +inf for full coverage
    pad = torch.tensor([-float('inf'), float('inf')], device=X.device)
    breakpoints = torch.cat((pad[0:1], breakpoints, pad[1:]))
    # print(f"Breakpoints shape: {breakpoints.shape}")

    # Compute midpoints of intervals
    lefts = breakpoints[:-1]
    rights = breakpoints[1:]
    midpoints = (lefts + rights) / 2

    # Compute empirical CDFs at midpoints
    F_x = torch.searchsorted(X, midpoints, right=True).float() / n
    F_y = torch.searchsorted(Y, midpoints, right=True).float() / m

    # Compute min(F_x, 1 - F_y)
    min_vals = torch.minimum(F_x, 1 - F_y)

    # Compute widths of intervals and exclude infinities
    widths = rights - lefts
    mask = torch.isfinite(widths)
    area = (min_vals[mask] * widths[mask]).sum().item()

    return area
