import numpy as np
from matplotlib import pyplot as plt


def plot_lambda(qgl_func1, qgl_func2, policy, observations, next_observations, actions):
    q1, g1, l1 = qgl_func1(observations, actions)
    q2, g2, l2 = qgl_func2(observations, actions)

    new_next_actions, _ = policy(next_observations)
    next_q1, next_g1, next_l1 = qgl_func1(next_observations, new_next_actions)
    next_q2, next_g2, next_l2 = qgl_func2(next_observations, new_next_actions)

    t = np.arange(0, observations.shape[0])
    plt.plot(t, l1, label='Lambda1')
    plt.plot(t, l2, label='Lambda2')
    plt.plot(t[:-1], l1[:-1] * (l1[:-1] - next_l1[:-1]), label='g_target1')
    plt.plot(t[:-1], l2[:-1] * (l2[:-1] - next_l2[:-1]), label='g_target2')
    plt.plot(t, g1, label='g1_hat')
    plt.plot(t, g2, label='g2_hat')

    plt.legend()
    plt.show()
