import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches

class Schedule:
    def __init__(self, num_jobs, num_machines):
        self.num_jobs = num_jobs
        self.num_machines = num_machines
        self.record = []
        self.record_dict = {}
    
    def add_record(self, job_id, operation_id, start_time, end_time):
        if (end_time - start_time) == 0:
            return
        if (job_id, operation_id) in self.record_dict:
            old_start_time = self.record_dict[job_id, operation_id][0]
            for i, record in enumerate(self.record):
                if record[0] == job_id and record[1] == operation_id:
                    self.record[i] = (job_id, operation_id, old_start_time, end_time)
                    break
            if self.record_dict[job_id, operation_id][1] != end_time:
                pass
            self.record_dict[job_id, operation_id] = (old_start_time, end_time)
            
            return
        self.record.append((job_id, operation_id, start_time, end_time))
        self.record_dict[job_id, operation_id] = (start_time, end_time)

    def cal_makespan(self):
        return max([i[1] for i in self.record_dict.values()])

    def cal_utilization(self):
        totalWorkTime = sum([(i[3]-i[2]) for i in self.record])
        DuringTime = max([i[3] for i in self.record]) - min([i[2] for i in self.record])
        return totalWorkTime / DuringTime

    
    def plotSchedule(self,orginData,savePath = None):
        fig, ax = plt.subplots(figsize=(12, 6))
        colors = plt.cm.tab20(np.linspace(0, 1, self.num_machines+10))
        for j in range(self.num_jobs):
            sequence = orginData[1][j]
            if (j,0) not in self.record_dict:
                continue
            for i, m in enumerate(sequence):
                start = self.record_dict[j,i][0]
                duration = self.record_dict[j,i][1] - self.record_dict[j,i][0]
                if duration == 0:
                    continue
                rect = patches.Rectangle(
                    (start, j-0.4),
                    duration,
                    0.8,
                    facecolor=colors[m],
                    edgecolor='black',
                    alpha=0.7
                )
                ax.add_patch(rect)
                ax.text(start + duration/2, j, f'M{m}', 
                    ha='center', va='center', color='black', fontweight='bold')
        ax.set_ylim(-0.5, self.num_jobs-0.5)
        ax.set_xlim(0, self.cal_makespan() * 1.1)
        ax.set_xlabel('Time')
        ax.set_ylabel('Job')
        ax.set_title('Job Shop Scheduling Gantt Chart')
        ax.grid(True, linestyle='--', alpha=0.7)
        legend_elements = [patches.Patch(facecolor=colors[m], edgecolor='black', alpha=0.7,
                                    label=f'Machine {m}') for m in range(self.num_machines)]
        ax.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.15, 1))
        if savePath is not None:
            plt.savefig(savePath)
        else:
            plt.show()
            plt.close()

    def check(self,data):
        flag = True
        for i in range(self.num_jobs):
            for j in range(self.num_machines):
                if self.record_dict[i,j][1] -  self.record_dict[i,j][0] !=data[0][i][j]:
                    flag = False
        return flag   
    

    def fixRecord(self,data):
        time_mat, machine_mat = data
        if not self.record_dict:
            return

        def get_duration(job_id, op_id):
            return time_mat[job_id][op_id]

        def set_record(job_id, op_id, start_time, end_time):
            self.record_dict[(job_id, op_id)] = (start_time, end_time)
            for idx, r in enumerate(self.record):
                if r[0] == job_id and r[1] == op_id:
                    self.record[idx] = (job_id, op_id, start_time, end_time)
                    break

        changed = True
        if True:
            changed = False
            machine_orders = {m: [] for m in range(self.num_machines+1)}
            for (job_id, op_id), (s, e) in self.record_dict.items():
                m = machine_mat[job_id][op_id].item()
                machine_orders[m].append((job_id, op_id, s))
            for m in machine_orders:
                machine_orders[m].sort(key=lambda x: x[2])
            all_ops = [
                (job_id, op_id, s)
                for (job_id, op_id), (s, e) in self.record_dict.items()
            ]
            all_ops.sort(key=lambda x: x[2])
            for job_id, op_id, cur_start in all_ops:
                cur_end = self.record_dict[(job_id, op_id)][1]
                duration = get_duration(job_id, op_id)
                m = machine_mat[job_id][op_id]
                prev_job_end = 0
                if op_id - 1 >= 0 and (job_id, op_id - 1) in self.record_dict:
                    prev_job_end = self.record_dict[(job_id, op_id - 1)][1]

                succ_start_prev_on_m = 0
                order = machine_orders[m]
                for idx, (j2, o2, s2) in enumerate(order):
                    if j2 == job_id and o2 == op_id:
                        if idx - 1 >= 0:
                            prev_job_on_m, prev_op_on_m, _ = order[idx - 1]
                            succ_op = prev_op_on_m + 1
                            if succ_op <self.num_machines:
            
                                succ_start_prev_on_m = self.record_dict[(prev_job_on_m, succ_op)][0]
                            else :
                                succ_start_prev_on_m = self.record_dict[(prev_job_on_m, succ_op-1)][1]
                        break
                required_start = max(prev_job_end, succ_start_prev_on_m)
                if cur_start > required_start :
                    new_start = required_start
                    new_end = new_start + duration
                    if new_end != cur_end or new_start != cur_start:
                        set_record(job_id, op_id, new_start, new_end)
                        changed = True

        return