import yaml
import numpy as np
import torch

HIGH = 1e6


class Rectangle(object):
    def __init__(self, w, h, c, tensor_kwargs={"device": "cpu", "dtype": torch.float32}):
        self.w = w
        self.h = h
        self.r = torch.tensor([w / 2, h / 2], **tensor_kwargs)
        self.c = torch.tensor(c, **tensor_kwargs)
        self.tensor_kwargs = tensor_kwargs

    def dist(self, p):
        p_rel = p - self.c
        q = (torch.abs(p_rel) - self.r)
        dist_out = torch.sum(q.clamp(min=0)**2, dim=-1)
        dist_in = q.max(dim=-1)[0].clamp(max=0)**2
        return dist_in - dist_out


class Circle(object):
    def __init__(self, r, c, tensor_kwargs={"device": "cpu", "dtype": torch.float32}):
        self.r = r
        self.c = torch.tensor(c, **tensor_kwargs)
        self.tensor_kwargs = tensor_kwargs

    def dist(self, p):
        p_rel = p - self.c
        return self.r**2 - torch.sum(p_rel**2, dim=-1)


class DiffMap(object):
    def __init__(self, file=None, width=None, height=None, origin=[0, 0], shapes=None,
                 tensor_kwargs={"device": "cpu", "dtype": torch.float32}):
        self.width = width
        self.height = height
        self.origin = origin
        self.shapes = shapes if shapes is not None else []
        self.tensor_kwargs = tensor_kwargs

        if file is not None:
            self.load_from_file(file)

        assert self.width is not None and self.height is not None, "Must provide either width and height or map file."

        self.lims = [self.origin[0], self.origin[0] + self.width,
                     self.origin[1], self.origin[1] + self.height]

    def load_from_file(self, file_path):
        self.shapes = []

        with open(file_path, 'r') as f:
            data = yaml.load(f, Loader=yaml.Loader)
            data = data["data"]
            self.width = data["width"]
            self.height = data["height"]
            self.origin = data["origin"]

            for key, val in data["obstacles"].items():
                if val["geometry"] == "rectangle":
                    self.add_rectangle(val["width"], val["height"], val["origin"])
                elif val["geometry"] == "circle":
                    self.add_circle(val["radius"], val["origin"])
                else:
                    print("WARNING: Unknown obstacle type:", val["geometry"])

    def add_circle(self, radius, center):
        self.shapes.append(Circle(radius, center, tensor_kwargs=self.tensor_kwargs))

    def add_rectangle(self, width, height, center):
        self.shapes.append(Rectangle(width, height, center, tensor_kwargs=self.tensor_kwargs))

    def compute_binary_img(self, ppm=40):
        sdf = self.compute_discrete_sdf(ppm=ppm)
        return sdf >= 0

    def compute_discrete_sdf(self, ppm=40):
        pix_w, pix_h, pts = self.grid_pts(ppm=ppm)

        sdf = torch.full((pix_w, pix_h), -HIGH, **self.tensor_kwargs)
        for shape in self.shapes:
            shape_sdf = shape.dist(pts).reshape((pix_w, pix_h))
            sdf = torch.maximum(shape_sdf, sdf)
        return sdf

    def grid_pts(self, ppm=40):
        pix_w, pix_h = int(ppm * self.width), int(ppm * self.height)
        x = np.linspace(*self.lims[:2], pix_w) + 1. / (2 * ppm)
        y = np.linspace(*self.lims[2:], pix_h) + 1. / (2 * ppm)
        X, Y = np.meshgrid(x, y)
        pts = np.stack([X.reshape(-1), np.flip(Y.reshape(-1))], axis=-1)
        pts = torch.tensor(pts, **self.tensor_kwargs)

        return pix_w, pix_h, pts

    def eval_sdf(self, x):
        dists = torch.stack([shape.dist(x) for shape in self.shapes])
        # TODO: Where this is positive, should be min, not max.
        return dists.max(dim=0)[0]
