import torch
import functorch as ft
from matplotlib import pyplot as plt


@torch.jit.script
def q(x, N: int = 1):
    fake_func = torch.tanh(x)
    with torch.no_grad():
        residual = torch.clip(torch.round(x), -N, N) - fake_func
    return fake_func + residual


def f_L(w, wq_score, M, lamda):
    wq = q(wq_score)
    return 0.5 * ((M @ wq - M @ w)**2).sum() + lamda * torch.norm(wq, 1)


@torch.jit.script
def f_L_batched(w, wq_score, M, lamda: float):
    w = w.unsqueeze(2)  # (B, n, 1)
    wq_score = wq_score.unsqueeze(2)  # (B, n, 1)
    M = M.unsqueeze(0)  # (1, k, n)
    wq = q(wq_score)
    result = 0.5 * ((M @ (wq - w))**2).sum(1)
    result += lamda * torch.norm(wq, 1, dim=1)
    return result.squeeze(1)


def f_L_pure(w, wq, M, lamda):
    return 0.5 * ((M @ wq - M @ w)**2).sum() + lamda * torch.norm(wq, 1)


def f_y(x, y, M, lamda):
    qp = (1 - torch.tanh(y)**2)
    return qp * (
        M.t() @ (M @ (q(y) - x))
        + lamda * torch.sign(q(y))
    )


def f_y_pure(w, wq, M, lamda):
    return M.t() @ (M @ (wq - w)) + lamda * torch.sign(wq)


def f_yy(x, y, M, lamda):
    qp = (1 - torch.tanh(y)**2)
    T = (M @ torch.diag(qp))
    d = -2 * qp * torch.tanh(y) * (
        (M.t() @ (M @ (q(y) - x)))
        + lamda * torch.sign(q(y))
    )
    result = T.t() @ T + torch.diag(d)
    return result


def f_yy_pure(w, wq, M, lamda):
    return M.t() @ M


def f_yy_inv(x, y, M, lamda: float):
    t = torch.tanh(y)
    qp = (1 - t**2)
    T = M * qp[None]
    d = -2 * qp * t * (
        (M.t() @ (M @ (q(y) - x)))
        + lamda * torch.sign(q(y))
    )

    dinv = 1 / d         # n
    TDinv = T * dinv     # kxn
    inner = TDinv @ T.T  # kxk
    inner.diagonal().add_(1)
    R = -T.T @ torch.linalg.solve(inner, TDinv)  # nxn
    R.diagonal().add_(1)
    result = dinv[:, None] * R
    return result  # nxn


def f_xy(w, wq, M, lamda):
    tanh = torch.tanh(wq)
    qt = (1 - tanh**2)
    return (qt[:, None] * -M.T) @ M


def f_xy_pure(w, wq, M, lamda):
    return -M.T @ M


def v_dgdx_product(v, x, y, M, lamda: float):
    t = torch.tanh(y)   # n
    qp = (1 - t**2)     # n
    T = M * qp[None]    # kxn
    d = -2 * qp * t * (
        (M.t() @ (M @ (q(y) - x)))
        + lamda * torch.sign(q(y))
    )  # n

    dinv = 1 / d         # n
    TDinv = T * dinv     # kxn
    inner = TDinv @ T.T  # kxk
    inner.diagonal().add_(1)
    R = v @ ((dinv * qp)[:, None] * M.T)
    R -= (
        v @ (dinv[:, None] * T.T)  # nxk
        @ torch.linalg.solve(
            inner,  # kxk
            TDinv @ (qp[:, None] * M.T)  # kxk
        )
    )

    return R @ M  # vxn


@torch.jit.script
def v_dgdx_product_batched(
        v,  # (B, n)
        x,  # (B, n)
        y,  # (B, n)
        M,  # (k, n)
        lamda: float,
        eps: float = 1e-8):
    v = v.unsqueeze(2)  # (B, n, 1)
    x = x.unsqueeze(2)  # (B, n, 1)
    y = y.unsqueeze(2)  # (B, n, 1)
    M = M.unsqueeze(0)  # (1, k, n)

    t = torch.tanh(y)   # (B, n, 1)
    qp = (1 - t**2)     # (B, n, 1)
    T = M * qp.mT       # (B, k, n)
    d = -2 * qp * t * (
        (M.mT @ (M @ (q(y) - x)))
        + lamda * torch.sign(q(y))
    ) + eps  # (B, n, 1)

    dinv = 1 / d          # (B, n, 1)
    TDinv = T * dinv.mT   # (B, k, n)
    inner = TDinv @ T.mT  # (B, k, k)
    inner.diagonal(dim1=1, dim2=2).add_(1)

    R = v.mT @ ((dinv * qp) * M.mT)  # (B, 1, k)
    R -= (
        v.mT @ (dinv * T.mT)  # (B, 1, k)
        @ torch.linalg.solve(
            inner,  # (B, k, k)
            TDinv @ (qp * M.mT)  # (B, k, k)
        )
    )

    return (R @ M).squeeze(1)


@torch.jit.script
def v_dgdx_product_batched_2(
        v,  # (B, n)
        x,  # (B, n)
        y,  # (B, n)
        M,  # (k, n)
        lamda: float,
        eps: float = 1e-8):
    v = v.unsqueeze(2)  # (B, n, 1)
    x = x.unsqueeze(2)  # (B, n, 1)
    y = y.unsqueeze(2)  # (B, n, 1)
    M = M.unsqueeze(0)  # (1, k, n)

    t = torch.tanh(y)   # (B, n, 1)
    qp = (1 - t**2)     # (B, n, 1)
    d = -2 * qp * t * (
        (M.mT @ (M @ (q(y) - x)))
        + lamda * torch.sign(q(y))
    ) + eps  # (B, n, 1)
    dinv = 1 / d   # (B, n, 1)
    T = M * qp.mT  # (B, k, n)

    result = (v.mT @ M.mT @ M) * (qp * dinv).mT

    inner = (T * dinv.mT) @ T.mT  # (B, k, k)
    inner.diagonal(dim1=1, dim2=2).add_(1)
    result -= result @ T.mT @ torch.linalg.solve(inner, T * dinv.mT)

    return result.squeeze(1)


# Test correctness
def test_derivatives():
    f_y_automatic = ft.grad(f_L, argnums=1)
    f_yy_automatic = ft.jacrev(f_y_automatic, argnums=1)
    f_xy_automatic = ft.jacrev(f_y_automatic, argnums=0)

    def dgdx_automatic(x, y, M, lamda):
        yy = f_yy_automatic(x, y, M, lamda)
        xy = f_xy_automatic(x, y, M, lamda)
        return -torch.linalg.solve(yy, xy)

    def v_dgdx_product_automatic(v, x, y, M, lamda):
        return v @ dgdx_automatic(x, y, M, lamda)

    def v_dgdx_product_automatic_2(v, x, y, M, lamda):
        return v @ dgdx_automatic(x, y, M, lamda).T

    f_L_batched_automatic = ft.vmap(f_L, (0, 0, None, None))
    v_dgdx_product_batched_automatic = ft.vmap(v_dgdx_product_automatic, (0, 0, 0, None, None))
    v_dgdx_product_batched_automatic_2 = ft.vmap(v_dgdx_product_automatic_2, (0, 0, 0, None, None))

    k = 50
    n_in = 500
    lamda = 0.1
    device = 'cpu' # 'cuda'
    M = torch.randn(k, n_in, dtype=torch.double, device=device)
    x = torch.randn(n_in, dtype=torch.double, device=device)
    y = torch.randn(n_in, dtype=torch.double, device=device)
    """
    M = torch.tensor([[-1.3673475 , -0.14198436,  0.21881042],
             [ 0.08168952,  0.61552334, -0.32064667],
             [-0.622608  ,  1.3794199 , -1.5966278 ]])
    x = torch.tensor([0.85425204, 0.09605553, 3.4899924 ])
    y = torch.tensor([ 0.2831919 , -0.8360501 , -0.09649985])
    """

    print(f_L(x, y, M, lamda))

    print("f_y")
    hand = f_y(x, y, M, lamda)
    #print(hand)
    auto = f_y_automatic(x, y, M, lamda)
    #print(auto)
    print(torch.allclose(hand, auto))

    print("f_yy")
    hand = f_yy(x, y, M, lamda)
    #print(hand)
    auto = f_yy_automatic(x, y, M, lamda)
    #print(auto)
    print(torch.allclose(hand, auto))

    print("f_yy_inv")
    hand = f_yy_inv(x, y, M, lamda)
    #print(hand)
    auto = torch.linalg.inv(f_yy_automatic(x, y, M, lamda))
    #print(auto)
    print(torch.allclose(hand, auto))

    print("f_xy")
    hand = f_xy(x, y, M, lamda)
    auto = f_xy_automatic(x, y, M, lamda)
    print(torch.allclose(hand, auto))

    print("dgdx")
    v = torch.eye(n_in, dtype=torch.double, device=device)
    hand = v_dgdx_product(v, x, y, M, lamda)
    auto = v_dgdx_product_automatic(v, x, y, M, lamda)
    print(torch.allclose(hand, auto))
    #plt.imshow(hand.cpu(), vmin=-1, vmax=1)
    #plt.colorbar()
    #plt.show()

    print("v * dgdx")
    v = torch.randn(n_in, dtype=torch.double, device=device)
    hand = v_dgdx_product(v, x, y, M, lamda)
    auto = v_dgdx_product_automatic(v, x, y, M, lamda)
    print(torch.allclose(hand, auto))

    print("v * dgdx (batched)")
    B = 10
    v = torch.randn(B, n_in, dtype=torch.double, device=device)
    x = torch.randn(B, n_in, dtype=torch.double, device=device)
    y = torch.randn(B, n_in, dtype=torch.double, device=device)
    hand = v_dgdx_product_batched(v, x, y, M, lamda, eps=0.0)
    auto = v_dgdx_product_batched_automatic(v, x, y, M, lamda)
    print(hand.shape, auto.shape)
    print(torch.allclose(hand, auto))

    print("v * dgdx.T (batched 2)")
    B = 10
    v = torch.randn(B, n_in, dtype=torch.double, device=device)
    x = torch.randn(B, n_in, dtype=torch.double, device=device)
    y = torch.randn(B, n_in, dtype=torch.double, device=device)
    hand = v_dgdx_product_batched_2(v, x, y, M, lamda, eps=0.0)
    auto = v_dgdx_product_batched_automatic_2(v, x, y, M, lamda)
    print(torch.allclose(hand, auto))

    print("f_L (batched)")
    hand = f_L_batched(x, y, M, lamda)
    auto = f_L_batched_automatic(x, y, M, lamda)
    print(hand.shape, auto.shape)
    print(torch.allclose(hand, auto))


if __name__ == "__main__":
    test_derivatives()
