""" Multi-Agent Task Allocation Network
author: Anonymous
"""
import os
import sys
import inspect

currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir) 

import torch
import torch.nn as nn
import dgl

from models.graph.hetnet import HetNet
from scheduling.environment import SchedulingEnvironment
from scheduling.agent import Agent
from scheduling.task import Task

from models.graph_model import GraphModel

class GraphModelWithEdge(GraphModel):
    def __init__(self, num_heads=8, num_layers=4, outputs=['agent', 'task'], task_assignment_node=True, wait_time_constraints=True, model_type='hetnet', device=None):
        super(GraphModelWithEdge, self).__init__(num_heads, num_layers, outputs, task_assignment_node, wait_time_constraints, model_type=model_type, device=device)
        
    def set_edge_features(self):
        """Edges that have Edge Features, mapped from edge_name to source_node"""
        self.edge_features = {
            'travel': 'task',
        }
        if self.task_assignment_node:
            self.edge_features['agent_to_task_assignment'] = 'agent'
            self.edge_features['task_assignment_to_agent'] = 'task_assignment'
            self.edge_features['task_assignment_to_task_assignment'] = 'task_assignment'
        else:
            self.edge_features['agent_to_task'] = 'agent'
            self.edge_features['task_to_agent'] = 'task'
            
        if self.wait_time_constraints is not None:
            self.edge_features['wait_time'] = 'task'
            self.edge_features['depend_on'] = 'task'
        