import matplotlib.pyplot as plt
import numpy as np
import scipy.linalg


def rk_accuracy(order, theta=1024, rk_order=8, lookahead=0):
    Q = np.arange(order, dtype=np.float64)
    R = (2 * Q + 1)[:, None] / theta
    j, i = np.meshgrid(Q, Q)
    A = np.where(i < j, -1, (-1.0) ** (i - j + 1)) * R
    Ad = scipy.linalg.expm(A)

    A_rk = np.eye(order)
    Apow = np.eye(order)
    for i in range(rk_order):
        Apow = np.dot(A, Apow) / (i + 1)
        A_rk += Apow

    norm_fn = np.linalg.norm
    # norm_fn = lambda x: np.linalg.norm(x, ord=np.inf)

    if not lookahead or lookahead == 1:
        norm = norm_fn(A_rk - Ad)
    else:
        assert isinstance(lookahead, int)
        norm = norm_fn(
            np.linalg.matrix_power(Ad, lookahead)
            - np.linalg.matrix_power(A_rk, lookahead)
        )

    return norm


plt.figure(figsize=(12, 6))
# orders = np.round(np.logspace(1, 2, 21)).astype(int)
orders = np.round(np.logspace(np.log10(3), np.log10(200), 51)).astype(int)

plt.subplot(121)
rk_order = 4
thetas = [10, 100, 1000]
for theta in thetas:
    accuracies = [rk_accuracy(order, theta=theta) for order in orders]
    plt.loglog(orders, accuracies, label=f"theta = {theta:0.0f}")

plt.xlabel("order")
plt.ylabel("norm")
plt.title(f"rk_order = {rk_order}")
plt.legend()

plt.subplot(122)
theta = 1000
rk_orders = [2, 4, 8]
for rk_order in rk_orders:
    accuracies = [
        rk_accuracy(order, theta=theta, rk_order=rk_order) for order in orders
    ]
    plt.loglog(orders, accuracies, label=f"rk = {rk_order}")

plt.xlabel("order")
plt.ylabel("norm")
plt.title(f"theta = {theta:0.0f}")
plt.legend()

plt.figure()
orders = np.round(np.logspace(np.log10(3), np.log10(200), 21)).astype(int)

rk_orders = [4, 8]
for i, rk_order in enumerate(rk_orders):
    plt.subplot(1, len(rk_orders), i + 1)

    lookaheads = [1, 2, 5, 10, 20, 50, 100]
    theta = 512
    for lookahead in lookaheads:
        accuracies = [
            rk_accuracy(order, theta=theta, rk_order=rk_order, lookahead=lookahead)
            for order in orders
        ]
        plt.loglog(orders, accuracies, label=f"lookahead = {lookahead}")

    plt.ylim([1e-15, 1e8])
    plt.xlabel("order")
    plt.ylabel("norm")
    plt.title(f"rk_order = {rk_order}")
    plt.legend()

plt.show()
