import torch
from math import gamma, sqrt, log


@torch.jit.script
def lorentz2klein(x, c):
    x = x / x[..., :1] / (-c).sqrt()
    return x[..., 1:]


@torch.jit.script
def klein2lorentz(x, c):
    c = -c
    x_norm = x.pow(2).sum(dim=-1, keepdim=True)
    denominator = (1 - x_norm * c).sqrt()
    x = torch.concat([
        torch.ones_like(x_norm) / c.sqrt(),
        x
    ], dim=-1) / denominator

    return x


def klein_mobius_addition(u, v):
    gamma_u = 1 / (1 - u.pow(2).sum(dim=-1, keepdim=True)).sqrt()
    uv_dot = (u * v).sum(dim=-1, keepdim=True)
    addition = u + 1 / gamma_u * v + (1 - 1 / (1 + gamma_u)) * uv_dot * u
    addition = addition / (1 + uv_dot)

    return addition


def halfplane_distance(x, y, keepdim=False):
    mu1, mu2 = x[..., 0], y[..., 0]
    sigma1, sigma2 = x[..., 1], y[..., 1]

    norm1 = ((mu1 - mu2).pow(2) + (sigma1 - sigma2).pow(2)).sqrt()
    norm2 = ((mu1 - mu2).pow(2) + (sigma1 + sigma2).pow(2)).sqrt()
    dist = (norm2 + norm1).log() - (norm2 - norm1).log()
    return dist


def fisher_rao_distance(x, y, keepdim=False):
    mu1, mu2 = x[..., 0] / sqrt(2), y[..., 0] / sqrt(2)
    sigma1, sigma2 = x[..., 1], y[..., 1]

    x_ = torch.stack([mu1, sigma1], dim=-1)
    y_ = torch.stack([mu2, sigma2], dim=-1)
    return sqrt(2) * halfplane_distance(x_, y_)


@torch.jit.script
def halfplane2disk(x, c):
    c = -c
    a, b = x[..., 0], x[..., 1]
    denominator = c * a.pow(2) + (b + 1).pow(2)
    x = torch.stack([
        c.sqrt() * a.pow(2) + (b.pow(2) - 1) / c.sqrt(),
        -2 * a
    ], dim=-1) / denominator[..., None]

    return x


@torch.jit.script
def disk2halfplane(x, c):
    c = -c
    a, b = x[..., 0], x[..., 1]
    denominator = (c.sqrt() * a - 1).pow(2) + c * b.pow(2)
    x = torch.stack([
        -2 * b,
        1 - (a.pow(2) + b.pow(2)) * c
    ], dim=-1) / denominator[..., None]

    return x


@torch.jit.script
def disk2lorentz(x, c):
    c = -c
    x_norm = x.pow(2).sum(dim=-1, keepdim=True)
    x = torch.concat([
        (1 + x_norm * c) / (c.sqrt() * (1 - x_norm * c)),
        2 * x / (1 - x_norm * c)
    ], dim=-1)
    return x


@torch.jit.script
def lorentz2disk(x, c):
    c = -c
    x = x / (c.sqrt() * x[..., :1] + 1)
    return x[..., 1:]


@torch.jit.script
def lorentz2halfplane(x, c, log=False):
    c = -c
    t, a, b = x[..., 0], x[..., 1], x[..., 2]
    x0 = -b / (c.sqrt() * (t - a))
    if not log:
        x1 = 1 / (c.sqrt() * (t - a))
    else:
        x1 = -0.5 * c.log() - (t - a).log()
    x = torch.stack([x0, x1], dim=-1)
    return x


@torch.jit.script
def halfplane2lorentz(x, c):
    c = -c
    a, b = x[..., 0], x[..., 1]
    x = torch.stack([
        (1 + c * a.pow(2) + b.pow(2)) / (2 * c.sqrt() * b),
        (-1 + c * a.pow(2) + b.pow(2)) / (2 * c.sqrt() * b),
        -a / b
    ], dim=-1)

    return x


@torch.jit.script
def lorentz_expmap0(z, c):
    alpha = z.pow(2).sum(dim=-1, keepdim=True).sqrt() * (-c).sqrt()
    x = torch.concat([
        alpha.cosh() / (-c).sqrt(),
        alpha.sinh() / alpha * z
    ], dim=-1)

    return x


@torch.jit.script
def lorentz_logmap0(y, c):
    beta = (-c).sqrt() * y[..., :1]
    v = torch.arccosh(beta) / (beta.pow(2) - 1).sqrt() * y[..., 1:]

    return v

