"""Heterogenous Graph Attention Network (HetGAT) solver for Scheduling Problem"""

import numpy as np

import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical

from scheduling.environment import SchedulingEnvironment
from models.graph_scheduler import GraphSchedulerTaskAssignment


class HetGatSolverTaskAssignment(nn.Module):
    def __init__(self, mode='attention', num_heads = 8):
        """Heterogenous Graph Attention Network (HetGAT) solver for Scheduling Problem
        Args:
            mode (str): Mode of the graph, must be one of ['no_edge', 'edge', 'attention']
        """
        super().__init__()
        self.actor = GraphSchedulerTaskAssignment(num_heads=num_heads, task_assignment_node=True, graph=mode)
        
    def get_action_and_memory(self, x, replay=None, greedy=False):
        action, log_prob, entropy, memory = self.get_action(x, replay=replay, greedy=greedy)
        return action, log_prob, entropy, memory
    
    def replay_baseline(self, x, action):
        action, log_prob, entropy, memory = self.get_action(x, action=action)    
        return action, log_prob, entropy, memory    
    
    def get_action(self, x, action=None, replay=None, greedy=False):
        memory = []
        if replay is not None:
            g, nf, ef = replay[0]
            actor_output = self.actor.replay(g, nf, ef, mode='agent')
        else:
            actor_output, g, nf, ef = self.actor(x, memory=True)
            memory.append((g, nf, ef))
        agent_logits = torch.softmax(actor_output, dim=0).flatten()
        agent_probs = Categorical(logits=agent_logits)
        
        if action is None and not greedy:
            agent = agent_probs.sample()
        elif action is None and greedy:
            agent = torch.argmax(agent_logits)
        else:
            agent = torch.tensor(action[1])
        
        if replay is not None:
            g, nf, ef, mask = replay[1]
            task_output = self.actor.replay(g, nf, ef, mode='task')
            task_output_masked = task_output.flatten()[mask]
        else:
            task_output, g, nf, ef = self.actor(x, agent.item(), memory=True)
            memory.append((g, nf, ef, x['task_to_task_select'][0]))
            mask = x['task_to_task_select'][0]
            task_output_masked = task_output.flatten()[mask]
            
        # print("Actor Output:", actor_output.flatten())
        # print("Task Output:", task_output.flatten())
        # print("-"*100)
        
        task_logits = torch.softmax(task_output_masked, dim=0).flatten()
        task_probs = Categorical(logits=task_logits)
        if action is None and not greedy:
            task = task_probs.sample()
            task_id = torch.tensor(x['task_to_task_select'][0][task.item()])
        elif action is None and greedy:
            task = torch.argmax(task_logits)
            task_id = torch.tensor(x['task_to_task_select'][0][task.item()])
        else:
            task = torch.tensor(np.where(mask == action[0])[0].item())
            task_id = torch.tensor(action[0])
            
        log_prob = agent_probs.log_prob(agent) + task_probs.log_prob(task)
        entropy = (agent_probs.entropy() + task_probs.entropy())/2
        return (task_id.item(), agent.item()), log_prob, entropy, memory