import numpy as np
import torch
from torch_geometric.data import Data
def getJobFeature(orginData,PTmask,curSelected):
    numofMachines = len(orginData[0][0])
    numofJobs = len(curSelected)
    jobs_features = [[] for i in range(numofJobs)]
    for i in range(numofJobs):
        for j in range(numofMachines):
            if PTmask[curSelected[i]][j] == -1:
                jobs_features[i].append(0)
                jobs_features[i].append(0)
                jobs_features[i].append(0)
            elif PTmask[curSelected[i]][j] >0:
                jobs_features[i].append(PTmask[curSelected[i]][j])
                jobs_features[i].append(orginData[1][curSelected[i]][j])
                jobs_features[i].append(1)
            else:
                jobs_features[i].append(orginData[0][curSelected[i]][j])
                jobs_features[i].append(orginData[1][curSelected[i]][j])
                jobs_features[i].append(0)
    return jobs_features


def build_jsp_graph(feature_matrix, M):
    n = feature_matrix.shape[0]
    op_features = []
    op_id = 0
    edge_index = []
    machine_to_ops = {}

    for i in range(n):
        for k in range(M):
            t = feature_matrix[i, 3*k]
            m = feature_matrix[i, 3*k + 1]
            f = feature_matrix[i, 3*k + 2]
            op_features.append([t, m, f])

            if k > 0:
                edge_index.append([op_id - 1, op_id])

            m = int(m)
            if m not in machine_to_ops:
                machine_to_ops[m] = []
            machine_to_ops[m].append(op_id)

            op_id += 1

    for ops in machine_to_ops.values():
        for i in range(len(ops)):
            for j in range(i+1, len(ops)):
                edge_index.append([ops[i], ops[j]])
                edge_index.append([ops[j], ops[i]])

    x = torch.tensor(op_features, dtype=torch.float)
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    batch = torch.zeros(x.size(0), dtype=torch.long)

    return Data(x=x,edge_index=edge_index, batch=batch)
