import torch
import numpy as np


class Line:
    def __init__(self, p0, v, device="cpu"):
        """p = p0 + t * v"""
        self.device = device
        self.v = self._totensor(v)
        assert self.v.ndim == 1 and self.v.shape[0] == 2, "v should be a 2D vector"
        self.p0 = self._totensor(p0)
        assert self.p0.ndim == 1 and self.p0.shape[0] == 2, "p0 should be a 2D vector"

    def distance(self, point, norm_p=2, signed=False):
        """Distance from point to line
        If signed, return the signed distance. Left and above the line is negative. Zero is on the line."""
        point = self._totensor(point)
        proj = self.project(point)
        d = torch.norm(point - proj, p=norm_p, dim=-1)
        if signed:
            d *= self.get_sign(point, proj)
        return d

    def project(self, point):
        """Project point onto line"""
        point = self._totensor(point)
        return self.p0 + self._batched_dot(self.v, (point - self.p0), keepdim=True) / (torch.norm(self.v, dim=-1) ** 2) * self.v

    def get_sign(self, point, proj):
        normal = point - proj
        #torch.sign(torch.cross(normal, self.v))
        res = torch.sign(self._batched_cross(normal, self.v))
        res[res == 0] = 1
        return res

    def get_t(self, point):
        # Make sure the point is on the line
        assert torch.all(self._batched_cross(point - self.p0, self.v) < 1e-6), f"Expected point to be on the line. Cross product: {self._batched_cross(point - self.p0, self.v).max()}"
        return self._batched_dot(point - self.p0, self.v, keepdim=True) / torch.dot(self.v, self.v)

    @staticmethod
    def _batched_dot(a, b, keepdim=False):
        return torch.sum(a * b, dim=-1, keepdim=keepdim)

    @staticmethod
    def _batched_cross(a, b):
        return a[..., 0] * b[..., 1] - a[..., 1] * b[..., 0]

    def _totensor(self, a):
        if not isinstance(a, torch.Tensor):
            a = torch.tensor(a).to(self.device)
        if isinstance(a, torch.LongTensor):
            a = a.float()
        return a


class LineSegment(Line):
    def __init__(self, start, end, device="cpu"):
        self.device = device
        start = self._totensor(start)
        end = self._totensor(end)
        super().__init__(p0=start, v=end-start, device=device)
        self.start = start
        self.end = end
        self.t_start = 0
        self.t_end = 1

    def distance(self, point, norm_p=2, signed=False):
        point = self._totensor(point)
        d = super().distance(point, norm_p=norm_p, signed=False)
        proj = self.project(point)
        t = self.get_t(proj)
        mask_small = (t < self.t_start).reshape(d.shape)
        mask_large = (t > self.t_end).reshape(d.shape)
        d_small = torch.norm(point - self.start, p=norm_p, dim=-1)
        d_large = torch.norm(point - self.end, p=norm_p, dim=-1)
        d = d_small * mask_small + d_large * mask_large + d * ((~mask_small) & (~mask_large))
        if signed:
            d *= self.get_sign(point, proj)
        return d


class Rectangle:
    def __init__(self, xa, ya, xb, yb, device="cpu"):
        assert xa < xb and ya < yb
        self.sides = {
            "left": LineSegment(start=(xa, ya), end=(xa, yb), device=device),
            "top": LineSegment(start=(xa, yb), end=(xb, yb), device=device),
            "right": LineSegment(start=(xb, yb), end=(xb, ya), device=device),
            "bottom": LineSegment(start=(xb, ya), end=(xa, ya), device=device),
        }

    def distance(self, point, norm_p=2):
        d_sides = {k: v.distance(point, norm_p=norm_p, signed=True) for k, v in self.sides.items()}
        d_sides = torch.stack(list(d_sides.values()), dim=-1)
        d = torch.abs(d_sides).min(dim=-1)[0]
        d[torch.all(d_sides > 0, dim=-1)] *= (-1)
        return d

    def __repr__(self):
        return f"Rectangle({self.sides['left'].start[0]}, {self.sides['left'].start[1]}, {self.sides['right'].start[0]}, {self.sides['right'].start[1]})"


class Checkerboard:
    def __init__(self, xa, ya, xb, yb, n, device="cpu"):
        assert n % 2 == 0
        x_unit = (xb - xa) / n
        y_unit = (yb - ya) / n
        self.rectangles = []
        for i in range(n):
            for j in range(i%2, n, 2):
                x = xa + j * x_unit
                y = ya + i * y_unit
                self.rectangles.append(Rectangle(x, y, x + x_unit, y + y_unit, device=device))

    def distance(self, point, norm_p=2):
        d = torch.stack([r.distance(point, norm_p=norm_p) for r in self.rectangles], dim=-1)
        d = d.min(dim=-1)[0]
        return d


def compute_distance_to_boundries(x, norm_p=2):
    device = x.device if isinstance(x, torch.Tensor) else "cpu"
    checkerboard = Checkerboard(-1, -1, 1, 1, 4, device=device)
    return checkerboard.distance(x, norm_p=norm_p)


def compute_infraction_differentiable(x, norm_p=2):
    res = compute_distance_to_boundries(x, norm_p=norm_p)
    res[res < 0] = 0
    return res

@torch.no_grad()
def compute_infraction(x, slack=None):
    if slack is None:
        return compute_distance_to_boundries(x, norm_p=1) > 0
    else:
        return compute_distance_to_boundries(x, norm_p=1) > -slack

def get_dataset_min_distance(dataloader):
    min_distance = float("inf")
    for batch, _ in dataloader:
        distances = compute_distance_to_boundries(batch)
        if isinstance(distances, torch.Tensor):
            distances = distances.cpu().numpy()
        distances = -distances[distances < 0]
        min_distance = min(min_distance, distances.min())
    return min_distance


if __name__ == "__main__":
    import matplotlib.pyplot as plt
    import matplotlib as mpl

    checkerboard = Checkerboard(-1, -1, 1, 1, 4)

    xx = np.linspace(-2, 2, 100)
    yy = np.linspace(-2, 2, 100)
    X, Y = np.meshgrid(xx, yy)
    Z = checkerboard.distance(np.stack([X, Y], axis=-1), norm_p=2)

    vmin, vmax = Z.min(), Z.max()
    center = 0
    norm = mpl.colors.TwoSlopeNorm(vmin=vmin, vcenter=center, vmax=vmax)
    plt.contourf(X, Y, Z, cmap="bwr", norm=norm, levels=20)
    plt.colorbar()
    plt.show()