import itertools
from typing import Tuple, Any, Union, List
import functools
import operator
import torch
import numpy as np
# import geoopt


max_norm = 85
eps = 1e-8

__all__ = [
    "copy_or_set_",
    "strip_tuple",
    "size2shape",
    "make_tuple",
    "broadcast_shapes",
    # "ismanifold",
    # "canonical_manifold",
    "list_range",
    "idx2sign",
    "drop_dims",
    "canonical_dims",
    "sign",
    "prod",
    "clamp_abs",
    "sabs",
]

# 归一化
def norm_normalization(data):
    _range = np.max(data) - np.min(data)
    return (data - np.min(data)) / _range
 

# 标准化
def norm_standardization(data):
    mu = np.mean(data, axis=0)
    sigma = np.std(data, axis=0)
    return (data - mu) / sigma

def is_on_lorentz_manifold(x, k):
    """
    Check whether embeddings x in Lorentz manifold with curvature k.
    """
    # x = x.double()
    # k = k.double()
    norm = torch.sum(x[:, 1:] ** 2, dim=1) 
    time = x[:, 0] ** 2 
    results = norm - time 
    all_equal = torch.isclose(results, -k, atol=1e-01) 

    if not torch.all(all_equal):
        print(f'FALLLLLLLL.....vector is not in lorentz manifold {k}')
        print(results)
        print(all_equal)
        import os
        os._exit(0)
    return True


def copy_or_set_(dest: torch.Tensor, source: torch.Tensor) -> torch.Tensor:
    """
    Copy or inplace set from :code:`source` to :code:`dest`.

    A workaround to respect strides of :code:`dest` when copying :code:`source`.
    The original issue was raised `here <https://github.com/geoopt/geoopt/issues/70>`_
    when working with matrix manifolds. Inplace set operation is mode efficient,
    but the resulting storage might be incompatible after. To avoid the issue we refer to
    the safe option and use :code:`copy_` if strides do not match.

    Parameters
    ----------
    dest : torch.Tensor
        Destination tensor where to store new data
    source : torch.Tensor
        Source data to put in the new tensor

    Returns
    -------
    dest
        torch.Tensor, modified inplace
    """
    if dest.stride() != source.stride():
        return dest.copy_(source)
    else:
        return dest.set_(source)


def strip_tuple(tup: Tuple) -> Union[Tuple, Any]:
    if len(tup) == 1:
        return tup[0]
    else:
        return tup


def make_tuple(obj: Union[Tuple, List, Any]) -> Tuple:
    if isinstance(obj, list):
        obj = tuple(obj)
    if not isinstance(obj, tuple):
        return (obj,)
    else:
        return obj


def prod(items):
    return functools.reduce(operator.mul, items, 1)


def sign(x):
    return torch.sign(x.sign() + 0.5)


def sabs(x, eps: float = 1e-15):
    return x.abs().add_(eps)


def clamp_abs(x, eps: float = 1e-15):
    s = sign(x)
    return s * sabs(x, eps=eps)


def idx2sign(idx: int, dim: int, neg: bool = True):
    """
    Unify idx to be negative or positive, that helps in cases of broadcasting.

    Parameters
    ----------
    idx : int
        current index
    dim : int
        maximum dimension
    neg : bool
        indicate we need negative index

    Returns
    -------
    int
    """
    if neg:
        if idx < 0:
            return idx
        else:
            return (idx + 1) % -(dim + 1)
    else:
        return idx % dim


def drop_dims(tensor: torch.Tensor, dims: List[int]):
    # Workaround to drop several dims in :func:`torch.squeeze`.
    seen: int = 0
    for d in dims:
        tensor = tensor.squeeze(d - seen)
        seen += 1
    return tensor


def list_range(end: int):
    res: List[int] = []
    for d in range(end):
        res.append(d)
    return res


def canonical_dims(dims: List[int], maxdim: int):
    result: List[int] = []
    for idx in dims:
        result.append(idx2sign(idx, maxdim, neg=False))
    return result


def size2shape(*size: Union[Tuple[int], int]) -> Tuple[int]:
    return make_tuple(strip_tuple(size))


def broadcast_shapes(*shapes: Tuple[int]) -> Tuple[int]:
    """Apply numpy broadcasting rules to shapes."""
    result = []
    for dims in itertools.zip_longest(*map(reversed, shapes), fillvalue=1):
        dim: int = 1
        for d in dims:
            if dim != 1 and d != 1 and d != dim:
                raise ValueError("Shapes can't be broadcasted")
            elif d > dim:
                dim = d
        result.append(dim)
    return tuple(reversed(result))

#
# def ismanifold(instance, cls):
#     """
#     Check if interface of an instance is compatible with given class.
#
#     Parameters
#     ----------
#     instance : geoopt.Manifold
#         check if a given manifold is compatible with cls API
#     cls : type
#         manifold type
#
#     Returns
#     -------
#     bool
#         comparison result
#     """
#     if not issubclass(cls, geoopt.manifolds.Manifold):
#         raise TypeError(
#             "`cls` should be a subclass of geoopt.manifolds.Manifold")
#     if not isinstance(instance, geoopt.manifolds.Manifold):
#         return False
#     else:
#         # this is the case to care about, Scaled class is a proxy, but fails instance checks
#         while isinstance(instance, geoopt.Scaled):
#             instance = instance.base
#         return isinstance(instance, cls)


# def canonical_manifold(manifold: "geoopt.Manifold"):
#     """
#     Get a canonical manifold.
#
#     If a manifold is wrapped with Scaled. Some attributes may not be available. This should help if you really need them.
#
#     Parameters
#     ----------
#     manifold : geoopt.Manifold
#
#     Returns
#     -------
#     geoopt.Maniflold
#         an unwrapped manifold
#     """
#     while isinstance(manifold, geoopt.Scaled):
#         manifold = manifold.base
#     return manifold


def cosh(x: torch.Tensor) -> torch.Tensor:
    x = clamp(x, min=-max_norm, max=max_norm)
    return torch.cosh(x)


def sinh(x: torch.Tensor) -> torch.Tensor:
    x = clamp(x, min=-max_norm, max=max_norm)
    return torch.sinh(x)


def sqrt(x: torch.Tensor) -> torch.Tensor:
    x = clamp(x, min=1e-9)  # Smaller epsilon due to precision around x=0.
    return torch.sqrt(x)


class LeakyClamp(torch.autograd.Function):

    @staticmethod
    def forward(ctx: Any, x: torch.Tensor, min: float, max: float) -> torch.Tensor:
        with torch.no_grad():
            ctx.save_for_backward(x.ge(min) & x.le(max))
            return torch.clamp(x, min=min, max=max)

    @staticmethod
    def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None, None]:
        mask, = ctx.saved_tensors
        mask = mask.type_as(grad_output)
        return grad_output * mask + grad_output * (1 - mask) * eps, None, None


def clamp(x: torch.Tensor, min: float = float("-inf"), max: float = float("+inf")) -> torch.Tensor:
    return LeakyClamp.apply(x, min, max)


class Atanh(torch.autograd.Function):
    """
    Numerically stable arctanh that never returns NaNs.
    x = clamp(x, min=-1+eps, max=1-eps)
    Returns atanh(x) = arctanh(x) = 0.5*(log(1+x)-log(1-x)).
    """

    @staticmethod
    def forward(ctx: Any, x: torch.Tensor) -> torch.Tensor:
        x = clamp(x, min=-1. + 4 * eps, max=1. - 4 * eps)
        ctx.save_for_backward(x)
        res = (torch.log_(1 + x).sub_(torch.log_(1 - x))).mul_(0.5)
        return res

    @staticmethod
    def backward(ctx: Any, grad_output: torch.Tensor) -> torch.Tensor:
        x, = ctx.saved_tensors
        return grad_output / (1 - x**2)


def atanh(x: torch.Tensor) -> torch.Tensor:
    """
    Numerically stable arctanh that never returns NaNs.

    :param x: The input tensor.
    :return: log(x + sqrt(max(x^2 - 1, eps))
    """
    return Atanh.apply(x)


class Acosh(torch.autograd.Function):
    """
    Numerically stable arccosh that never returns NaNs.
    Returns acosh(x) = arccosh(x) = log(x + sqrt(max(x^2 - 1, eps))).
    """

    @staticmethod
    def forward(ctx: Any, x: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            x = clamp(x, min=1 + eps)
            z = sqrt(x * x - 1.)
            ctx.save_for_backward(z)
            return torch.log(x + z)

    @staticmethod
    def backward(ctx: Any, grad_output: torch.Tensor) -> torch.Tensor:
        z, = ctx.saved_tensors
        # z_ = clamp(z, min=eps)
        z_ = z
        return grad_output / z_


def acosh(x: torch.Tensor) -> torch.Tensor:
    """
    Numerically stable arccosh that never returns NaNs.

    :param x: The input tensor.
    :return: log(x + sqrt(max(x^2 - 1, eps))
    """
    return Acosh.apply(x)
