import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

fig, axs = plt.subplots(1, 5, figsize=(5.5 * 5, 5))
axis_lim = 1.1

sns.set_context("poster")
sns.set_theme(style="white", font_scale=3.2)




data = np.load("toy.npy", allow_pickle=True).item()
actions = data["actions"]
sns.scatterplot(x = actions[:, 0], y = actions[:, 1],edgecolor = "royalblue",color = 'royalblue',
                s=8,legend=False,ax=axs[0])

# axs[0].scatter(actions[:, 0], actions[:, 1], alpha=0.3, s=5.5, color='royalblue')
axs[0].set_xlabel("")
axs[0].set_ylabel("")
axs[0].tick_params(axis="x", labelsize=20)
axs[0].tick_params(axis="y", labelsize=20)
axs[0].set_title("Ground Truth", fontsize=38, pad=20)

# Load the .npy file
actions = np.load("diffcps15.npy", allow_pickle=True)
# axs[1].scatter(actions[:, 0], actions[:, 1], alpha=0.3, s=5.5, color='royalblue')
sns.scatterplot(x = actions[:, 0], y = actions[:, 1],edgecolor = "royalblue",color = 'royalblue',
                s=8,legend=False,ax=axs[1])
axs[1].set_xlabel("")
axs[1].set_ylabel("")
axs[1].tick_params(axis="x", labelsize=20)
axs[1].tick_params(axis="y", labelsize=20)
axs[1].set_title("DiffCPS (T=15)", fontsize=38, pad=20)

actions = np.load("cave.npy", allow_pickle=True)
# axs[2].scatter(actions[:, 0], actions[:, 1], alpha=0.3, s=5.5, color='royalblue')
sns.scatterplot(x = actions[:, 0], y = actions[:, 1],edgecolor = "royalblue",color = 'royalblue',
                s=8,legend=False,ax=axs[2])

axs[2].set_xlabel("")
axs[2].set_ylabel("")
axs[2].tick_params(axis="x", labelsize=20)
axs[2].tick_params(axis="y", labelsize=20)
axs[2].set_title("BCQ", fontsize=38, pad=20)

actions = np.load("mle.npy", allow_pickle=True)
# axs[3].scatter(actions[:, 0], actions[:, 1], alpha=0.3, s=5.5, color='royalblue')
sns.scatterplot(x = actions[:, 0], y = actions[:, 1],edgecolor = "royalblue",color = 'royalblue',
                s=8,legend=False,ax=axs[3])
axs[3].set_xlabel("")
axs[3].set_ylabel("")
axs[3].tick_params(axis="x", labelsize=20)
axs[3].tick_params(axis="y", labelsize=20)
axs[3].set_title("TD3+BC", fontsize=38, pad=20)

actions = np.load("mmd.npy", allow_pickle=True)
# axs[4].scatter(actions[:, 0], actions[:, 1], alpha=0.3, s=5.5, color='royalblue')
sns.scatterplot(x = actions[:, 0], y = actions[:, 1],edgecolor = "royalblue",color = 'royalblue',
                s=8,legend=False,ax=axs[4])
axs[4].set_xlabel("")
axs[4].set_ylabel("")
axs[4].tick_params(axis="x", labelsize=20)
axs[4].tick_params(axis="y", labelsize=20)
axs[4].set_title("BEAR", fontsize=38, pad=20)


plt.tight_layout()
plt.savefig("others.pdf", format="pdf")

plt.clf()
fig, axs = plt.subplots(1, 5, figsize=(5.5 * 5, 5))
axis_lim = 1.1

data = np.load("toy.npy", allow_pickle=True).item()
actions = data['actions']
sns.scatterplot(x = actions[:, 0], y = actions[:, 1],edgecolor = "royalblue",color = 'royalblue',
                s=8,legend=False,ax=axs[0])
axs[0].set_xlabel("")
axs[0].set_ylabel("")
axs[0].tick_params(axis='x', labelsize=20)
axs[0].tick_params(axis='y', labelsize=20)
axs[0].set_title("Ground Truth", fontsize=38,pad=20)

# Load the .npy file
actions = np.load("diff15.npy", allow_pickle=True)
sns.scatterplot(x = actions[:, 0], y = actions[:, 1],edgecolor = "royalblue",color = 'royalblue',
                s=8,legend=False,ax=axs[1])
axs[1].set_xlabel('')
axs[1].set_ylabel('')
axs[1].tick_params(axis='x', labelsize=20)
axs[1].tick_params(axis='y', labelsize=20)
axs[1].set_title("Diffusion (T=15)", fontsize=38,pad=20)

actions = np.load("diff25.npy", allow_pickle=True)
sns.scatterplot(x = actions[:, 0], y = actions[:, 1],edgecolor = "royalblue",color = 'royalblue',
                s=8,legend=False,ax=axs[2])
axs[2].set_xlabel('')
axs[2].set_ylabel('')
axs[2].tick_params(axis='x', labelsize=20)
axs[2].tick_params(axis='y', labelsize=20)
axs[2].set_title("Diffusion (T=25)", fontsize=38,pad=20)

actions = np.load("diff30.npy", allow_pickle=True)
sns.scatterplot(x = actions[:, 0], y = actions[:, 1],edgecolor = "royalblue",color = 'royalblue',
                s=8,legend=False,ax=axs[3])
axs[3].set_xlabel('')
axs[3].set_ylabel('')
axs[3].tick_params(axis='x', labelsize=20)
axs[3].tick_params(axis='y', labelsize=20)
axs[3].set_title("Diffusion (T=30)", fontsize=38,pad=20)

actions = np.load("diff50.npy", allow_pickle=True)
sns.scatterplot(x = actions[:, 0], y = actions[:, 1],edgecolor = "royalblue",color = 'royalblue',
                s=8,legend=False,ax=axs[4])
axs[4].set_xlabel('')
axs[4].set_ylabel('')
axs[4].tick_params(axis='x', labelsize=20)
axs[4].tick_params(axis='y', labelsize=20)
axs[4].set_title("Diffusion (T=50)", fontsize=38,pad=20)


plt.tight_layout()
plt.savefig("diff_res.pdf", format="pdf")

