import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import rcParams
import seaborn as sns
import cupy as cp

csfont = {"fontname": "Times New Roman"}
rcParams["font.weight"] = "normal"


def jaccard_score(class1_samples, class2_samples):

    if len(class1_samples) == 0 or len(class2_samples) == 0:
        raise ValueError("At least one class has no samples")

    distances = []

    for sample1 in class1_samples:
        distance = cp.min(cp.array(sample1 - class2_samples) ** 2)
        if distance < 1e-8:
            flag = 1
        else:
            flag = 0
        distances.append(flag)

    # Calculate the distance between the two classes as the average of all sample pairs' distances
    distances = cp.array(distances)
    mean_distance = cp.mean(distances)

    return mean_distance


def score(file_path):
    list = []
    path = file_path
    data = np.load("toy.npy", allow_pickle=True).item()
    truth = cp.array(data["actions"].numpy())
    # print(truth)
    policy_actions = cp.array(np.load(file_path, allow_pickle=True))
    # print(policy_actions)
    policy_score = jaccard_score(policy_actions, truth)

    return policy_score


def create_df(path, name, T):
    data = score(path)
    data = cp.ndarray.get(data)
    swap = {"Algorithm": [name], "T": [T], "score": data}
    extra_data = pd.DataFrame({"Algorithm": [name], "T": 0, "score": 0})
    data_frame = pd.concat([pd.DataFrame(swap), extra_data], ignore_index=True)
    return data_frame


columns = ["Algorithm", "T", "score"]
df = pd.DataFrame(columns=columns)


diffcps5 = create_df("diffcps5.npy", "DiffCPS", 5)
df = pd.concat([df, diffcps5], ignore_index=True)
diffcps15 = create_df("diffcps15.npy", "DiffCPS", 15)
df = pd.concat([df, diffcps15], ignore_index=True)
diffcps25 = create_df("diffcps25.npy", "DiffCPS", 25)
df = pd.concat([df, diffcps25], ignore_index=True)
diffcps30 = create_df("diffcps30.npy", "DiffCPS", 30)
df = pd.concat([df, diffcps30], ignore_index=True)
diffcps50 = create_df("diffcps50.npy", "DiffCPS", 50)
df = pd.concat([df, diffcps50], ignore_index=True)

dql = create_df("dql5.npy", "DQL", 5)
df = pd.concat([df, dql], ignore_index=True)
dql = create_df("dql15.npy", "DQL", 15)
df = pd.concat([df, dql], ignore_index=True)
dql = create_df("dql25.npy", "DQL", 25)
df = pd.concat([df, dql], ignore_index=True)
dql = create_df("dql30.npy", "DQL", 30)
df = pd.concat([df, dql], ignore_index=True)
dql = create_df("dql50.npy", "DQL", 50)
df = pd.concat([df, dql], ignore_index=True)

# dql = create_df("awr.npy", "AWR", 5)
# df = pd.concat([df, dql], ignore_index=True)
# dql = create_df("awr.npy", "AWR", 15)
# df = pd.concat([df, dql], ignore_index=True)
# dql = create_df("awr.npy", "AWR", 25)
# df = pd.concat([df, dql], ignore_index=True)
# dql = create_df("awr.npy", "AWR", 30)
# df = pd.concat([df, dql], ignore_index=True)
# dql = create_df("awr.npy", "AWR", 50)
# df = pd.concat([df, dql], ignore_index=True)


sfbc5 = create_df("sfbc5.npy", "SfBC", 5)
df = pd.concat([df, sfbc5], ignore_index=True)
sfbc15 = create_df("sfbc15.npy", "SfBC", 15)
df = pd.concat([df, sfbc15], ignore_index=True)
sfbc25 = create_df("sfbc25.npy", "SfBC", 25)
df = pd.concat([df, sfbc25], ignore_index=True)
sfbc30 = create_df("sfbc30.npy", "SfBC", 30)
df = pd.concat([df, sfbc30], ignore_index=True)
sfbc50 = create_df("sfbc50.npy", "SfBC", 50)
df = pd.concat([df, sfbc50], ignore_index=True)
# df.to_csv("df.csv", index=False)

# print(df)

sns.set_context("poster")

plt.figure(figsize=(15, 10))

sns.set_theme(style="white", font_scale=3.2)


colors = ["red", "forestgreen", "purple", "mediumblue", "darkmagenta"]

sns.lineplot(
    data=df,
    x="T",
    y="score",
    hue="Algorithm",
    palette=colors,
    marker="o",
    markersize=10,
)
plt.axhline(y=0.432, color="purple", linestyle="--", label="AWR")
plt.title("", fontsize=39, pad=20)
plt.xlabel("Diffusion steps")
plt.ylabel("Score")
plt.tight_layout()
plt.savefig("toy_lineplot.pdf")
