

import torch
import numpy as np
from CPSolver import CPSolver
from schedule import Schedule
import os
import time
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

class TRHOSolver:

    def __init__(self,numofMachines,width=10):
        self.width = min(width,numofMachines)
        self.PTmask = None
        self.curSelect = None
        self.mask = None
        self.numofJobs = None
        self.numofMachines = numofMachines
        self.CPSolver = CPSolver()

    def reset(self,data):
        assert len(data[0][0]) == self.numofMachines
        self.numofJobs = len(data[0])
        self.mask = np.zeros(self.numofJobs)
        self.curSelect = []
        self.PTmask = np.zeros((self.numofJobs,self.numofMachines))
    

    def solve(self,data,PTmask = None,bws =True,time_limit = 100, global_time_limit = None):
        cpSolver = CPSolver()
        numofJobs = len(data[0])
        numofMachines = len(data[0][0])
        self.reset(data)
        if PTmask is not None:
            self.PTmask = PTmask
        totalSchedule = Schedule(numofJobs,numofMachines)
        solveCount = 0
        baseTime = 0
        solveTime = 0
        global_start_time = time.time() if global_time_limit is not None else None

        while solveCount < numofJobs:
            if global_time_limit is not None:
                elapsed = time.time() - global_start_time
                remaining = global_time_limit - elapsed
                if remaining <0:
                    cp_time_limit = 1
                cp_time_limit = max(1, int(min(time_limit, remaining)))
            else:
                cp_time_limit = time_limit

            schedule = cpSolver.solve_blocking_job_shop(data, self.PTmask, bws=bws, time_limit=cp_time_limit)
            clip = (20 +80 * solveTime/numofJobs)/100
            clipTime = int(clip * schedule.cal_makespan())
            minTime = 0
            for tem_job, machine, start_time, end_time in schedule.record:
 
                if start_time != end_time and machine == numofMachines-1 and end_time > clipTime:
                    minTime = min(minTime,end_time)
            clipTime = max(clipTime,minTime + 1)
            for record in schedule.record:
                if  record[2] >= clipTime or record[3] == record[2]:
                    continue
                totalSchedule.add_record(int(record[0]),record[1],baseTime+record[2],baseTime+record[3])
            elapsed = time.time() - global_start_time
            remaining = global_time_limit - elapsed
            if solveCount / numofJobs >= 0.6 or remaining <0  :
                for record in schedule.record:
                    if  record[3] == record[2]:
                        continue
                    totalSchedule.add_record(int(record[0]),record[1],baseTime+record[2],baseTime+record[3])
                break
            for tem_job, machine, start_time, end_time in schedule.record:
                job = tem_job
                if (end_time - start_time) == 0:
                    continue
                if end_time <= clipTime:
                    self.PTmask[job][machine] = -1
                elif end_time > clipTime > start_time:
                    self.PTmask[job][machine] = data[0][job][machine] - clipTime + start_time
                
            
            baseTime = baseTime + clipTime
            solveCount = 0
            for job in range(numofJobs):
                flag = True
                for machine in range(numofMachines):
                    if self.PTmask[job][machine] != -1:
                        flag = False
                        break
                if flag:
                    solveCount += 1
            solveTime = solveTime + 1

        makespan = totalSchedule.cal_makespan()
        return totalSchedule




if __name__ == "__main__":
    la_instances = [f'la{i:02d}' for i in range(1, 41)]
    
    for instance_name in la_instances:
        instance_path = f'benchmark/la/{instance_name}.npy'
        
        if not os.path.exists(instance_path):
            continue
            
        try:
            data = np.load(instance_path, allow_pickle=True)
            data = data[0]
            solver = TRHOSolver(numofMachines=len(data[0][0]))
            solver.reset(data)
            
            start_time = time.time()
            schedule = solver.solve(data, bws=False, time_limit=10, global_time_limit=10)
            end_time = time.time()
            run_time = end_time - start_time
            schedule.fixRecord(data)

        except Exception as e:
            continue
    