import torch
import numpy as np
import matplotlib.pyplot as plt
import os
import sys



N_list = [100, 300, 500]



P_list = [[],[],[]]
# N_list = [100, 200, 300]

# P_list = [[(10,1,1,1), (10,2,1,2), (10,3,1,0)], [(10,1,1,-1), (10,2,1,-1), (10,3,1,-1)], [(10,1,1,-1), (10,2,1,-1), (10,3,1,-1)]]
if(len(sys.argv) > 1):
    if(sys.argv[1] == '1'):
        P_list = [[(6,1,1,0), (6,2,1,0), (6,4,1,0)], [(6,1,1,-1), (6,2,1,-1), (6,4,1,-1)], [(6,1,1,-2), (6,2,1,-2), (6,4,1,-2)]]
    if(sys.argv[1] == '2'):
        P_list = [[(10,1,1,401), (10,2,1,401), (10,4,1,401)], [(10,1,1,-401), (10,2,1,-401), (10,4,1,-401)], [(10,1,1,-402), (10,2,1,-402), (10,4,1,-402)]]

#print(sys.argv, P_list)




model_id = 4

plt.yscale("log")

for N, P_ in zip(N_list, P_list):
    rew_list = []
    P_list = []
    group_name = f"SK_1T_N_{N}"
    for p in P_:
        T_in, D_int, T_D, SEED = p
        outpath = f"out/m_{model_id}/{group_name}/T_{T_in}_{T_D}_Di_{D_int}_S_{SEED}"
        info = np.loadtxt(f"{outpath}/info.txt")
        print(info.shape)
        rew_last = info[-1,1]

        P = (T_in*D_int + D_int)*T_D
        rew_list.append(rew_last)
        P_list.append(P)
        plt.annotate(f"{T_in}-{D_int} ({T_D})", (P, rew_last), textcoords="offset points", xytext=(5, -10))
    plt.plot(P_list, rew_list, marker = "o", label = f"N = {N}")

plt.xlabel("number of parameters")
plt.ylabel("reward (success rate)")
plt.legend()
plt.show()
plt.close()


        

model_id = 4

plt.figure(figsize=(5,3.5))

N = 100

group_name = f"SK_1T_N_{N}"

model_id_list = [4,6]
model_labels = ["cNPIM", "dNPIM"]

all_list_list = []

for lab, model_id in zip(model_labels, model_id_list):
    outpath = f"out/m_{model_id}/{group_name}/"
    
    dirs = os.listdir(outpath)
    print(dirs)
    dirs = [dir for dir in dirs if not dir.startswith(".") ]

    rew_list = []
    P_list = []
    all_list = []


    for dir in dirs:
        tokens = dir.split("_")
        print(dir)
        T_in, T_D, D_int, SEED = int(tokens[1]), int(tokens[2]), int(tokens[4]), int(tokens[6])
        outpath = f"out/m_{model_id}/{group_name}/T_{T_in}_{T_D}_Di_{D_int}_S_{SEED}"
        if(os.path.exists(f"{outpath}/info.txt") and SEED == 400):
            info = np.loadtxt(f"{outpath}/info.txt")
            print(info.shape)
            rew_last = info[-1,1]
            rew_last = info[-1,1]

            P = (T_in*D_int + D_int)*T_D
            rew_list.append(rew_last)
            P_list.append(P)
            print(outpath, rew_last)
            all_list.append( (P, rew_last, T_in, D_int, T_D))

            #plt.annotate(f"{T_in}-{D_int} ({T_D})", (P, rew_last), textcoords="offset points", xytext=(5, -10))
    all_list_list.append(all_list)

    plt.scatter(P_list, rew_list, label = lab)


plt.xlim((0,150))
#plt.title(f"N={N}")

plt.legend()

plt.xlabel("number of parameters", fontsize = 12)
plt.ylabel("reward (success rate)", fontsize = 12)
plt.tight_layout()
plt.show()
plt.close()


all_list_prev = []
for all_list in all_list_list:

    all_list = sorted(all_list, key = lambda tup: tup[0])
    
    print("\n\n")
    print(" $T_c$ & $D$ & $M$ & total parameters & cNPIM & dNPIM \hline")
    for (P, rew_last, T_in, D_int, T_D), rew2 in zip(all_list, all_list_prev):
        
        print(f"{T_in} & {D_int} & {T_D} & {P} & {rew_last:.3f} & {rew2:.3f} \\\\")


    print(" & ".join([str(tup[0]+1) for tup in all_list]))

    all_list_prev = [tup[1] for tup in all_list]



