import os
import time
from functions.env import *
from functions.gen_instance import *
from functions.rob import *
from functions.plot import *
from functions.utils import brute_force, get_feasible_set, print_results, save_results

# Problem specification
SAH = [(3,3,5), (100,10,15)]
ms = [(1,2,2),(5,15,15)]
n_objss = [4]
ns = [(50,100),(500,1000)]
n_envs = 20

# hyperparameters rob
alpha = 1e-2
K = 3000
s = 1000

# save results
folder = 'results/'
os.makedirs(folder, exist_ok = True)

# seed for reproducibility
np.random.seed(0)

results = {}
for (S,A,H) in SAH:
    for n_objs in n_objss:
        for (mD, mPC, mTC) in ms:
            for (n,N) in ns:
                n1 = n2 = nD = nPC = nTC = n
                N1 = N2 = ND = NPC = NTC = N

                init = time.time()
                setting = str((S,A,H))+str((n_objs,))+str((mD, mPC, mTC))+str((n,N))
                print('-'*10,setting)
                res = []
                for env_idx in range(n_envs):
                    if env_idx % 5 == 0:
                        print(env_idx)

                    # generate problem instance
                    (g,g_ex), (fD,fD_ex), (fPC,fPC_ex), (fTC,fTC_ex), r_star = gen_problem_instance(
                            S,A,H,n_objs,n1, n2, nD, nPC, nTC,
                            N1, N2, ND, NPC, NTC,mD, mPC, mTC, theta=H
                    )

                    # compute solution through "exact" method
                    fs = get_feasible_set(fD_ex,fPC_ex,fTC_ex)
                    M_true, _, m_true, _ = brute_force(fs, g_ex)
                    x_true = (M_true+m_true) / 2
                    I_true = (M_true-m_true) / 2

                    # compute solution through rob
                    x, I = rob(g, fD, fPC, fTC, S, A, H, n_objs, alpha, K, np.zeros(n_objs), s)

                    # append
                    res.append((x_true,I_true,x,I))

                    print('true: x,I=',f"{x_true:.2f}",f"{I_true:.2f}",', rob: x,I=',f"{x:.2f}",f"{I:.2f}")
                
                curr = time.time()-init
                print(f"Tot time: {curr//60}m {curr%60}s")
                results[setting] = res
                
                # save
                np.save(folder+setting,res)

# print
print_results(results)