import torch
from utility import uninterleave, plot_orderbook
from matplotlib import pyplot as plt
plt.rcParams["font.family"] = "Times New Roman"

def plot_stuff(load_path, title=""):
	orderbook_sample = torch.load(load_path).squeeze()
	uninterleaved = uninterleave(orderbook_sample)
	plot_orderbook(uninterleaved, title)

plot_stuff('/home/user/Desktop/trading_attacks/default_out/778_fooled_target=2_diffUnprop.pth', "Unropagated Perturbation")
plot_stuff('/home/user/Desktop/trading_attacks/default_out/778_fooled_target=2_diffProp.pth', "Propagated Perturbation")
plot_stuff('/home/user/Desktop/trading_attacks/default_out/778_fooled_target=2_clean_inputs.pth', "Unperturbed Order Book")
plot_stuff('/home/user/Desktop/trading_attacks/default_out/778_fooled_target=2_perturbed_inputs.pth', "Perturbed Order Book")
plot_stuff('/home/user/Desktop/trading_attacks/default_out/universal_perturbation.pth', "Targeted Universal Perturbation")

###
# FOR THE PAPER
###

# Universal perturbations
plot_stuff('universal_perturbations/universal_f_MLP_target0_detectReg.pth', 'universal_f_MLP_target0_detectReg')
plot_stuff('universal_perturbations/universal_f_MLP_target0.pth', 'universal_f_MLP_target0')

plot_stuff('universal_perturbations/universal_f_linear_target0.pth', 'universal_f_linear_target0')
plot_stuff('universal_perturbations/universal_f_linear_target0_detectReg.pth', 'universal_f_linear_target0_detectReg')

# plot_stuff('universal_perturbations/universal_f_LSTM_target0.pth', 'universal_f_LSTM_target0.pth')
# plot_stuff('universal_perturbations/universal_f_LSTM_target0_detectReg.pth', 'universal_f_LSTM_target0_detectReg.pth')

plt.show()
