diff --git a/compare_arch.py b/compare_arch.py
deleted file mode 100644
index 6ea1e4c..0000000
--- a/compare_arch.py
+++ /dev/null
@@ -1,128 +0,0 @@
-import torch
-import numpy as np
-import matplotlib.pyplot as plt
-import os
-
-
-
-
-N_list = [100, 300, 500]
-
-P_list = [[(6,1,1,0), (6,2,1,0), (6,4,1,0)], [(6,1,1,-1), (6,2,1,-1), (6,4,1,-1)], [(6,1,1,-2), (6,2,1,-2), (6,4,1,-2)]]
-
-P_list = [[(10,1,1,0), (10,2,1,0), (10,4,1,0)], [(10,1,1,-1), (10,2,1,-1), (10,4,1,-1)], [(10,1,1,-2), (10,2,1,-2), (10,4,1,-2)]]
-
-# N_list = [100, 200, 300]
-
-# P_list = [[(10,1,1,1), (10,2,1,2), (10,3,1,0)], [(10,1,1,-1), (10,2,1,-1), (10,3,1,-1)], [(10,1,1,-1), (10,2,1,-1), (10,3,1,-1)]]
-
-
-
-
-model_id = 6
-
-plt.yscale("log")
-
-for N, P_ in zip(N_list, P_list):
-    rew_list = []
-    P_list = []
-    group_name = f"SK_1T_N_{N}"
-    for p in P_:
-        T_in, D_int, T_D, SEED = p
-        outpath = f"out/m_{model_id}/{group_name}/T_{T_in}_{T_D}_Di_{D_int}_S_{SEED}"
-        info = np.loadtxt(f"{outpath}/info.txt")
-        print(info.shape)
-        rew_last = info[-1,1]
-
-        P = (T_in*D_int + D_int)*T_D
-        rew_list.append(rew_last)
-        P_list.append(P)
-        plt.annotate(f"{T_in}-{D_int} ({T_D})", (P, rew_last), textcoords="offset points", xytext=(5, -10))
-    plt.plot(P_list, rew_list, marker = "o", label = f"N = {N}")
-
-plt.xlabel("number of parameters")
-plt.ylabel("reward (success rate)")
-plt.legend()
-plt.show()
-plt.close()
-
-
-        
-
-model_id = 4
-
-plt.figure(figsize=(5,3.5))
-
-N = 100
-
-group_name = f"SK_1T_N_{N}"
-
-model_id_list = [4,6]
-model_labels = ["cNPIM", "dNPIM"]
-
-all_list_list = []
-
-for lab, model_id in zip(model_labels, model_id_list):
-    outpath = f"out/m_{model_id}/{group_name}/"
-            
-    dirs = os.listdir(outpath)
-    print(dirs)
-    dirs = [dir for dir in dirs if not dir.startswith(".") ]
-
-    rew_list = []
-    P_list = []
-    all_list = []
-
-
-    for dir in dirs:
-        tokens = dir.split("_")
-        print(dir)
-        T_in, T_D, D_int, SEED = int(tokens[1]), int(tokens[2]), int(tokens[4]), int(tokens[6])
-        outpath = f"out/m_{model_id}/{group_name}/T_{T_in}_{T_D}_Di_{D_int}_S_{SEED}"
-        if(os.path.exists(f"{outpath}/info.txt") and SEED == 400):
-            info = np.loadtxt(f"{outpath}/info.txt")
-            print(info.shape)
-            rew_last = info[-1,1]
-            rew_last = info[-1,1]
-
-            P = (T_in*D_int + D_int)*T_D
-            rew_list.append(rew_last)
-            P_list.append(P)
-
-            all_list.append( (P, rew_last, T_in, D_int, T_D))
-
-            #plt.annotate(f"{T_in}-{D_int} ({T_D})", (P, rew_last), textcoords="offset points", xytext=(5, -10))
-    all_list_list.append(all_list)
-
-    plt.scatter(P_list, rew_list, label = lab)
-
-
-plt.xlim((0,150))
-#plt.title(f"N={N}")
-
-plt.legend()
-
-plt.xlabel("number of parameters", fontsize = 12)
-plt.ylabel("reward (success rate)", fontsize = 12)
-plt.tight_layout()
-plt.show()
-plt.close()
-
-
-all_list_prev = []
-for all_list in all_list_list:
-    all_list = sorted(all_list, key = lambda tup: tup[0])
-    
-    print("\n\n")
-    print(" $T_c$ & $D$ & $M$ & total parameters & cNPIM & dNPIM \hline")
-    for (P, rew_last, T_in, D_int, T_D), rew2 in zip(all_list, all_list_prev):
-        
-        print(f"{T_in} & {D_int} & {T_D} & {P} & {rew_last:.3f} & {rew2:.3f} \\\\")
-
-
-    print(" & ".join([str(tup[0]+1) for tup in all_list]))
-
-    all_list_prev = [tup[1] for tup in all_list]
-
-
-
diff --git a/config.json b/config.json
index 08f8004..6be2a1e 100644
--- a/config.json
+++ b/config.json
@@ -1 +1 @@
-{"numb_epochs": 400, "R": 400, "B": 20, "rew_type": "succ"}
+{"numb_epochs": 200, "R": 400, "B": 20, "rew_type": "succ"}
diff --git a/device_config.py b/device_config.py
index 861b57f..30b7bcb 100644
--- a/device_config.py
+++ b/device_config.py
@@ -1 +1 @@
-device = "cuda"
\ No newline at end of file
+device = "cpu"
\ No newline at end of file
diff --git a/eval.py b/eval.py
index 501aa54..cbc9651 100644
--- a/eval.py
+++ b/eval.py
@@ -98,7 +98,7 @@ def calc_reward(E_opt, GE, rho):
     return (E_opt <= GE + 0)*1.0 + (E_opt >= GE*0.5)*(-0.0)
 
 
-R = 1000*10
+R = 100
 
 B = instance_setup.numb_inst_config
 
diff --git a/fig3a.py b/fig3a.py
new file mode 100644
index 0000000..a626bd6
--- /dev/null
+++ b/fig3a.py
@@ -0,0 +1,56 @@
+import numpy as np
+import matplotlib.pyplot as plt
+
+
+path1 = "./out/m_0/SK_1T_N_800/T_6_3_Di_3_S_-4/rew.txt"
+path1 = "./out/m_4/SK_1T_N_800/T_8_3_Di_2_S_-2/rew.txt"
+path1 = "./eval/m_4_SK_1T_N_%i/eval_mlp_scaling/rew_centered.txt"
+
+path2 = "./out/m_-2/SK_1T_N_800/from_aws/rew.txt"
+path2 = "./eval/m_0_SK_1T_N_500/eval_mlp_discrete/rew_centered.txt"
+path2 = "./eval/m_4_SK_1T_N_%i/eval_mlp_scaling_ft/rew_centered.txt"
+path3 = "./eval/m_-2_SK_1T_N_%i/eval_cac_scaling/rew_centered.txt"
+
+
+path_list = [path1, path2, path3]
+
+labels = [r"trained on $N=100$", r"fined tuned on $N=500$", "CAC (IM baseline)"]
+
+#path2 = "./eval/m_4_SK_1T_N_500/eval_mlp_cont/rew_centered.txt"
+N_list = [100,300,500,800]
+
+
+medians = [[] for p in path_list]
+p25 = [[] for p in path_list]
+p75 = [[] for p in path_list]
+
+
+for idx_N, N in enumerate(N_list):
+    T = 1*N
+    for idx_p, path in enumerate(path_list):
+        rew = np.loadtxt(path % N)
+        rew = np.maximum(0.0002, rew)
+        tts = T*np.log(0.01)/np.log(1 - rew)
+        medians[idx_p].append(np.median(tts))
+        p25[idx_p].append(np.percentile(tts, 25))
+        p75[idx_p].append(np.percentile(tts, 75))
+    
+
+plt.figure(figsize=(5,3.5))
+
+plt.xlabel("N (problem size)", fontsize = 12)
+plt.ylabel("median time to solution (TTS)", fontsize = 12)
+plt.yscale("log")
+for idx_p, (label, path) in enumerate(zip(labels, path_list)):
+    if(label.startswith("CAC")):
+        plt.plot(N_list, medians[idx_p], label = label, color = "black", dashes = [3,3], marker = "+")
+    else:
+        plt.plot(N_list, medians[idx_p], label = label, marker = "o")
+
+
+plt.legend()
+
+plt.tight_layout()
+
+plt.show()
+plt.close()
\ No newline at end of file
diff --git a/fig3a_commands.sh b/fig3a_commands.sh
new file mode 100644
index 0000000..12481e5
--- /dev/null
+++ b/fig3a_commands.sh
@@ -0,0 +1,57 @@
+#figure scaling
+
+#configure hyper-parameters
+echo '{"numb_epochs": 200, "R": 200, "B": 20, "rew_type": "succ"}' > config.json
+
+#tune initial parameters
+python main.py 8 3 3 0 SK_1T_N_100 4
+
+#save tuned parameters
+mv ./boot/m_4_T_8_3_Di_3/current/ ./boot/m_4_T_8_3_Di_3/v1
+
+#fine tune
+echo '{"numb_epochs": 100, "R": 200, "B": 20, "rew_type": "succ"}' > config.json
+python main.py 8 3 3 -1 SK_1T_N_500 4
+
+
+python eval.py 8 3 3 -1 SK_1T_N_500 4 SK_1T_N_800 eval_mlp_scaling_ft
+python eval.py 8 3 3 -1 SK_1T_N_500 4 SK_1T_N_500 eval_mlp_scaling_ft
+python eval.py 8 3 3 -1 SK_1T_N_500 4 SK_1T_N_300 eval_mlp_scaling_ft
+python eval.py 8 3 3 -1 SK_1T_N_500 4 SK_1T_N_100 eval_mlp_scaling_ft
+
+
+python eval.py 8 3 3 0 SK_1T_N_100 4 SK_1T_N_800 eval_mlp_scaling
+python eval.py 8 3 3 0 SK_1T_N_100 4 SK_1T_N_500 eval_mlp_scaling
+python eval.py 8 3 3 0 SK_1T_N_100 4 SK_1T_N_300 eval_mlp_scaling
+python eval.py 8 3 3 0 SK_1T_N_100 4 SK_1T_N_100 eval_mlp_scaling
+
+
+
+
+#CAC tuning
+#configure hyper-parameters
+echo '{"numb_epochs": 200, "R": 200, "B": 20, "rew_type": "succ"}' > config.json
+
+#tune initial parameters
+python main.py 5 1 3 0 SK_1T_N_100 -2
+
+#save tuned parameters
+mv ./boot/m_-2_T_5_3_Di_1/current/ ./boot/m_-2_T_5_3_Di_1/v1
+
+#fine tune
+echo '{"numb_epochs": 100, "R": 200, "B": 20, "rew_type": "succ"}' > config.json
+python main.py 5 1 3 -1 SK_1T_N_500 -2
+
+
+python eval.py 5 1 3 -1 SK_1T_N_800 -2 SK_1T_N_100 eval_cac_scaling
+python eval.py 5 1 3 -1 SK_1T_N_800 -2 SK_1T_N_300 eval_cac_scaling
+python eval.py 5 1 3 -1 SK_1T_N_800 -2 SK_1T_N_500 eval_cac_scaling
+python eval.py 5 1 3 -1 SK_1T_N_800 -2 SK_1T_N_800 eval_cac_scaling
+python eval.py 5 1 3 -1 SK_1T_N_800 -2 SK_1T_N_1000 eval_cac_scaling
+
+python fig3a.py
+
+
+
+
+
diff --git a/fig3c.py b/fig3c.py
new file mode 100644
index 0000000..d4ceff3
--- /dev/null
+++ b/fig3c.py
@@ -0,0 +1,135 @@
+import torch
+import numpy as np
+import matplotlib.pyplot as plt
+import os
+import sys
+
+
+
+N_list = [100, 300, 500]
+
+
+
+P_list = [[],[],[]]
+# N_list = [100, 200, 300]
+
+# P_list = [[(10,1,1,1), (10,2,1,2), (10,3,1,0)], [(10,1,1,-1), (10,2,1,-1), (10,3,1,-1)], [(10,1,1,-1), (10,2,1,-1), (10,3,1,-1)]]
+if(len(sys.argv) > 1):
+    if(sys.argv[1] == '1'):
+        P_list = [[(6,1,1,0), (6,2,1,0), (6,4,1,0)], [(6,1,1,-1), (6,2,1,-1), (6,4,1,-1)], [(6,1,1,-2), (6,2,1,-2), (6,4,1,-2)]]
+    if(sys.argv[1] == '2'):
+        P_list = [[(10,1,1,401), (10,2,1,401), (10,4,1,401)], [(10,1,1,-401), (10,2,1,-401), (10,4,1,-401)], [(10,1,1,-402), (10,2,1,-402), (10,4,1,-402)]]
+
+#print(sys.argv, P_list)
+
+
+
+
+model_id = 4
+
+plt.yscale("log")
+
+for N, P_ in zip(N_list, P_list):
+    rew_list = []
+    P_list = []
+    group_name = f"SK_1T_N_{N}"
+    for p in P_:
+        T_in, D_int, T_D, SEED = p
+        outpath = f"out/m_{model_id}/{group_name}/T_{T_in}_{T_D}_Di_{D_int}_S_{SEED}"
+        info = np.loadtxt(f"{outpath}/info.txt")
+        print(info.shape)
+        rew_last = info[-1,1]
+
+        P = (T_in*D_int + D_int)*T_D
+        rew_list.append(rew_last)
+        P_list.append(P)
+        plt.annotate(f"{T_in}-{D_int} ({T_D})", (P, rew_last), textcoords="offset points", xytext=(5, -10))
+    plt.plot(P_list, rew_list, marker = "o", label = f"N = {N}")
+
+plt.xlabel("number of parameters")
+plt.ylabel("reward (success rate)")
+plt.legend()
+plt.show()
+plt.close()
+
+
+        
+
+model_id = 4
+
+plt.figure(figsize=(5,3.5))
+
+N = 100
+
+group_name = f"SK_1T_N_{N}"
+
+model_id_list = [4,6]
+model_labels = ["cNPIM", "dNPIM"]
+
+all_list_list = []
+
+for lab, model_id in zip(model_labels, model_id_list):
+    outpath = f"out/m_{model_id}/{group_name}/"
+    
+    dirs = os.listdir(outpath)
+    print(dirs)
+    dirs = [dir for dir in dirs if not dir.startswith(".") ]
+
+    rew_list = []
+    P_list = []
+    all_list = []
+
+
+    for dir in dirs:
+        tokens = dir.split("_")
+        print(dir)
+        T_in, T_D, D_int, SEED = int(tokens[1]), int(tokens[2]), int(tokens[4]), int(tokens[6])
+        outpath = f"out/m_{model_id}/{group_name}/T_{T_in}_{T_D}_Di_{D_int}_S_{SEED}"
+        if(os.path.exists(f"{outpath}/info.txt") and SEED == 400):
+            info = np.loadtxt(f"{outpath}/info.txt")
+            print(info.shape)
+            rew_last = info[-1,1]
+            rew_last = info[-1,1]
+
+            P = (T_in*D_int + D_int)*T_D
+            rew_list.append(rew_last)
+            P_list.append(P)
+            print(outpath, rew_last)
+            all_list.append( (P, rew_last, T_in, D_int, T_D))
+
+            #plt.annotate(f"{T_in}-{D_int} ({T_D})", (P, rew_last), textcoords="offset points", xytext=(5, -10))
+    all_list_list.append(all_list)
+
+    plt.scatter(P_list, rew_list, label = lab)
+
+
+plt.xlim((0,150))
+#plt.title(f"N={N}")
+
+plt.legend()
+
+plt.xlabel("number of parameters", fontsize = 12)
+plt.ylabel("reward (success rate)", fontsize = 12)
+plt.tight_layout()
+plt.show()
+plt.close()
+
+
+all_list_prev = []
+for all_list in all_list_list:
+
+    all_list = sorted(all_list, key = lambda tup: tup[0])
+    
+    print("\n\n")
+    print(" $T_c$ & $D$ & $M$ & total parameters & cNPIM & dNPIM \hline")
+    for (P, rew_last, T_in, D_int, T_D), rew2 in zip(all_list, all_list_prev):
+        
+        print(f"{T_in} & {D_int} & {T_D} & {P} & {rew_last:.3f} & {rew2:.3f} \\\\")
+
+
+    print(" & ".join([str(tup[0]+1) for tup in all_list]))
+
+    all_list_prev = [tup[1] for tup in all_list]
+
+
+
diff --git a/fig3c_commands.sh b/fig3c_commands.sh
new file mode 100644
index 0000000..e5cbe3b
--- /dev/null
+++ b/fig3c_commands.sh
@@ -0,0 +1,51 @@
+#figure arch comparison
+
+#configure hyper-parameters
+echo '{"numb_epochs": 400, "R": 400, "B": 20, "rew_type": "succ"}' > config.json
+
+#(seed 400 used)
+
+python main.py 4 1 1 400 SK_1T_N_100 4 
+python main.py 4 1 3 400 SK_1T_N_100 4 
+python main.py 4 3 1 400 SK_1T_N_100 4 
+python main.py 4 3 3 400 SK_1T_N_100 4 
+python main.py 4 3 5 400 SK_1T_N_100 4 
+
+python main.py 8 1 1 400 SK_1T_N_100 4 
+python main.py 8 1 3 400 SK_1T_N_100 4 
+python main.py 8 3 1 400 SK_1T_N_100 4 
+python main.py 8 3 3 400 SK_1T_N_100 4 
+python main.py 8 3 5 400 SK_1T_N_100 4 
+
+python main.py 12 1 1 400 SK_1T_N_100 4 
+python main.py 12 1 3 400 SK_1T_N_100 4 
+python main.py 12 3 1 400 SK_1T_N_100 4 
+python main.py 12 3 3 400 SK_1T_N_100 4 
+python main.py 12 3 5 400 SK_1T_N_100 4 
+
+
+
+python main.py 4 1 1 400 SK_1T_N_100 6 
+python main.py 4 1 3 400 SK_1T_N_100 6 
+python main.py 4 3 1 400 SK_1T_N_100 6 
+python main.py 4 3 3 400 SK_1T_N_100 6 
+python main.py 4 3 5 400 SK_1T_N_100 6 
+
+python main.py 8 1 1 400 SK_1T_N_100 6 
+python main.py 8 1 3 400 SK_1T_N_100 6 
+python main.py 8 3 1 400 SK_1T_N_100 6 
+python main.py 8 3 3 400 SK_1T_N_100 6 
+python main.py 8 3 5 400 SK_1T_N_100 6 
+
+python main.py 12 1 1 400 SK_1T_N_100 6 
+python main.py 12 1 3 400 SK_1T_N_100 6 
+python main.py 12 3 1 400 SK_1T_N_100 6 
+python main.py 12 3 3 400 SK_1T_N_100 6 
+python main.py 12 3 5 400 SK_1T_N_100 6 
+
+python fig3c.py
+
+
+
+
+
diff --git a/fig_arch_commands.sh b/fig_arch_commands.sh
new file mode 100644
index 0000000..2f017b7
--- /dev/null
+++ b/fig_arch_commands.sh
@@ -0,0 +1,46 @@
+#figure arch comparison
+
+#configure hyper-parameters
+echo '{"numb_epochs": 400, "R": 400, "B": 20, "rew_type": "succ"}' > config.json
+
+#(seed 401 used)
+
+python main.py 10 1 1 401 SK_1T_N_100 4 
+python main.py 10 2 1 401 SK_1T_N_100 4 
+python main.py 10 4 1 401 SK_1T_N_100 4 
+
+#save tuned parameters
+mv ./boot/m_4_T_10_1_Di_1/current/ ./boot/m_4_T_10_1_Di_1/v401
+mv ./boot/m_4_T_10_1_Di_2/current/ ./boot/m_4_T_10_1_Di_2/v401
+mv ./boot/m_4_T_10_1_Di_4/current/ ./boot/m_4_T_10_1_Di_4/v401
+
+
+echo '{"numb_epochs": 200, "R": 400, "B": 20, "rew_type": "succ"}' > config.json
+
+
+#fine tune for larger problem sizes
+
+python main.py 10 1 1 -401 SK_1T_N_300 4 
+python main.py 10 2 1 -401 SK_1T_N_300 4 
+python main.py 10 4 1 -401 SK_1T_N_300 4 
+
+#save tuned parameters
+mv ./boot/m_4_T_10_1_Di_1/current/ ./boot/m_4_T_10_1_Di_1/v402
+mv ./boot/m_4_T_10_1_Di_2/current/ ./boot/m_4_T_10_1_Di_2/v402
+mv ./boot/m_4_T_10_1_Di_4/current/ ./boot/m_4_T_10_1_Di_4/v402
+
+
+#fine tune for larger problem sizes
+
+python main.py 10 1 1 -402 SK_1T_N_500 4 
+python main.py 10 2 1 -402 SK_1T_N_500 4 
+python main.py 10 4 1 -402 SK_1T_N_500 4 
+
+
+
+python fig3c.py 2
+
+
+
+
+
diff --git a/main.py b/main.py
index 085fc78..01f1b51 100644
--- a/main.py
+++ b/main.py
@@ -334,7 +334,7 @@ for epoch in range(numb_epochs):
         p_L.data = p_L.data*(10/torch.clamp(torch.mean(p_L.data**2, axis = 1)**0.5, 10, 1000).reshape(-1,1))
         # print(torch.sum(p_L.data**2, axis = 1)**0.5)
         # print(torch.clamp(torch.sum(p_L.data**2, axis = 1)**0.5, 1, 1000))
-
+        
         p_x.grad.zero_()
         p_L.grad.zero_()
     
diff --git a/scaling.py b/scaling.py
deleted file mode 100644
index a626bd6..0000000
--- a/scaling.py
+++ /dev/null
@@ -1,56 +0,0 @@
-import numpy as np
-import matplotlib.pyplot as plt
-
-
-path1 = "./out/m_0/SK_1T_N_800/T_6_3_Di_3_S_-4/rew.txt"
-path1 = "./out/m_4/SK_1T_N_800/T_8_3_Di_2_S_-2/rew.txt"
-path1 = "./eval/m_4_SK_1T_N_%i/eval_mlp_scaling/rew_centered.txt"
-
-path2 = "./out/m_-2/SK_1T_N_800/from_aws/rew.txt"
-path2 = "./eval/m_0_SK_1T_N_500/eval_mlp_discrete/rew_centered.txt"
-path2 = "./eval/m_4_SK_1T_N_%i/eval_mlp_scaling_ft/rew_centered.txt"
-path3 = "./eval/m_-2_SK_1T_N_%i/eval_cac_scaling/rew_centered.txt"
-
-
-path_list = [path1, path2, path3]
-
-labels = [r"trained on $N=100$", r"fined tuned on $N=500$", "CAC (IM baseline)"]
-
-#path2 = "./eval/m_4_SK_1T_N_500/eval_mlp_cont/rew_centered.txt"
-N_list = [100,300,500,800]
-
-
-medians = [[] for p in path_list]
-p25 = [[] for p in path_list]
-p75 = [[] for p in path_list]
-
-
-for idx_N, N in enumerate(N_list):
-    T = 1*N
-    for idx_p, path in enumerate(path_list):
-        rew = np.loadtxt(path % N)
-        rew = np.maximum(0.0002, rew)
-        tts = T*np.log(0.01)/np.log(1 - rew)
-        medians[idx_p].append(np.median(tts))
-        p25[idx_p].append(np.percentile(tts, 25))
-        p75[idx_p].append(np.percentile(tts, 75))
-    
-
-plt.figure(figsize=(5,3.5))
-
-plt.xlabel("N (problem size)", fontsize = 12)
-plt.ylabel("median time to solution (TTS)", fontsize = 12)
-plt.yscale("log")
-for idx_p, (label, path) in enumerate(zip(labels, path_list)):
-    if(label.startswith("CAC")):
-        plt.plot(N_list, medians[idx_p], label = label, color = "black", dashes = [3,3], marker = "+")
-    else:
-        plt.plot(N_list, medians[idx_p], label = label, marker = "o")
-
-
-plt.legend()
-
-plt.tight_layout()
-
-plt.show()
-plt.close()
\ No newline at end of file
