import numpy as np
from CPSolver import CPSolver
from uniform_instance_gen import uni_instance_gen
from memory import Memory
import time
from tqdm import tqdm
from netTools import getJobFeature
from  schedule import Schedule
from params import configs


def generate_memory_data(num_instances=1000, n_jobs=5, m_machines=5, memory_size=100000,time_limit=100):
    memory = Memory(num_machines=m_machines, capacity=memory_size)
    for i in tqdm(range(num_instances)):
        data = uni_instance_gen(n_jobs, m_machines, 1, 100)
        ptMask = np.zeros((n_jobs, m_machines))
        for row in range(n_jobs):
            if np.random.random() < 0.03:
                data[0][row] = [0] * m_machines
                data[1][row] = [0] * m_machines
                for j in range(m_machines):
                    ptMask[row, j] = -1
        OCMachine = set()
        for row in range(n_jobs):
            if np.random.random() < 0.2:
                zero_ops = np.random.randint(1, m_machines-1)
                if data[0][row][zero_ops] == 0:
                    continue
                if data[1][row][zero_ops] not in OCMachine:
                    OCMachine.add(data[1][row][zero_ops])
                else:
                    continue
                data[0][row][:zero_ops] = [0] * zero_ops
                data[1][row][:zero_ops] = [0] * zero_ops
                for i in range(zero_ops):
                    ptMask[row, i] = -1
                ptMask[row, zero_ops ] = np.random.randint(1, data[0][row][zero_ops]+1)
        cpSolver = CPSolver()
        schedule = cpSolver.solve_blocking_job_shop(data,ptMask,time_limit=time_limit)
        temtime = time_limit
        while schedule is None:
            temtime += time_limit
            schedule = cpSolver.solve_blocking_job_shop(data,ptMask,time_limit=temtime)
            if schedule is None and temtime > 1000:
                break
        jobs_features = getJobFeature(data, ptMask,range(n_jobs))
        memory.push(jobs_features, schedule.cal_utilization())

    return memory


if __name__ == "__main__":
    memory = generate_memory_data(
        num_instances=configs.gen_instance_num,
        n_jobs=configs.gen_job_num,
        m_machines=configs.gen_machine_num,
        memory_size=200000,
        time_limit=configs.gen_time_limit
    )
    memory.save('TrainData' +str(configs.gen_instance_num) + '_' + str(configs.gen_job_num) + '_' + str(configs.gen_machine_num) + '_' + str(configs.gen_time_limit))