""" Multi-Agent Task Allocation Network
author: Anonymous
"""
import os
import sys
import inspect

currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir) 

import torch
import torch.nn as nn
import dgl

from models.graph.hetnet import HetNet
from scheduling.environment import SchedulingEnvironment
from scheduling.agent import Agent
from scheduling.task import Task

from models.graph_model import GraphModel
from models.graph_model_with_edge import GraphModelWithEdge
from models.graph_model_with_edge_attention import GraphModelWithEdgeAttention, GraphModelWithEdgeAttentionSimple

from solvers.edf import EarliestDeadlineFirstAgentScheduler


def get_graph(graph, num_heads, num_layers, outputs, task_assignment_node, device):
    if 'no_edge' in graph: # hetgat without any edge features
        model = GraphModel(num_heads, num_layers, outputs, task_assignment_node, device=device).to(device)
    elif 'edge' in graph and 'hgt_edge' not in graph: # hetgat with edge features
        model = GraphModelWithEdge(num_heads, num_layers, outputs, task_assignment_node, device=device).to(device)
    elif 'hetgat' in graph:
        if 'resnet' in graph: # ['hetgat_resnet']:
            model = GraphModelWithEdgeAttention(num_heads, num_layers, outputs, task_assignment_node, model_type=graph, device=device).to(device)
        elif 'simple' in graph: # ['hetgat_simple']:
            model = GraphModelWithEdgeAttentionSimple(num_heads, num_layers, outputs, task_assignment_node, device=device).to(device)
        else: # ['hetgat']:
            model = GraphModelWithEdgeAttention(num_heads, num_layers, outputs, task_assignment_node, device=device).to(device)
    elif 'hgt' in graph:
        if 'edge' in graph: # ['hgt_edge', 'hgt_edge_resnet', "hgt_edge_resnet_bb"]:
            model = GraphModelWithEdgeAttention(num_heads, num_layers, outputs, task_assignment_node, model_type=graph, device=device).to(device)
        else: # ['hgt']:
            model = GraphModelWithEdgeAttention(num_heads, num_layers, outputs, task_assignment_node, model_type=graph, device=device).to(device)
    else:
        raise ValueError("Mode is not valid: {}".format(graph))
    
    return model

class GraphSchedulerAgent(nn.Module):
    def __init__(self, num_heads=8, num_layers=4, task_assignment_node=False, graph='hetgat', device=None):
        super(GraphSchedulerAgent, self).__init__()
        
        self.mode = 'agent'
        self.outputs = ['agent']
        self.graph = graph
        self.model = get_graph(graph, num_heads, num_layers, self.outputs, task_assignment_node, device)
        
        
    def forward(self, observation, memory=False):
        return self.model(observation, self.mode, memory=memory)
        
    def get_action(self, observation, memory=False):
        with torch.no_grad():
            output = self.forward(observation, memory=memory)
            if memory:
                return output.argmax().item(), output[1], output[2], output[3]
            return output.argmax().item()
        
    def replay(self, g, nf, ef, mode='agent'):
        if mode in ['agent']:
            return self.model.replay(g, nf, ef, mode)
        else:
            raise NotImplementedError("Replay mode must be one of ['agent']")
    
class GraphSchedulerTask(nn.Module):
    def __init__(self, num_heads=8, num_layers=4, task_assignment_node=False, graph='hetgat', device=None):
        super(GraphSchedulerTask, self).__init__()
        self.mode = 'task'
        self.outputs = ['task']
        self.graph = graph
        self.model = get_graph(graph, num_heads, num_layers, self.outputs, task_assignment_node, device)
        
            
    def forward(self, observation, agent_id=None, memory=False):
        dependent_action = {'agent': agent_id}
        return self.model(observation, self.mode, dependent_action, memory=memory)
        
    def get_action(self, observation, agent_id=None, memory=False):
        with torch.no_grad():
            output = self.model.forward(observation, agent_id, memory=memory)
            if memory:
                return output[0].argmax().item(), output[1], output[2], output[3]
            return output.argmax().item()
    
    def replay(self, g, nf, ef, mode='task'):
        if mode in ['task']:
            return self.model.replay(g, nf, ef, mode)
        else:
            raise NotImplementedError("Replay mode must be one of ['task']")

class GraphSchedulerCritic(nn.Module):
    def __init__(self, num_heads=8, num_layers=4, task_assignment_node=True, graph = 'hetgat', device=None):
        super(GraphSchedulerCritic, self).__init__()
        self.mode = 'individual_value'
        self.outputs = ['value']
        self.graph = graph
        self.model = get_graph(graph, num_heads, num_layers, self.outputs, task_assignment_node, device)
            
    def forward(self, observation, action, memory=False):
        task, agent = action
        dependent_action = {'agent': agent, 'task': task}
        value = self.model(observation, self.mode, dependent_action=dependent_action, memory=memory, node_name='value_select')
        # sigmoid
        return torch.nn.functional.leaky_relu(value)
    
    def get_q_values(self, observation):
        values = self.model(observation, mode = 'q_value')
        agent_indices, task_indices = self.model.get_task_agent_assignment_ids(observation)
        return values.flatten(), agent_indices, task_indices
    
    def get_action(self, observation, greedy=False):
        q_vals, agent_indices, task_indices = self.get_q_values(observation)
        normalized = nn.Softmax(dim=0)(q_vals)
        if greedy:
            action = normalized.argmax()
        else:
            selector = torch.distributions.Categorical(normalized)
            action = selector.sample()
        agent = agent_indices[action]
        task = task_indices[action]
        return (task, agent), None, None, None
    
    
    def get_action_and_probability(self, x):
        """Get action and probability of the action
        Args:
            x (dict): Input data
        Returns:
            tuple: Tuple of action and probability
                - action (tuple): Tuple of task and agent id
                - probability (tuple): Tuple of agent and task probabilities
                - log_prob (tuple): Tuple of agent and task log probabilities
                - memory (tuple): Tuple of agent and task memory (g, nf, ef)
        """
        q_vals, agent_indices, task_indices = self.get_q_values(x)
        normalized = nn.Softmax(dim=0)(q_vals)
        # if greedy:
        selector = torch.distributions.Categorical(normalized)
        action = normalized.argmax()
        # else:
        #     action = selector.sample()
        agent = agent_indices[action]
        task = task_indices[action]
        log_prob = selector.log_prob(action)
        prob = normalized[action] 
        return (task, agent), log_prob, prob
        
class GraphSchedulerIndividual(nn.Module):
    def __init__(self, num_heads=8, num_layers=4, layer_data=None, task_assignment_node=False, graph='hetgat', device=None):
        super(GraphSchedulerIndividual, self).__init__()
        if layer_data is not None:
            if 'agent' in layer_data and layer_data['agent'] is None:
                agent_layers = num_layers
            else:
                agent_layers = layer_data['agent']
            if 'task' in layer_data and layer_data['task'] is None:
                task_layers = num_layers
            else:
                task_layers = layer_data['task']
        else:
            agent_layers = num_layers
            task_layers = num_layers
        
        if 'edf_agent' in graph:
            self.agent_scheduler = EarliestDeadlineFirstAgentScheduler()
        else:
            self.agent_scheduler = GraphSchedulerAgent(num_heads, agent_layers, task_assignment_node, graph, device)
        self.task_scheduler = GraphSchedulerTask(num_heads, task_layers, task_assignment_node, graph=graph, device=device)
            
    def get_action(self, observation, memory=False):
        with torch.no_grad():
            agent_output = self.forward(observation, memory=memory)
            if memory:
                agent_id = agent_output[0].argmax().item()
            else:
                agent_id = agent_output.argmax().item()
            task_output = self.forward(observation, agent_id, memory=memory)
            if memory:
                task_id = task_output[0].argmax().item()
                return agent_id, task_id, (agent_output[1], task_output[1]), (agent_output[2], task_output[2]), (agent_output[3], task_output[3])
            else:
                task_id = task_output.argmax().item()
            return agent_id, task_id
    
    def forward(self, observation, preq_id = None, memory=False):
        if preq_id is None:
            return self.agent_scheduler.forward(observation, memory=memory)
        else:
            return self.task_scheduler(observation, preq_id, memory=memory)
    
    def replay(self, g, nf, ef, mode='agent'):
        if mode in ['agent']:
            return self.agent_scheduler.replay(g, nf, ef, mode)
        else:
            return self.task_scheduler.replay(g, nf, ef, mode)
        
class GraphSchedulerAgentTask(nn.Module):
    def __init__(self, num_heads=8, num_layers=4, task_assignment_node=False, graph='hetgat', device=None):
        super(GraphSchedulerAgentTask, self).__init__()
        self.outputs = ['agent']
        
        self.graph = graph
        
        self.model = get_graph(graph, num_heads, num_layers, self.outputs, task_assignment_node, device)
            
    def forward(self, observation, task_id = None, memory=False):
        if task_id is None:
            return self.model(observation, 'agent', memory=memory)
        else:
            dependent_action = {'task': int(task_id)}
            return self.model.forward(observation, 'agent', dependent_action, memory=memory)
        
    def get_action(self, observation, memory=False):
        with torch.no_grad():
            agent_output = self.forward(observation, memory=memory)
            if memory:
                agent_id = agent_output[0].argmax().item()
            else:
                agent_id = agent_output.argmax().item()
            task_output = self.forward(observation, agent_id, memory=memory)
            if memory:
                task_id = task_output[0].argmax().item()
                return agent_id, task_id, (agent_output[1], task_output[1]), (agent_output[2], task_output[2]), (agent_output[3], task_output[3])
            else:
                task_id = task_output.argmax().item()
            return agent_id, task_id
        
class GraphSchedulerTaskFirst(nn.Module):
    def __init__(self, num_heads=8, num_layers=4, layer_data=None, task_assignment_node=False, graph='hetgat', device=None):
        super(GraphSchedulerTaskFirst, self).__init__()
        if layer_data is not None:
            if 'agent' in layer_data and layer_data['agent'] is None:
                agent_layers = num_layers
            else:
                agent_layers = layer_data['agent']
            if 'task' in layer_data and layer_data['task'] is None:
                task_layers = num_layers
            else:
                task_layers = layer_data['task']
        else:
            agent_layers = num_layers
            task_layers = num_layers
            
        self.agent_scheduler = GraphSchedulerAgentTask(num_heads, agent_layers, task_assignment_node, graph, device)
        self.task_scheduler = GraphSchedulerTask(num_heads, task_layers, task_assignment_node, graph=graph, device=device)
    
    def get_action(self, observation, memory=False):
        with torch.no_grad():
            
            
            agent_output = self.forward(observation, memory=memory)
            if memory:
                agent_id = agent_output[0].argmax().item()
            else:
                agent_id = agent_output.argmax().item()
            task_output = self.forward(observation, agent_id, memory=memory)
            if memory:
                task_id = task_output[0].argmax().item()
                return agent_id, task_id, (agent_output[1], task_output[1]), (agent_output[2], task_output[2]), (agent_output[3], task_output[3])
            else:
                task_id = task_output.argmax().item()
            return agent_id, task_id
    
    def forward(self, observation, task_id = None, memory=False):
        if task_id is None:
            return self.task_scheduler(observation, memory=memory)
        else:
            return self.agent_scheduler(observation, task_id, memory=memory)
    
    def replay(self, g, nf, ef, mode='agent'):
        if mode in ['agent']:
            return self.agent_scheduler.replay(g, nf, ef, mode)
        else:
            return self.task_scheduler.replay(g, nf, ef, mode)
    


class GraphTaskAllocation(nn.Module):
    def __init__(self, num_heads=8, num_layers=4, output_mode='task_assignment', task_assignment_node=False, graph='hetgat', device=None):
        super(GraphTaskAllocation, self).__init__()
        
        self.mode = output_mode
        self.outputs = [output_mode]
        self.graph = graph
        self.model = get_graph(graph, num_heads, num_layers, self.outputs, task_assignment_node, device)
        
        
    def forward(self, observation, memory=False):
        return self.model(observation, self.mode, memory=memory)
        
    def get_action(self, observation, memory=False):
        with torch.no_grad():
            output = self.forward(observation, memory=memory)
            if memory:
                return output.argmax().item(), output[1], output[2], output[3]
            return output.argmax().item()
        
    def replay(self, g, nf, ef, mode='task_assignment'):
        if mode in ['task_assignment']:
            return self.model.replay(g, nf, ef, mode)
        else:
            raise NotImplementedError("Replay mode must be one of ['task_assignment']")
        