"""Heterogenous Graph Attention Network (HetGAT) solver for Scheduling Problem"""

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.categorical import Categorical

from scheduling.environment import SchedulingEnvironment
from models.graph_scheduler import GraphSchedulerIndividual, GraphSchedulerCritic, GraphSchedulerTaskFirst


class HetGatSolverIndividual(nn.Module):
    def __init__(self, mode='attention', num_heads = 8, num_layers=4, layer_data=None, temperature=1.0):
        """Heterogenous Graph Attention Network (HetGAT) solver for Scheduling Problem
        Args:
            mode (str): Mode of the graph, must be one of ['no_edge', 'edge', 'attention']
        """
        super().__init__()
        self.actor = GraphSchedulerIndividual(num_heads=num_heads, num_layers=num_layers, layer_data=layer_data, task_assignment_node=True, graph=mode)
        self.eps = 1e-5 # epsilon for numerical stability
        self.epsilon = 0.1 # epsilon for epsilon greedy
        self.temperature = temperature
        
    def set_environment(self, env, rerun=False):
        pass
    
    def get_action_and_memory(self, x, replay=None, greedy=False, adaptive_temperature=False):
        action, log_prob, entropy, memory, select_probs = self.get_action(x, replay=replay, greedy=greedy, adaptive_temperature=adaptive_temperature)
        return action, log_prob, entropy, memory,  select_probs
    
    def replay_baseline(self, x, action):
        action, log_prob, entropy, memory, select_probs = self.get_action(x, action=action)    
        return action, log_prob, entropy, memory    
    
    def softmax(self, logits):
        # run softmax with self.temperature
        logits = logits / self.temperature
        # normalize the logits
        logits = (logits - torch.min(logits)) / (torch.max(logits) - torch.min(logits) + self.eps)
        
        return torch.softmax(logits, dim=0)
        
    def get_action(self, x, action=None, replay=None, greedy=True, adaptive_temperature=False):
        memory = []
        if replay is not None:
            g, nf, ef = replay[0]
            agent_logits = self.actor.replay(g, nf, ef, mode='agent')
        else:
            agent_logits, g, nf, ef = self.actor(x, memory=True)
            memory.append((g, nf, ef))
        
        # if not greedy:
        #     # explore with epsilon greedy instead
        #     prob = np.random.rand()
        #     if prob > 0.1:
        #         greedy = True
        
        # softmax the agent_logits with temperature
        agent_logits_ = self.softmax(agent_logits)
        if adaptive_temperature:
            agent_logits_ = self.adaptive_temperature_run(agent_logits_)
        # agent_logits = torch.softmax(actor_output, dim=0).flatten()
        agent_probs = Categorical(logits=agent_logits_.flatten())
        if action is None and not greedy:
            # The sampling policy for multi-agent schedule requires a balance between exploration and exploitation for different agents
            if np.random.rand() < self.epsilon:
                agent = agent_probs.sample()
            else:
                agent = torch.argmax(agent_logits_.flatten())
        elif action is None and greedy:
            agent = torch.argmax(agent_logits_.flatten())
        else:
            agent = torch.tensor(action[1])
        
        agent_log_prob = F.log_softmax(agent_logits_, dim=0)
        # if greedy:
        #     print(f"Actor: - {agent_logits.flatten()} - {agent_log_prob.flatten()} - {agent}")
        
        if replay is not None:
            g, nf, ef, mask = replay[1]
            task_output = self.actor.replay(g, nf, ef, mode='task')
            task_logits = task_output.flatten()[mask]
        else:
            task_output, g, nf, ef = self.actor(x, agent.item(), memory=True)
            memory.append((g, nf, ef, x['task_to_task_select'][0]))
            mask = x['task_to_task_select'][0]
            task_logits = task_output.flatten()[mask]
            
        # print("Actor Output:", actor_output.flatten())
        # print("Task Output:", task_output.flatten())
        # print("-"*100)
        task_logits_ = self.softmax(task_logits)
        if adaptive_temperature and task_logits.shape[0] > 1:
            task_logits_ = self.adaptive_temperature_run(task_logits_)
        # task_logits = torch.softmax(task_output_masked, dim=0).flatten()
        task_probs = Categorical(logits=task_logits_)
        if action is None and not greedy:
            # The sampling policy for multi-agent schedule requires a balance between exploration and exploitation for different agents
            if np.random.rand() < self.epsilon:
                task = task_probs.sample()
            else:
                task = torch.argmax(task_logits_)
            # task = task_probs.sample()
            task_id = torch.tensor(x['task_to_task_select'][0][task.item()])
        elif action is None and greedy:
            task = torch.argmax(task_logits_)
            task_id = torch.tensor(x['task_to_task_select'][0][task.item()])
        else:
            task = torch.tensor(np.where(mask == action[0])[0].item())
            task_id = torch.tensor(action[0])
        
        task_log_prob = F.log_softmax(task_logits_, dim=0)
        log_prob = agent_log_prob.flatten()[agent.item()] + task_log_prob.flatten()[task.item()]
        entropy = (agent_probs.entropy() + task_probs.entropy())/2
        
        return (task_id.item(), agent.item()), log_prob, entropy, memory, (agent_probs.probs, task_probs.probs)
    
    def get_agent_probs(self, x, greedy=False, adaptive_temperature=False):
        """Get agent probabilities
        Args:
            x (dict): Input data Observation
            greedy (bool): Greedy action
            adaptive_temperature (bool): Adaptive temperature
        Returns:
            tuple: Tuple of agent ID, agent probabilities, agent log probabilities
        """
        agent_logits, g, nf, ef = self.actor(x, memory=True)
        # softmax the agent_logits with temperature
        agent_logits_ = self.softmax(agent_logits)
        if adaptive_temperature:
            agent_logits_ = self.adaptive_temperature_run(agent_logits_)
        # agent_logits = torch.softmax(actor_output, dim=0).flatten()
        agent_probs = Categorical(logits=agent_logits_.flatten())
        if greedy:
            agent = torch.argmax(agent_logits_.flatten())
        else:
            agent = agent_probs.sample()
        
        agent_log_prob = F.log_softmax(agent_logits_, dim=0)

        return agent.item(), agent_probs, agent_log_prob

    def get_task_probs(self, x, agent, greedy=False, adaptive_temperature=False):
        """Get task probabilities
        Args:
            x (dict): Input data Observation
            agent (int): Agent ID
            greedy (bool): Greedy action
            adaptive_temperature (bool): Adaptive temperature
        Returns:
            tuple: Tuple of task ID, task probabilities, task log probabilities
        """
        task_output, g, nf, ef = self.actor(x, agent, memory=True)
        mask = x['task_to_task_select'][0]
        task_logits = task_output.flatten()[mask]
        task_logits_ = self.softmax(task_logits)
        if adaptive_temperature and task_logits.shape[0] > 1:
            task_logits_ = self.adaptive_temperature_run(task_logits_)
        task_probs = Categorical(logits=task_logits_)
        if greedy:
            task = torch.argmax(task_logits_)
        else:
            task = task_probs.sample()
        task_id = torch.tensor(x['task_to_task_select'][0][task.item()])
        task_log_prob = F.log_softmax(task_logits_, dim=0)
        return task_id.item(), task_probs, task_log_prob

    def adaptive_temperature_run(self, logits):
        # print(high_entropy)
        # print(f"Softmax, Temp = 1 {torch.functional.F.softmax(high_entropy, dim=0)}")
        temp = 1.0
        # print(logits.shape)
        N = logits.shape[0]
        dist = torch.functional.F.softmax(logits / temp, dim=0)
        diff = dist.max() - dist.min() # difference between the two values
        # entropy = -torch.sum(logits * torch.functional.F.log_softmax(agent_logits / temp, dim=0))
        # print(entropy)
        for i in range(10):
            if diff < 1/N:
                break
            if torch.isnan(dist).any():
                temp *= 2
                dist = torch.function.F.softmax(logits / temp, dim=0)
                break
            temp *= 0.5
            dist = torch.functional.F.softmax(logits / temp, dim=0)
            diff = dist.max() - dist.min() # difference between the two values
            # entropy = -torch.sum(high_entropy * torch.functional.F.log_softmax(high_entropy / temp, dim=0))
            # print(f"Temp = {temp} Entropy = {entropy} Diff = {diff}")
            # print(f"\t{torch.functional.F.softmax(high_entropy / temp, dim=0)}")
        return dist
    
    def get_action_and_probability(self, x, action=None):
        """Get action and probability of the action for the maximum probability
        Args:
            x (dict): Input data
        Returns:
            tuple: Tuple of action and probability
                - action (tuple): Tuple of task and agent id
                - probability (tuple): Tuple of agent and task probabilities
                - log_prob (tuple): Tuple of agent and task log probabilities
                - memory (tuple): Tuple of agent and task memory (g, nf, ef)
        """
        if action is not None:
            task = torch.tensor(action[0])
            
            task_id = action[0]
            task = torch.tensor(np.where(x['task_to_task_select'][0] == task_id)[0].item())
            
            agent = torch.tensor(action[1])
            agent_id = action[1]
            
        agent_output, a_g, a_nf, a_ef = self.actor(x, memory=True)
        # random sample using epsilon greedy
        agent_logits = torch.softmax(agent_output, dim=0).flatten()
        agent_probs = Categorical(logits=agent_logits)
        if action is None:
            # This function is called after the initial action is taken for s -> a -> s',
            # to move from s' to get a' for the critic as such this action should be on policy, 
            # therefore it is an exploitation only policy unless the action is given by the user
            agent = torch.argmax(agent_logits)
            agent_id = agent.item()
        
        agent_log_prob = agent_probs.logits[agent.item()]
        agent_prob = agent_probs.probs[agent.item()]
        
        task_output, t_g, t_nf, t_ef = self.actor(x, agent.item(), memory=True)
        task_output_masked = task_output.flatten()[x['task_to_task_select'][0]]
        task_logits = torch.softmax(task_output_masked, dim=0).flatten()
        task_probs = Categorical(logits=task_logits)
        # task = task_probs.sample()
        if action is None:
            task = torch.argmax(task_logits)
            task_id = x['task_to_task_select'][0][task.item()]
        
        task_log_prob = task_probs.logits[task.item()]
        task_prob = task_probs.probs[task.item()]
        
        log_prob = agent_log_prob + task_log_prob
        # prob = agent_probs.probs[agent.item()] * task_probs.probs[task.item()]
        return (task_id, agent_id), log_prob, torch.tensor([1.0]) * agent_prob * task_prob, torch.tensor([1.0]) * (agent_probs.entropy() + task_probs.entropy())/2
    

class HetGatSolverIndividualTaskFirst(nn.Module):
    def __init__(self, mode='attention', num_heads = 8, num_layers=4, layer_data=None):
        """Heterogenous Graph Attention Network (HetGAT) solver for Scheduling Problem
        Args:
            mode (str): Mode of the graph, must be one of ['no_edge', 'edge', 'attention']
        """
        super().__init__()
        self.actor = GraphSchedulerTaskFirst(num_heads=num_heads, num_layers=num_layers, layer_data=layer_data, task_assignment_node=True, graph=mode)
        self.eps = 1e-5
        
    def set_environment(self, env):
        pass
    
    def get_action_and_memory(self, x, replay=None, greedy=False):
        action, log_prob, entropy, memory, select_probs = self.get_action(x, replay=replay, greedy=greedy)
        return action, log_prob, entropy, memory
    
    def replay_baseline(self, x, action):
        action, log_prob, entropy, memory, select_probs = self.get_action(x, action=action)    
        return action, log_prob, entropy, memory    
    
    def get_action(self, x, action=None, replay=None, greedy=False):
        # TODO: Swap
        memory = []
        
        if replay is not None:
            g, nf, ef, mask = replay[1]
            task_output = self.actor.replay(g, nf, ef, mode='task')
            task_logits = task_output.flatten()[mask]
        else:
            task_output, g, nf, ef = self.actor(x, memory=True)
            memory.append((g, nf, ef, x['task_to_task_select'][0]))
            mask = x['task_to_task_select'][0]
            task_logits = task_output.flatten()[mask]
        
        # task_logits = torch.softmax(task_output_masked, dim=0).flatten()
        task_probs = Categorical(logits=task_logits)
        if action is None and not greedy:
            task = task_probs.sample()
            task_id = torch.tensor(x['task_to_task_select'][0][task.item()])
        elif action is None and greedy:
            task = torch.argmax(task_logits)
            task_id = torch.tensor(x['task_to_task_select'][0][task.item()])
        else:
            task = torch.tensor(np.where(mask == action[0])[0].item())
            task_id = torch.tensor(action[0])
        
        task_log_prob = F.log_softmax(task_logits, dim=0)
        
        if replay is not None:
            g, nf, ef = replay[0]
            agent_logits = self.actor.replay(g, nf, ef, mode='agent')
        else:
            agent_logits, g, nf, ef = self.actor(x, task_id, memory=True)
            memory.append((g, nf, ef))    
        
        # agent_logits = torch.softmax(actor_output, dim=0).flatten()
        agent_probs = Categorical(logits=agent_logits.flatten())
        if action is None and not greedy:
            agent = agent_probs.sample()
        elif action is None and greedy:
            agent = torch.argmax(agent_logits.flatten())
        else:
            agent = torch.tensor(action[1])
        
        agent_log_prob = F.log_softmax(agent_logits, dim=0)
        # if greedy:
        #     print(f"Actor: - {agent_logits.flatten()} - {agent_log_prob.flatten()} - {agent}")
            
        # print("Actor Output:", actor_output.flatten())
        # print("Task Output:", task_output.flatten())
        # print("-"*100)
        
        log_prob = agent_log_prob.flatten()[agent.item()] + task_log_prob.flatten()[task.item()]
        entropy = (agent_probs.entropy() + task_probs.entropy())/2
        
        return (task_id.item(), agent.item()), log_prob, entropy, memory, (agent_probs.probs, task_probs.probs)
    
    def get_action_and_probability(self, x):
        """Get action and probability of the action
        Args:
            x (dict): Input data
        Returns:
            tuple: Tuple of action and probability
                - action (tuple): Tuple of task and agent id
                - probability (tuple): Tuple of agent and task probabilities
                - log_prob (tuple): Tuple of agent and task log probabilities
                - memory (tuple): Tuple of agent and task memory (g, nf, ef)
        """
        task_output, t_g, t_nf, t_ef = self.actor(x, memory=True)
        task_output_masked = task_output.flatten()[x['task_to_task_select'][0]]
        task_logits = torch.softmax(task_output_masked, dim=0).flatten()
        task_probs = Categorical(logits=task_logits)
        # task = task_probs.sample()
        task = torch.argmax(task_logits)
        task_id = x['task_to_task_select'][0][task.item()]
        
        task_log_prob = task_probs.logits[task.item()]
        task_prob = task_probs.probs[task.item()]
        
        agent_output, a_g, a_nf, a_ef = self.actor(x, task_id, memory=True)
        agent_logits = torch.softmax(agent_output, dim=0).flatten()
        agent_probs = Categorical(logits=agent_logits)
        agent = torch.argmax(agent_logits)
        agent_id = agent.item()
        
        agent_log_prob = agent_probs.logits[agent.item()]
        agent_prob = agent_probs.probs[agent.item()]
        
        
        
        log_prob = agent_log_prob + task_log_prob
        # prob = agent_probs.probs[agent.item()] * task_probs.probs[task.item()]
        return (task_id, agent_id), log_prob, agent_prob * task_prob
    