import deepxde as dde
import deepxde.backend as bkd
import torch


class Transformed(dde.geometry.Geometry):
    """Transformed geometry.

    A transformed geometry that is generated by mapping
    a reference domain to global coordinates.

    Points are sampled in the reference domain and transformed
    to global coordinates.

    Args:
        ref: The reference geometry.
        to_global: The mapping from local to global coordinates.
        to_local: The mapping from global to local coordinates.
    """

    ref: dde.geometry.Geometry
    to_global: callable
    to_local: callable

    def __init__(self, ref, to_global, to_local):
        self.ref = ref
        self.to_global_ = to_global
        self.to_local_ = to_local

        super().__init__(ref.dim, ref.bbox, torch.nan)  # dummy bbox and diam

    def to_global(self, x):
        """Transform points from local to global coordinates."""
        tensor = bkd.is_tensor(x)
        if not tensor:
            x = bkd.as_tensor(x)

        y = self.to_global_(x)

        if not tensor:
            y = bkd.to_numpy(y)

        return y

    def to_local(self, y):
        """Transform points from global to local coordinates."""
        tensor = bkd.is_tensor(x)
        if not tensor:
            y = bkd.as_tensor(y)

        x = self.to_local_(y)

        if not tensor:
            y = bkd.to_numpy(y)

        return x

    def inside(self, x):
        """Return whether the points are inside the geometry."""
        return self.ref.inside(self.to_local(x))

    def on_boundary(self, x):
        """Return if x is on the boundary of the geometry."""
        return self.ref.on_boundary(self.to_local(x))

    def random_points(self, n, random="pseudo"):
        """Return random points."""
        x = self.ref.random_points(n, random)
        return self.to_global(x)

    def uniform_points(self, n):
        """Return uniform points."""
        x = self.ref.uniform_points(n)
        return self.to_global(x)

    def random_boundary_points(self, n, random="pseudo"):
        """Return random boundary points."""
        x = self.ref.random_boundary_points(n, random)
        return self.to_global(x)

    def uniform_boundary_points(self, n):
        """Return uniform boundary points."""
        x = self.ref.uniform_boundary_points(n)
        return self.to_global(x)
