""" Multi-Agent Task Allocation Network
Default Model with No Edge Features or Attention Based Edges
author: Anonymous
"""
import os
import sys
import inspect

currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir) 

import numpy as np
import torch
import torch.nn as nn
import dgl

from models.graph.hetnet import HetNet, HetResNet
from models.graph.hgt import HGT
from models.graph.hgt_edge import HGTEdge
from models.graph.hgt_edge_resnet import HGTEdgeRes
from scheduling.environment import SchedulingEnvironment
from scheduling.agent import Agent
from scheduling.task import Task

HIDDEN_DIM_SIZE = 16

class GraphModel(nn.Module):
    def __init__(self, num_heads=2, num_layers=5, outputs=['agent', 'task'], task_assignment_node=True, wait_time_constraints=True, model_type='hetnet', device=None):
        super(GraphModel, self).__init__()
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.outputs = outputs
        self.task_assignment_node = task_assignment_node
        self.wait_time_constraints = wait_time_constraints

        self.model_type = model_type
        print(f"Model Type: {self.model_type}")
        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = device
        
        self.set_model()
        
    def forward(self, observation, mode='agent', dependent_action:dict=None, memory=False, node_name=None):
        """Forward Pass for the Graph Scheduler
        Args:
            observation (dict): observation space
            mode (str): mode of the scheduler
            dependent_action (dict): dependent action for the task assignment node {'agent': int, 'task': int, 'path': int}
            memory (bool): If True, it will store the decision tree memory.
        Returns:
            torch.Tensor: Decision tree output"""
        if node_name is None and mode not in ['task_assignment']:
            node_name = f"{mode}_select"
        elif mode in ['task_assignment']:
            node_name = 'task_assignment'
        g, node_feat, edge_feat = self.get_graph_features(observation, mode, dependent_action)
        if mode in ['individual_value', 'q_value']:
            mode = 'value'
        hn, he = self.model(g, node_feat, edge_feat, self.get_output_node(mode))
        if node_name in ['individual_value', 'q_value']:
            node_name = 'value_select'
        
        if node_name in hn:
            if memory:
                # return hn[self.get_output_node(mode)], g, node_feat, edge_feat
                return hn[node_name], g, node_feat, edge_feat
            else:
                # return hn[self.get_output_node(mode)]
                return hn[node_name]
        elif node_name in he:
            if memory:
                return he[node_name], g, node_feat, edge_feat
            else:
                return he[node_name]
        else:
            raise NotImplementedError(f"Node Name: {node_name} is not implemented")
    def replay(self, g, node_feat, edge_feat, mode='agent', node_name=None):
        if node_name is None:
            node_name = mode
        hn, _ = self.model(g, node_feat, edge_feat, self.get_output_node(mode))
        # return hn[self.get_output_node(mode)]
        
        return hn[node_name]
    
    def get_action(self, observation, mode='agent', dependent_action:dict=None, memory=False):
        with torch.no_grad():
            output = self.forward(observation, mode, dependent_action, memory)
            if memory:
                return output[0].argmax().item(), output[1], output[2], output[3]
            else:
                return output.argmax().item()
    
    def get_graph_features(self, observation, mode, dependent_action:dict=None):
        data_dict, node_feats, edge_feats = self.get_features(observation, mode, dependent_action)
                

        g = dgl.heterograph(data_dict, self.num_nodes, idtype=torch.int64, device=self.device)
        
        edge_feats_tensor = {}
        for key in edge_feats:
            edge_feats_tensor[key] = torch.Tensor(edge_feats[key]).to(self.device)
        
        node_feats_tensor = {}
        for key in node_feats:
            node_feats_tensor[key] = torch.Tensor(node_feats[key]).to(self.device) 
        
        return g, node_feats_tensor, edge_feats_tensor
    
    def set_nodes(self):
        self.nodes = {
            'agent',
            'task',
            'state'
        }
        self.set_output_nodes()
        self.nodes.update(selector for selector in self.output_nodes)
        if self.task_assignment_node:
            self.nodes.add('task_assignment')
    
    def set_output_nodes(self):
        if 'task_assignment' in self.outputs:
            self.output_nodes = ['task_assignment']
            return
        self.output_nodes = [f"{output}_select" for output in self.outputs]
        self.nodes.update(self.output_nodes)
        
    def set_edges(self):
        self.edges = {
            ('agent', 'agent_self', 'agent'),
            ('task', 'task_self', 'task'),
            ('task', 'travel', 'task'),
            ('agent', 'assigned', 'task'),
            ('state', 'state_self', 'state'),
            ('agent', 'agent_to_state', 'state'),
            ('task', 'task_to_state', 'state'),    
        }
        for i, selector in enumerate(self.output_nodes):
            if 'task_assignment' in self.outputs:
                break
            if 'task_first' in self.model_type:
                self.edges.add(('task', f"task_to_{selector}", selector))
                self.edges.add(('state', f"state_to_{selector}", selector))
                if self.outputs[i] in ['task', 'value']:
                    self.edges.add(('agent', f"agent_to_{selector}", selector))
            else:
                self.edges.add(('agent', f"agent_to_{selector}", selector))
                self.edges.add(('state', f"state_to_{selector}", selector))
                if self.outputs[i] in ['task', 'value']:
                    self.edges.add(('task', f"task_to_{selector}", selector))
                
        if self.task_assignment_node:
            self.edges.update([
                ('task_assignment', 'task_assignment_self', 'task_assignment'),
                ('agent', 'agent_to_task_assignment', 'task_assignment'),
                ('task_assignment', 'task_assignment_to_agent', 'agent'),
                ('task', 'task_to_task_assignment', 'task_assignment'),
                ('task_assignment', 'task_assignment_to_task', 'task'),
                ('task_assignment', 'task_assignment_to_task_assignment', 'task_assignment')
            ])
        else:
            self.edges.update([
                ('agent', 'agent_to_task', 'task'),
                ('task', 'task_to_agent', 'agent')
            ])
            
        if self.wait_time_constraints is not None:
            self.edges.add(('task', 'wait_time', 'task'))
            self.edges.add(('task', 'depend_on', 'task'))
            
            
        # Bidirectional Edges
        self.bidirectional_edges = {
            'task_assignment_to_agent': 'agent_to_task_assignment',
            'task_assignment_to_task': 'task_to_task_assignment',
            'depend_on': 'wait_time',
        }
    
    def set_edge_features(self):
        """Edges that have Edge Features, mapped from edge_name to source_node"""
        self.edge_features = {}
        
    def set_attention_features(self):
        """Attention Features for the Graph Scheduler"""
        self.attention_features = {}
        
    def set_num_nodes(self, observation, output_mode='agent'):
        self.num_nodes = {}
        for node in self.nodes:
            if node in observation and node not in self.output_nodes:
                self.num_nodes[node] = observation[node].shape[0]
        self.num_nodes['state'] = 1
        if self.task_assignment_node:
            self.num_nodes['task_assignment'] = len(observation['agent']) * len(observation['task'])
        if output_mode in ['agent']:
            self.num_nodes[f'{output_mode}_select'] = self.num_nodes[output_mode]
        elif output_mode in ['task']:
            self.num_nodes[f'{output_mode}_select'] = len(observation['task'])
        elif output_mode in ['q_value']:
            output_mode = 'value'
            self.num_nodes[f'{output_mode}_select'] = len(observation['task_to_task_select'][1]) * len(observation['agent'])
        else:
            self.num_nodes[f'{output_mode}_select'] = 1
                
    def set_dimensions(self):
        self.in_dim = self.get_input_dimensions()
        self.hid_dim = self.get_hidden_dimensions()
        self.out_dim = self.get_output_dimensions()
        
    def get_input_dimensions(self):
        dims = {
            "agent": Agent.get_feature_size(),
            "task": Task.get_feature_size(),
            'state': SchedulingEnvironment.get_state_feature_space(),
            'task_assignment': SchedulingEnvironment.get_task_assignment_feature_space(),
            "agent_to_task": SchedulingEnvironment.get_agent_task_feature_space()[0],
            "task_to_agent": SchedulingEnvironment.get_agent_task_feature_space()[0],
            'travel': SchedulingEnvironment.get_task_to_task_travel_feature_space()[0],
            "wait_time": SchedulingEnvironment.get_wait_time_feature_space()[0],
            "depend_on": SchedulingEnvironment.get_wait_time_feature_space()[0],
            "agent_to_task":SchedulingEnvironment.get_agent_task_feature_space()[0],
            'agent_to_task_assignment': SchedulingEnvironment.get_agent_to_task_assignment_feature_space()[0],
            'task_assignment_to_agent': SchedulingEnvironment.get_agent_to_task_assignment_feature_space()[0],
            'task_assignment_to_task_assignment': SchedulingEnvironment.get_agent_task_task_travel_time_feature_space()[0],
        }
        
        dims.update({output_node: 1 for output_node in self.output_nodes})
        return dims
        
    def get_hidden_dimensions(self):
        return {key: HIDDEN_DIM_SIZE for key in self.in_dim}
    
    def get_output_dimensions(self):
        return {key: 1 for key in self.in_dim}
    
    def set_model(self):
        """Set the HetNet Model for the Graph Scheduler"""
        self.set_nodes()
        self.set_edges()
        self.set_edge_features()
        self.set_attention_features()
        self.set_dimensions()
        if self.model_type in ['hetnet']:
            self.model = HetNet(
                nodes = self.nodes,
                edges = self.edges,
                edge_features = self.edge_features,
                attention_nodes = self.attention_features,
                output_nodes = self.output_nodes,
                in_dim = self.in_dim,
                hid_dim = self.hid_dim,
                out_dim = self.out_dim,
                num_heads = self.num_heads,
                num_layers = self.num_layers,
                merge = 'cat',
                mode = 'leaky_relu',
                device = self.device,
                final_activation = 'avg'
            ).to(self.device)    
        elif self.model_type in ['hetgat_resnet']:
            self.model = HetResNet(
                nodes = self.nodes,
                edges = self.edges,
                edge_features = self.edge_features,
                attention_nodes = self.attention_features,
                output_nodes = self.output_nodes,
                in_dim = self.in_dim,
                hid_dim = self.hid_dim,
                out_dim = self.out_dim,
                num_heads = self.num_heads,
                num_layers = self.num_layers,
                merge = 'cat',
                mode = 'leaky_relu',
                device = self.device,
                final_activation = 'avg'
            ).to(self.device)
        elif 'hgt' in self.model_type: # ['hgt', 'hgt_edge', 'hgt_edge_resnet', 'hgt_edge_resnet_bb']:
            for src, edge, tgt in self.edges:
                if edge in self.edge_features:
                    continue
                self.in_dim[edge] = 1
                self.hid_dim[edge] = HIDDEN_DIM_SIZE
                self.out_dim[edge] = 1
                
            if self.model_type == 'hgt':
                self.model = HGT(
                    nodes = self.nodes,
                    edges = self.edges,
                    in_dim = self.in_dim,
                    hid_dim = self.hid_dim,
                    out_dim = self.out_dim,
                    n_layers = self.num_layers,
                    n_heads = self.num_heads,
                    use_norm = False
                ).to(self.device)
            elif self.model_type == 'hgt_edge':
                self.model = HGTEdge(
                    nodes = self.nodes,
                    edges = self.edges,
                    in_dim = self.in_dim,
                    hid_dim = self.hid_dim,
                    out_dim = self.out_dim,
                    n_layers = self.num_layers,
                    n_heads = self.num_heads,
                    use_norm = False
                ).to(self.device)
            elif 'hgt_edge_resnet' in self.model_type:
                self.model_type == 'hgt_edge_resnet'
                self.model = HGTEdgeRes(
                    nodes = self.nodes,
                    edges = self.edges,
                    in_dim = self.in_dim,
                    hid_dim = self.hid_dim,
                    out_dim = self.out_dim,
                    n_layers = self.num_layers,
                    n_heads = self.num_heads,
                    use_norm = False
                ).to(self.device)
            
        else:
            raise NotImplementedError(f"Model Type: {self.model_type} is not implemented")
        
    def get_features(self, observation, mode='agent', dependent_action:dict=None):
        """Get the features for the Graph Scheduler
        Args:
            observation (dict): observation space
            mode (str): mode of the scheduler
            dependent_action (dict): dependent action for the task assignment node {'agent': int, 'task': int, 'path': int}
        Returns:
            data_dict (dict): data dictionary for the graph mapping edges from source and target nodes
            node_feats (dict): node features for the graph
            edge_feats (dict): edge features for the graph
        """
        # number of nodes from the observation space
        if mode in ['individual_value']:
            self.set_num_nodes(observation, 'value')
        else:
            self.set_num_nodes(observation, mode)
            
        # Agent Selector
        if mode in ['agent']:
            if dependent_action is not None and 'task' in dependent_action:
                out_feats = self.get_agent_select_features_given_task(observation, mode, dependent_action['task'])            
            else:
                out_feats = self.get_agent_select_features(observation, mode)    
        # Task Selector
        elif mode in ['task']:
            # assert('agent' in dependent_action)
            if dependent_action is not None and 'agent' in dependent_action and dependent_action['agent'] is not None:
                out_feats = self.get_task_select_features(observation, mode, dependent_action['agent'])
            else:
                out_feats = self.get_task_first_select_features(observation, mode)
        elif mode in ['task_assignment']:
            out_feats = None # self.get_agent_select_features(observation, mode)
        # Critic Output
        elif mode in ['value']:
            out_feats = self.get_critic_edge_features(observation, mode)
        elif mode in ['individual_value']:
            mode = 'value'
            out_feats = self.get_individual_critic_edge_features(observation, mode, dependent_action['agent'], dependent_action['task'])
        elif mode in ['q_value']:
            mode = 'value'
            out_feats = self.get_q_value_critic_edge_features(observation, mode)    
        else:
            raise NotImplementedError(f"Mode: {mode} is not implemented")

        # Node Features
        node_feats = self.get_node_features(observation, mode)
        # state to output_node is one to many
        data_dict, edge_feats = self.get_edge_features(observation, mode)
        if mode not in ['task_assignment']:
            data_dict[('state', f"state_to_{mode}_select", self.get_output_node(mode))] = ([0 for i in range(self.num_nodes[self.get_output_node(mode)])], [i for i in range(self.num_nodes[self.get_output_node(mode)])])
            data_dict.update(out_feats)
        
        # if self.model_type in ['hgt', 'hgt_edge']:
        if 'hgt' in self.model_type:
            # all edges that are not in edge
            for src, edge, tgt in self.edges:
                if edge in self.edge_features: # or (src, edge, tgt) not in data_dict:
                    continue
                if (src, edge, tgt) not in data_dict:
                    print(src, edge, tgt, self.model_type, self.output_nodes)
                if (src, edge, tgt) in data_dict:
                    edge_feats[edge] = np.ones((len(data_dict[(src, edge, tgt)][1]), 1))
                else:
                    edge_feats[edge] = np.ones((0, 1))
        return data_dict, node_feats, edge_feats
    
    def get_node_features(self, observation, mode):
        node_feats = {}
        for node in self.nodes:
            if node in observation:
                node_feats[node] = observation[node]
            elif node not in self.output_nodes:
                node_feats[node] = torch.zeros(self.num_nodes[node], self.in_dim[node])
            elif node in [self.get_output_node(mode)]:
                node_feats[node] = torch.zeros(self.num_nodes[node], self.in_dim[node])
        return node_feats
    
    def get_edge_features(self, observation, mode):
        data_dict = {}
        edge_feats = {}
        # Edge Features and Edge Mapping:
        for source, edge, target in self.edges:
            if target in self.output_nodes and target != self.get_output_node(mode) and target not in ['task_assignment']:
                continue
            elif edge in observation:
                # Edge has been observed from the environment
                data_dict[(source, edge, target)] = (observation[edge].source_nodes.tolist(), observation[edge].target_nodes.tolist())
                if edge in self.edge_features:
                    edge_feats[edge] = observation[edge].edge_features.tolist()
            elif edge in self.bidirectional_edges:
                # Edge has parity with another edge in observation of the environment
                reverse_edge = self.bidirectional_edges[edge]
                data_dict[(source, edge, target)] = (observation[reverse_edge].target_nodes.tolist(), observation[reverse_edge].source_nodes.tolist())
                if edge in self.edge_features:
                    edge_feats[edge] = observation[reverse_edge].edge_features.tolist()
            elif 'self' in edge:
                # Self Loop
                data_dict[(source, edge, target)] = ([i for i in range(self.num_nodes[source])], [i for i in range(self.num_nodes[target])])
            else:
                pass
        data_dict[('agent', 'agent_to_state', 'state')] = ([i for i in range(self.num_nodes['agent'])], [0 for _ in range(self.num_nodes['agent'])])
        data_dict[('task', 'task_to_state', 'state')] = ([i for i in range(self.num_nodes['task'])], [0 for _ in range(self.num_nodes['task'])])
        return data_dict, edge_feats
    
    def get_agent_select_features(self, observation, mode):
        data_dict = {}
        data_dict[('agent', f'agent_to_{mode}_select', f'{mode}_select')] = (
            [i for i in range(self.num_nodes['agent'])],
            [i for i in range(self.num_nodes[f'{mode}_select'])]
        ) 
        return data_dict
    
    def get_agent_select_features_given_task(self, observation, mode, task_selected):
        data_dict = {}
                                                                           
        data_dict[('agent', f'agent_to_{mode}_select', f'{mode}_select')] = (
            [i for i in range(self.num_nodes['agent'])],
            [i for i in range(self.num_nodes[f'{mode}_select'])]
        ) 
        data_dict[('task', f'task_to_{mode}_select', f'{mode}_select')] = (
            [task_selected for _ in range(self.num_nodes[f'{mode}_select'])], 
            [i for i in range(self.num_nodes[f'{mode}_select'])]
        )
        return data_dict
    
    def get_task_select_features(self, observation, mode, agent_selected):
        data_dict = {}
        data_dict[('agent', f'agent_to_{mode}_select', f'{mode}_select')] = ([agent_selected for _ in range(self.num_nodes[f'{mode}_select'])], [i for i in range(self.num_nodes[f'{mode}_select'])])
        data_dict[('task', f'task_to_{mode}_select', f'{mode}_select')] = ([i for i in range(self.num_nodes['task'])], [i for i in range(self.num_nodes[f'{mode}_select'])])
        return data_dict
    
    def get_task_first_select_features(self, observation, mode):
        # print("Task First ---")
        data_dict = {}
        # connect every agent to every task_select
        data_dict[(f'agent', f'agent_to_{mode}_select', f'{mode}_select')] = (
            [
                i for i in range(self.num_nodes['agent']) 
                    for _ in range(self.num_nodes[f'{mode}_select'])
            ],
            [
                i for _ in range(self.num_nodes['agent'])
                    for i in range(self.num_nodes[f'{mode}_select'])
            ]
        )
        
        data_dict[(f'task', f'task_to_{mode}_select', f'{mode}_select')] = (
            [i for i in range(self.num_nodes['task'])], 
            [i for i in range(self.num_nodes[f'{mode}_select'])]
        )
        # print("Task:", list(data_dict.keys()))
        return data_dict
    
    def get_critic_edge_features(self, observation, mode):
        data_dict = {}
        if mode in ['value']:
            data_dict[('agent', f'agent_to_{mode}_select', f'{mode}_select')] = ([agent_id for agent_id in range(self.num_nodes['agent'])], [0 for _ in range(self.num_nodes['agent'])] )
            data_dict[('task', f'task_to_{mode}_select', f'{mode}_select')] = ([task_id for task_id in range(self.num_nodes['task'])], [0 for i in range(self.num_nodes['task'])])
        else:
            raise NotImplementedError        
        return data_dict

    def get_individual_critic_edge_features(self, observation, mode, agent, task):
        # out_feats = self.get_individual_critic_edge_features(observation, mode, dependent_action['agent'], dependent_action['task'])
        data_dict = {}
        mode = 'value'
        if mode in ['value']:
            data_dict[('agent', f'agent_to_{mode}_select', f'{mode}_select')] = ([agent], [0])
            data_dict[('task', f'task_to_{mode}_select', f'{mode}_select')] = ([task], [0])
            # TODO: Implement Individual Values for allowed agent and task pairing
            # mode = 'value'
            # data_dict[('agent', f'agent_to_{mode}_select', f'{mode}_select')] = ([agent_id for agent_id in range(self.num_nodes['agent']) for _ in range(self.)], [0 for _ in range(self.num_nodes['agent'])] )
            pass
        else:
            raise NotImplementedError        
        return data_dict
    
    def get_q_value_critic_edge_features(self, observation, mode):
        data_dict = {}
        mode = 'value'
        if mode in ['value']:
            tasks = observation['task_to_task_select'][0]
            num_tasks = len(tasks)
            # for i in num_agents: for j in num_tasks: assignment_id = i * num_tasks + j
            # an index map from agent id to assignment_id
            # 
            # tasks = [2, 4, 7]
            # num_agents = 5

            # print([f'agent_{i}' for i in range(num_agents) for task in tasks])
            # print([f'task_{task}' for i in range(num_agents) for task in tasks])
            # print([f'{i * num_agents + j}' for i in range(len(tasks)) for j in range(num_agents)])
            data_dict[('agent', f'agent_to_{mode}_select', f'{mode}_select')] = (
                [agent_id for _ in range(num_tasks) for agent_id in range(self.num_nodes['agent'])],
                [i  + j * self.num_nodes['agent']  for j in range(num_tasks) for i in range(self.num_nodes['agent'])]
            )
            data_dict[('task', f'task_to_{mode}_select', f'{mode}_select')] = (
                [task_id for task_id in range(num_tasks) for _ in range(self.num_nodes['agent'])],
                [i + j * self.num_nodes['agent'] for j in range(num_tasks) for i in range(self.num_nodes['agent'])],
            )
            # print(tasks, self.num_nodes['agent'], num_tasks)
            # print(f'Agents: {data_dict[("agent", f"agent_to_{mode}_select", f"{mode}_select")][0]}')
            # print(f'Tasks: {data_dict[("task", f"task_to_{mode}_select", f"{mode}_select")][0]}')
            # print(f'Assignment: {data_dict[("task", f"task_to_{mode}_select", f"{mode}_select")][1]}')
            # print(f'List: { [i + j * self.num_nodes["agent"]  for j in range(num_tasks) for i in range(self.num_nodes["agent"])]}')
            
        else:
            raise NotImplementedError
        
        return data_dict
        
    def get_task_agent_assignment_ids(self, observation):
        tasks = observation['task_to_task_select'][0]
        agent_ids = [i for _ in range(len(tasks)) for i in range(self.num_nodes['agent'])]
        task_ids = [task_id for task_id in tasks for _ in range(self.num_nodes['agent'])]
        return agent_ids, task_ids
        
    def get_output_node(self, mode):
        """Get the output node for the mode"""
        return f"{mode}_select"
  