
import numpy as np
import matplotlib.pyplot as plt
import os

datasets = ["iris", "seeds", "compas"]
names = ["Iris", "Seeds", "COMPAS"]
eps_list = [0.80, 0.80, 0.50]
lossf_name = "cross_entropy"
attack_norm = float("inf")


def load_dataset_data(dataset_name, prefix, eta_cap=0.5):
    keys = np.load(f"results/{dataset_name}{prefix}_data.npy.npz", allow_pickle=True)
    data = [keys[key] for key in keys]
    etas, losses, lossopt, losses_adv, losses_orig_adv, weight_sims, pattern_sims = data

    ind = 0
    for i, eta in enumerate(etas):
        if eta > eta_cap:
            ind = i
            break

    return {
        "etas": etas[:ind],
        "losses": losses[:ind],
        "lossopt": lossopt,
        "losses_adv": losses_adv[:ind],
        "losses_orig_adv": losses_orig_adv[:ind],
        "weight_sims": weight_sims[:ind],
        "pattern_sims": pattern_sims[:ind],
    }


os.makedirs("figures/final", exist_ok=True)

# # Weight Similarity

# plt.plot(etas, weight_sims)

# plt.title(f"{dataset} Dataset")
# plt.xlabel("Eta")
# plt.ylabel("Norm of weights times cosine similarity")

# plt.savefig(f'figures/final/{dataset}{prefix}_weight_similarity.png')
# plt.clf()

# Pattern Similarity (side-by-side datasets)

fig, axes = plt.subplots(1, len(datasets), figsize=(18, 5))
for ax, dataset, eps, name in zip(axes, datasets, eps_list, names):
    prefix = f"_{lossf_name}_{attack_norm}_{eps}"
    data = load_dataset_data(dataset, prefix)
    ax.plot(data["etas"], data["pattern_sims"])
    ax.set_title(f"{name}", size=18)
    ax.tick_params(axis='both', labelsize=14)
axes[0].set_ylabel("Pattern Similarity", size=18)
fig.supxlabel("Eta", size=18)
fig.suptitle("Pattern Similarity vs Eta", size=24)

plt.tight_layout()
plt.savefig(f"figures/final/{'_'.join(datasets)}{prefix}_pattern_similarity.png")
plt.clf()

# # Loss

# plt.plot(etas, losses, label="Adversarial Model")
# plt.plot([etas[0], etas[-1]], [lossopt, lossopt], label="Optimal Model")

# plt.title(f"{dataset} Dataset")
# plt.xlabel("Eta")
# plt.ylabel("Loss of Adversarial Model on Original Dataset")
# plt.legend()

# plt.savefig(f'figures/final/{dataset}{prefix}_loss.png')
# plt.clf()

# Adversarial Loss (side-by-side datasets)

fig, axes = plt.subplots(1, len(datasets), figsize=(18, 5))
for ax, dataset, eps, name in zip(axes, datasets, eps_list, names):
    prefix = f"_{lossf_name}_{attack_norm}_{eps}"
    data = load_dataset_data(dataset, prefix)
    ax.plot(data["etas"], data["losses_adv"], label="Adversarial Model")
    ax.plot(data["etas"], data["losses_orig_adv"], label="Optimal Model")
    ax.set_title(f"{name}", size=18)
    ax.tick_params(axis='both', labelsize=14)

    if True or ax == axes[-1]:
        ax.legend(fontsize=14)

axes[0].set_ylabel("Transfer Adversarial Loss", size=18)
fig.supxlabel("Eta", size=18)
fig.suptitle("Transferred Adversarial Loss vs Eta", size=24)

plt.tight_layout()
plt.savefig(f"figures/final/{'_'.join(datasets)}{prefix}_advloss.png")
plt.clf()
