

import functools
import inspect
import logging
from typing import Callable, List

import numpy as np
import torch

logger = logging.getLogger(__name__)




def autocast(func: Callable) -> Callable:


    @functools.wraps(func)
    def wrap(self, *args):
        device = torch.device("cpu")
        dtype = None
        if isinstance(self, TensorWrapper):
            if self._data is not None:
                device = self.device
                dtype = self.dtype
        elif not inspect.isclass(self) or not issubclass(self, TensorWrapper):
            raise ValueError(self)

        cast_args = []
        for arg in args:
            if isinstance(arg, np.ndarray):
                arg = torch.from_numpy(arg)
                arg = arg.to(device=device, dtype=dtype)
            cast_args.append(arg)
        return func(self, *cast_args)

    return wrap


class TensorWrapper:
    """Wrapper for PyTorch tensors."""

    _data = None

    @autocast
    def __init__(self, data: torch.Tensor):
        """Wrapper for PyTorch tensors."""
        self._data = data

    @property
    def shape(self) -> torch.Size:
        """Shape of the underlying tensor."""
        return self._data.shape[:-1]

    @property
    def device(self) -> torch.device:
        """Get the device of the underlying tensor."""
        return self._data.device

    @property
    def dtype(self) -> torch.dtype:
        """Get the dtype of the underlying tensor."""
        return self._data.dtype

    def __getitem__(self, index) -> torch.Tensor:
        """Get the underlying tensor."""
        return self.__class__(self._data[index])

    def __setitem__(self, index, item):
        """Set the underlying tensor."""
        self._data[index] = item.data

    def to(self, *args, **kwargs):
        """Move the underlying tensor to a new device."""
        return self.__class__(self._data.to(*args, **kwargs))

    def cpu(self):
        """Move the underlying tensor to the CPU."""
        return self.__class__(self._data.cpu())

    def cuda(self):
        """Move the underlying tensor to the GPU."""
        return self.__class__(self._data.cuda())

    def pin_memory(self):
        """Pin the underlying tensor to memory."""
        return self.__class__(self._data.pin_memory())

    def float(self):
        """Cast the underlying tensor to float."""
        return self.__class__(self._data.float())

    def double(self):
        """Cast the underlying tensor to double."""
        return self.__class__(self._data.double())

    def detach(self):
        """Detach the underlying tensor."""
        return self.__class__(self._data.detach())

    def numpy(self):
        """Convert the underlying tensor to a numpy array."""
        return self._data.detach().cpu().numpy()

    def new_tensor(self, *args, **kwargs):
        """Create a new tensor of the same type and device."""
        return self._data.new_tensor(*args, **kwargs)

    def new_zeros(self, *args, **kwargs):
        """Create a new tensor of the same type and device."""
        return self._data.new_zeros(*args, **kwargs)

    def new_ones(self, *args, **kwargs):
        """Create a new tensor of the same type and device."""
        return self._data.new_ones(*args, **kwargs)

    def new_full(self, *args, **kwargs):
        """Create a new tensor of the same type and device."""
        return self._data.new_full(*args, **kwargs)

    def new_empty(self, *args, **kwargs):
        """Create a new tensor of the same type and device."""
        return self._data.new_empty(*args, **kwargs)

    def unsqueeze(self, *args, **kwargs):
        """Create a new tensor of the same type and device."""
        return self.__class__(self._data.unsqueeze(*args, **kwargs))

    def squeeze(self, *args, **kwargs):
        """Create a new tensor of the same type and device."""
        return self.__class__(self._data.squeeze(*args, **kwargs))

    @classmethod
    def stack(cls, objects: List, dim=0, *, out=None):
        """Stack a list of objects with the same type and shape."""
        data = torch.stack([obj._data for obj in objects], dim=dim, out=out)
        return cls(data)

    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        """Support torch functions."""
        if kwargs is None:
            kwargs = {}
        return cls.stack(*args, **kwargs) if func is torch.stack else NotImplemented


class EuclideanManifold:
    """Simple euclidean manifold."""

    @staticmethod
    def J_plus(x: torch.Tensor) -> torch.Tensor:
        """Plus operator Jacobian."""
        return torch.eye(x.shape[-1]).to(x)

    @staticmethod
    def plus(x: torch.Tensor, delta: torch.Tensor) -> torch.Tensor:
        """Plus operator."""
        return x + delta


class SphericalManifold:


    @staticmethod
    def householder_vector(x: torch.Tensor) -> torch.Tensor:

        sigma = torch.sum(x[..., :-1] ** 2, -1)
        xpiv = x[..., -1]
        norm = torch.norm(x, dim=-1)
        if torch.any(sigma < 1e-7):
            sigma = torch.where(sigma < 1e-7, sigma + 1e-7, sigma)
            logger.warning("sigma < 1e-7")

        vpiv = torch.where(xpiv < 0, xpiv - norm, -sigma / (xpiv + norm))
        beta = 2 * vpiv**2 / (sigma + vpiv**2)
        v = torch.cat([x[..., :-1] / vpiv[..., None], torch.ones_like(vpiv)[..., None]], -1)
        return v, beta

    @staticmethod
    def apply_householder(y: torch.Tensor, v: torch.Tensor, beta: torch.Tensor) -> torch.Tensor:

        return y - v * (beta * torch.einsum("...i,...i->...", v, y))[..., None]

    @classmethod
    def J_plus(cls, x: torch.Tensor) -> torch.Tensor:
        """Plus operator Jacobian."""
        v, beta = cls.householder_vector(x)
        H = -torch.einsum("..., ...k, ...l->...kl", beta, v, v)
        H = H + torch.eye(H.shape[-1]).to(H)
        return H[..., :-1]  # J

    @classmethod
    def plus(cls, x: torch.Tensor, delta: torch.Tensor) -> torch.Tensor:

        eps = 1e-7
        # keep norm is not equal to 1
        nx = torch.norm(x, dim=-1, keepdim=True)
        nd = torch.norm(delta, dim=-1, keepdim=True)

        # make sure we don't divide by zero in backward as torch.where computes grad for both
        # branches
        nd_ = torch.where(nd < eps, nd + eps, nd)
        sinc = torch.where(nd < eps, nd.new_ones(nd.shape), torch.sin(nd_) / nd_)

        # cos is applied to last dim instead of first
        exp_delta = torch.cat([sinc * delta, torch.cos(nd)], -1)

        v, beta = cls.householder_vector(x)
        return nx * cls.apply_householder(exp_delta, v, beta)


@torch.jit.script
def J_vecnorm(vec: torch.Tensor) -> torch.Tensor:

    D = vec.shape[-1]
    norm_x = torch.norm(vec, dim=-1, keepdim=True).unsqueeze(-1)  # (..., 1, 1)

    if (norm_x == 0).any():
        norm_x = norm_x + 1e-6

    xxT = torch.einsum("...i,...j->...ij", vec, vec)  # (..., D, D)
    identity = torch.eye(D, device=vec.device, dtype=vec.dtype)  # (D, D)

    return identity / norm_x - (xxT / norm_x**3)  # (..., D, D)


@torch.jit.script
def J_focal2fov(focal: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
    """Compute the jacobian of the focal2fov function."""
    return -4 * h / (4 * focal**2 + h**2)


@torch.jit.script
def J_up_projection(uv: torch.Tensor, abc: torch.Tensor, wrt: str = "uv") -> torch.Tensor:

    if wrt == "uv":
        c = abc[..., 2][..., None, None, None]
        return -c * torch.eye(2, device=uv.device, dtype=uv.dtype).expand(uv.shape[:-1] + (2, 2))

    elif wrt == "abc":
        J = uv.new_zeros(uv.shape[:-1] + (2, 3))
        J[..., 0, 0] = 1
        J[..., 1, 1] = 1
        J[..., 0, 2] = -uv[..., 0]
        J[..., 1, 2] = -uv[..., 1]
        return J

    else:
        raise ValueError(f"Unknown wrt: {wrt}")
