import os
import time
from functions.env import *
from functions.gen_instance import *
from functions.rob import *
from functions.plot import *
from functions.utils import brute_force, get_feasible_set, print_results, save_results

# Problem specification
SAH = [(20,5,12)]
ms = [(3,7,7)]
n_objss = [20,50]
ns = [(50,100),(500,1000)]
n_envs = 20

# hyperparameters rob
alpha = 1e-2
K = 3000
s = 1000

# save results
folder = 'results2/'
os.makedirs(folder, exist_ok = True)

# seed for reproducibility
np.random.seed(0)

results = {}
for (S,A,H) in SAH:
    for n_objs in n_objss:
        for (mD, mPC, mTC) in ms:
            for (n,N) in ns:
                n1 = n2 = nD = nPC = nTC = n
                N1 = N2 = ND = NPC = NTC = N

                init = time.time()
                setting = str((S,A,H))+str((n_objs,))+str((mD, mPC, mTC))+str((n,N))
                print('-'*10,setting)
                res = []
                for env_idx in range(n_envs):
                    if env_idx % 5 == 0:
                        print(env_idx)

                    # generate problem instance
                    (g,g_ex), (fD,fD_ex), (fPC,fPC_ex), (fTC,fTC_ex), r_star = gen_problem_instance(
                            S,A,H,n_objs,n1, n2, nD, nPC, nTC,
                            N1, N2, ND, NPC, NTC,mD, mPC, mTC, theta=H
                    )

                    # compute target
                    delta_J = np.dot(r_star,g_ex[0]-g_ex[1])

                    # compute solution through rob
                    x, I = rob(g, fD, fPC, fTC, S, A, H, n_objs, alpha, K, np.zeros(n_objs), s)

                    # append
                    res.append((delta_J,x,I))

                    print('rob: delta_J,x,I=',f"{delta_J:.2f}",f"{x:.2f}",f"{I:.2f}")
                
                curr = time.time()-init
                print(f"Tot time: {curr//60}m {curr%60}s")
                results[setting] = res
                
                # save
                np.save(folder+setting,res)