import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import os

sns.set_style("whitegrid")

def get_filename(dir):
    files = []
    for file in os.listdir(dir):
        files.append(file)
    return files

comix_dir = "/home/wangdongzi/Desktop/RDHNet/exp_data/pp/6a/comix/"
hpn_dir = "/home/wangdongzi/Desktop/RDHNet/exp_data/pp/6a/hpn/"
# dimenet_dir = "/home/wangdongzi/Desktop/RDHNet/exp_data/pp/3a/dimenet/"

length = 1000

comix_files = get_filename(comix_dir)
comix_data = []
for comix_file in comix_files:
    data = pd.read_csv(comix_dir + comix_file).values[:length,:2]
    comix_data.append(data)


hpn_files = get_filename(hpn_dir)
hpn_data = []
for hpn_file in hpn_files:
    data = pd.read_csv(hpn_dir + hpn_file).values[:length,:2]
    hpn_data.append(data)

# dimenet_files = get_filename(hpn_dir)
# dimenet_data = []
# for dimenet_file in dimenet_files:
#     data = pd.read_csv(dimenet_dir + dimenet_file).values[:length,:2]
#     dimenet_data.append(data)



comix_data = np.vstack(comix_data)
hpn_data = np.vstack(hpn_data)
# dimenet_data = np.vstack(dimenet_data)

# print(comix_data)
color = []

sns.lineplot(x=comix_data[:,0], y=comix_data[:,1], label="comix", color="olivedrab")
sns.lineplot(x=hpn_data[:,0], y=hpn_data[:,1], label="comix_HPN", color="orangered")
# sns.lineplot(x=dimenet_data[:,0], y=dimenet_data[:,1], label="comix_RDHNet", color="purple")


plt.title("6 predator vs 2 prey")
plt.xlabel("step")
plt.ylabel("return")
plt.savefig("3 algorithm 6a.pdf", dpi=400)

print("ok")