

from network.deepSetNet import DeepSetNet
from TRHOSolver import TRHOSolver
import torch
import numpy as np
from CPSolver import CPSolver
from netTools import getJobFeature
from schedule import Schedule
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

class RHOSolver:

    def __init__(self,numofMachines,k = 5,net = DeepSetNet()):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.PTmask = None
        self.curSelect = None
        self.mask = None
        self.numofJobs = None
        self.numofMachines = numofMachines
        self.CPSolver = CPSolver()
        self.chooseNet = net.to(self.device)
        self.chooseNet.eval()
        self.k = k

    def reset(self,data):
        assert len(data[0][0]) == self.numofMachines
        self.numofJobs = len(data[0])
        self.mask = np.zeros(self.numofJobs)
        self.curSelect = []
        self.PTmask = np.zeros((self.numofJobs,self.numofMachines))
    
    def saveNet(self,savePath = 'model/defaultModel'):
        torch.save(self.chooseNet.state_dict(), savePath)

    def loadNet(self,loadPath = 'model/defaultModel'):
        state = torch.load(loadPath, map_location=self.device)
        self.chooseNet.load_state_dict(state)
        self.chooseNet.to(self.device)
        self.chooseNet.eval()


    def randomChoose(self,k,randomSeed = 2005):
        np.random.seed(randomSeed)
        available_jobs = [j for j in range(len(self.mask)) if self.mask[j] == 0 and j not in self.curSelect]
        num_additional = min(k - len(self.curSelect), len(available_jobs))
        if num_additional <=0 :
            return
        while len(self.curSelect) < k:
            available_jobs = [j for j in range(len(self.mask)) if self.mask[j] == 0 and j not in self.curSelect]
            num_additional = min(k - len(self.curSelect), len(available_jobs))
            if num_additional > 0:
                additional_jobs = np.random.choice(available_jobs, size=num_additional, replace=False)
                self.curSelect.extend(additional_jobs)
                self.mask[additional_jobs] = 1

    def greedyChoose(self,k ,data):
        while len(self.curSelect) < k:
            available_jobs = [j for j in range(len(self.mask)) if self.mask[j] == 0 and j not in self.curSelect]
            if len(available_jobs)  == 0:
                # 没得选了，直接退出
                return 
            available_jobs = [j for j in range(len(self.mask)) if self.mask[j] == 0 and j not in self.curSelect]
            score = []
            for job in available_jobs:
                self.curSelect.append(job)
                jobs_features = getJobFeature(data,self.PTmask,self.curSelect)
                with torch.no_grad():
                    jobs_features = torch.tensor(jobs_features,dtype=torch.float, device=self.device)
                    jobs_features = torch.unsqueeze(jobs_features,0)
                    predicted_utilization = self.chooseNet(jobs_features)
                score.append(predicted_utilization.detach().cpu().item())
                self.curSelect.remove(job)
            selected_job = available_jobs[np.argmax(score)]
            self.curSelect.append(selected_job)
            self.mask[selected_job] = 1

    def CPGreedyChoose(self,k,data,time_limit = 100):
        CPSolver = self.CPSolver
        while len(self.curSelect) < k:
            available_jobs = [j for j in range(len(self.mask)) if self.mask[j] == 0 and j not in self.curSelect]
            if len(available_jobs)  == 0:
                # 没得选了，直接退出
                return 
            available_jobs = [j for j in range(len(self.mask)) if self.mask[j] == 0 and j not in self.curSelect]
            score = []
            for job in available_jobs:
                self.curSelect.append(job)
                schedule = CPSolver.solve_blocking_job_shop([data[0][self.curSelect],data[1][self.curSelect]],
                                                            self.PTmask[self.curSelect],time_limit=time_limit)
                score.append(schedule.cal_utilization())
                self.curSelect.remove(job)
            selected_job = available_jobs[np.argmax(score)]
            self.curSelect.append(selected_job)
            self.mask[selected_job] = 1

    
    def GuidedRandomGreedyChoose(self,k,data,eps = 0.1, t=0.3):
        self.greedyChoose(k,data)
        fz = 0
        current_utilization = 0
        while (1):
            jobs_features = getJobFeature(data,self.PTmask,self.curSelect)
            with torch.no_grad():
                jobs_features = torch.tensor(jobs_features,dtype=torch.float, device=self.device)
                fz = self.chooseNet(jobs_features).detach().cpu().item()
            available_jobs = [j for j in range(len(self.mask)) if self.mask[j] == 0 and j not in self.curSelect]
            noExist = True
            for selected_job in self.curSelect[:]:
                for available_job in available_jobs[:]:
                    self.curSelect.remove(selected_job)
                    self.curSelect.append(available_job)
                    
                    # 计算交换后的效用值
                    jobs_features = getJobFeature(data,self.PTmask,self.curSelect)
                    with torch.no_grad():
                        jobs_features = torch.tensor(jobs_features,dtype=torch.float, device=self.device)
                        new_utilization = self.chooseNet(jobs_features).detach().cpu().item()
                    
                    if new_utilization > current_utilization + fz*eps/self.k:
                        self.mask[selected_job] = 0
                        self.mask[available_job] = 1
                        available_jobs.remove(available_job)
                        current_utilization = new_utilization
                        noExist = False
                        break
                    else:
                        self.curSelect.remove(available_job)
                        self.curSelect.append(selected_job)
                if not noExist:
                    break
            if noExist:
                break

        Z = self.curSelect.copy()
        self.curSelect = []
        self.mask = np.zeros(len(self.mask))
        max_score = 0
        while (len(self.curSelect) < k):
            available_jobs = [j for j in range(len(self.mask)) if self.mask[j] == 0 and j not in self.curSelect]
            if (len(self.curSelect)< t*k):
                available_jobs = [j for j in available_jobs if j not in Z]
                if len(available_jobs) == 0:
                    available_jobs = [j for j in range(len(self.mask)) if self.mask[j] == 0 and j not in self.curSelect]
            score = []
            for job in available_jobs:
                self.curSelect.append(job)
                jobs_features = getJobFeature(data,self.PTmask,self.curSelect)
                with torch.no_grad():
                    jobs_features = torch.tensor(jobs_features,dtype=torch.float, device=self.device)
                    predicted_utilization = self.chooseNet(jobs_features)
                score.append(predicted_utilization.detach().cpu().item())
                self.curSelect.remove(job)
            max_score = np.max(score)
            good_jobs = [i for i, s in enumerate(score) if max_score - s <= 0.05]
            selected_idx = np.random.choice(good_jobs)
            selected_job = available_jobs[selected_idx]
            self.curSelect.append(selected_job)
            self.mask[selected_job] = 1
        if fz > max_score:
            self.curSelect = Z.copy()
            self.mask = np.zeros(len(self.mask))
            for job in Z:
                self.mask[job] = 1
        return
                
    
            

        
    
    def chooseJobs(self,originData,model = 'net',initChoose = 'random'):
        k = self.k
        if len(self.curSelect) == 0:
            if initChoose == 'random':
                self.randomChoose(k)
            elif initChoose == 'greedy':
                self.greedyChoose(k,originData)
            elif initChoose == 'CPGreedy':
                self.CPGreedyChoose(k,originData)
            elif initChoose == 'GRGC':
                self.GuidedRandomGreedyChoose(k,originData)
            return
        if model == 'random':
            self.randomChoose(k)
        elif model == 'net':
            self.greedyChoose(k,originData)
        elif model == 'greedy':
            self.CPGreedyChoose(k,originData)
        else:
            raise ValueError(f"Invalid model: {model}")
        return 

    def solve(self,data,model = 'net',initChoose = 'random',PTmask = None,bws =True,time_limit = 100,initsolver =False,returnSchedule = False,detail=False):
        cpSolver = CPSolver()
        numofJobs = len(data[0])
        numofMachines = len(data[0][0])
        assert numofMachines == self.numofMachines
        self.reset(data)
        initRHOSolve = initsolver and numofMachines >= 5
        if PTmask is not None:
            self.PTmask = PTmask
        totalSchedule = Schedule(numofJobs,numofMachines)
        sovleCount = 0
        baseTime = 0

        while sovleCount < numofJobs:
            self.chooseJobs(originData=data,model = model,initChoose = initChoose)
            if initRHOSolve:
                initSolver = TRHOSolver(numofMachines=numofMachines)
                initSolver.reset([data[0][self.curSelect],data[1][self.curSelect]])
                schedule = initSolver.solve([data[0][self.curSelect],data[1][self.curSelect]],PTmask = self.PTmask[self.curSelect]
                                            ,bws = bws,time_limit=time_limit)
            schedule = cpSolver.solve_blocking_job_shop([data[0][self.curSelect],data[1][self.curSelect]],
                                                        self.PTmask[self.curSelect],bws = bws,time_limit=time_limit)
            for record in schedule.record:
                totalSchedule.add_record(int(self.curSelect[record[0]]),record[1],baseTime+record[2],baseTime+record[3])
            if len(schedule.record) == 0:
                assert False
            job_completion_times = {}
            for job, machine, start_time, end_time in schedule.record:
                job = self.curSelect[job]
                if job not in job_completion_times or end_time > job_completion_times[job]:
                    job_completion_times[job] = end_time

            earliest_complete_job = min(job_completion_times.items(), key=lambda x: x[1])
            earliest_job_id = earliest_complete_job[0]
            earliest_completion_time = earliest_complete_job[1]
            for tem_job, machine, start_time, end_time in schedule.record:
                job = self.curSelect[tem_job]
                if (end_time - start_time) == 0:
                    continue
                if end_time <= earliest_completion_time:
                    self.PTmask[job][machine] = -1
                elif end_time > earliest_completion_time > start_time:
                    self.PTmask[job][machine] = data[0][job][machine] - earliest_completion_time + start_time

            baseTime = baseTime + earliest_completion_time

            sovleCount += 1
            self.curSelect.remove(earliest_job_id)
            removeList = []
            for i in self.curSelect:
                if self.PTmask[i][numofMachines-1] == -1:
                    removeList.append(i)
                    sovleCount += 1
            for i in removeList:
                self.curSelect.remove(i)
            if len(self.curSelect) <= 1:
                break

        makespan = totalSchedule.cal_makespan()
        remaining_time = 0
        for job_id in self.curSelect:
            for machine in range(numofMachines):
                if self.PTmask[job_id][machine] > 0:
                    remaining_time += self.PTmask[job_id][machine]
                elif self.PTmask[job_id][machine] == 0:
                    remaining_time += data[0][job_id][machine]
        makespan = makespan + remaining_time
        if returnSchedule:
            return totalSchedule
        return makespan




if __name__ == "__main__":
    from network.JobShopSetTransformer import JobShopSetTransformer
    from params import configs
    model = JobShopSetTransformer(m=configs.gen_machine_num)
    model.load_state_dict(torch.load('model/jsst'+str(configs.gen_instance_num) + '_' + str(configs.gen_machine_num) + '_' + str(configs.gen_machine_num) + '_' + str(configs.gen_time_limit)+'.pth'))
    ta80_path = os.path.join(os.path.dirname(__file__), 'benchmark/ta/ta79.txt')
    all_mat = np.loadtxt(ta80_path, delimiter='\t').astype(int)
    time_mat = all_mat[:100, :]
    machine_mat = (all_mat[100:, :] + 1).astype(int)
    n_jobs, n_machines = time_mat.shape
    data = [time_mat, machine_mat]
    solver = RHOSolver(numofMachines=n_machines, k=n_machines,net=model,time_limit=configs.gen_time_limit)
    solver.reset(data)