from RHOSolver import RHOSolver
from uniform_instance_gen import uni_instance_gen
from network.deepSubmudularFunc import DSFDeepSet
import numpy as np
import torch
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

if __name__ == '__main__':
    use_random = True
    if use_random:
        num_instances = 100
        all_net_random_init = []
        all_net_greedy_init = []
        all_net_cpgreedy_init = []
        all_net_grgc_init = []
        
        model =  DSFDeepSet()
        solver = RHOSolver(numofMachines=5, k=7,net=model)
        solver.loadNet('model/dsf75.pth')
        
        for i in range(num_instances):
            np.random.seed(i)
            times, machines = uni_instance_gen(10, 5, 1, 100)
            data = (times, machines)

            solver.reset(data)
            net_random_init = solver.solve(data, model='greedy', initChoose='random')
            
            solver.reset(data)
            net_greedy_init = solver.solve(data, model='greedy', initChoose='greedy')
            
            solver.reset(data)
            net_cpgreedy_init = solver.solve(data, model='greedy', initChoose='CPGreedy')

            solver.reset(data)
            net_grgc_init = solver.solve(data, model='greedy', initChoose='GRGC')
            
            all_net_random_init.append(net_random_init)
            all_net_greedy_init.append(net_greedy_init)
            all_net_cpgreedy_init.append(net_cpgreedy_init)
            all_net_grgc_init.append(net_grgc_init)
        
    else:
        data = np.load('benchmark/la/la01.npy',allow_pickle=True)
        data = data[0]
        model =  DSFDeepSet(input_dim=15)
        model.load_state_dict(torch.load('model/dsf75.pth'))
        solver = RHOSolver(numofMachines=5, k=5,net=model)
        
        solver.reset(data)
        net_random_init = solver.solve(data, model='net', initChoose='random')
        
        solver.reset(data)
        net_greedy_init = solver.solve(data, model='net', initChoose='greedy')
        
        solver.reset(data)
        net_cpgreedy_init = solver.solve(data, model='net', initChoose='CPGreedy')

        solver.reset(data)
        net_grgc_init = solver.solve(data, model='net', initChoose='GRGC')
