""" Multi-Agent Task Allocation Network
author: Anonymous
"""
import os
import sys
import inspect

currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir) 

import torch
import torch.nn as nn
import dgl

from models.graph.hetnet import HetNet
from scheduling.environment import SchedulingEnvironment
from scheduling.agent import Agent
from scheduling.task import Task

from models.graph_model_with_edge import GraphModelWithEdge

class GraphModelWithEdgeAttention(GraphModelWithEdge):
    def __init__(self, num_heads=8, num_layers=4, outputs=['agent', 'task'], task_assignment_node=True, wait_time_constraints=True, model_type='hetnet', device=None):
        super(GraphModelWithEdgeAttention, self).__init__(num_heads, num_layers, outputs, task_assignment_node, wait_time_constraints, model_type=model_type, device=device)
        
    
    def set_attention_features(self):
        self.attention_features = {
            'travel': 'task',
            # 'agent_to_state': 'agent',
            # 'task_to_state': 'task'
        }
        if self.task_assignment_node:
            self.attention_features['agent_to_task_assignment'] = 'agent'
            self.attention_features['task_assignment_to_agent'] = 'task_assignment'
            self.attention_features['task_assignment_to_task_assignment'] = 'task_assignment'
        else:
            self.attention_features['agent_to_task'] = 'agent'
            self.attention_features['task_to_agent'] = 'task'
        
        if self.wait_time_constraints is not None:
            self.attention_features['wait_time'] = 'task'
            self.attention_features['depend_on'] = 'task'


class GraphModelWithEdgeAttentionSimple(GraphModelWithEdge):
    def __init__(self, num_heads=8, num_layers=4, outputs=['agent', 'task'], task_assignment_node=True, wait_time_constraints=True, device=None):
        super(GraphModelWithEdgeAttentionSimple, self).__init__(num_heads, num_layers, outputs, task_assignment_node, wait_time_constraints, device)
        
    
    def set_attention_features(self):
        self.attention_features = {
            # 'travel': 'task',
            # 'agent_to_state': 'agent',
            # 'task_to_state': 'task'
        }
        if self.task_assignment_node:
            self.attention_features['agent_to_task_assignment'] = 'agent'
            # self.attention_features['task_assignment_to_agent'] = 'task_assignment'
            self.attention_features['task_assignment_to_task_assignment'] = 'task_assignment'
        else:
            self.attention_features['agent_to_task'] = 'agent'
            self.attention_features['task_to_agent'] = 'task'
        
        if self.wait_time_constraints is not None:
            self.attention_features['wait_time'] = 'task'
            # self.attention_features['depend_on'] = 'task'
        