
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.categorical import Categorical

from scheduling.environment import SchedulingEnvironment
from models.graph_scheduler import GraphSchedulerIndividual, GraphSchedulerCritic, GraphSchedulerTaskFirst, GraphTaskAllocation

from collections import defaultdict

class HetGatSolverSimultaneous(nn.Module):
    def __init__(self, mode='attention', num_heads = 8, num_layers=4, layer_data=None, temperature=1.0):
        """Heterogenous Graph Attention Network (HetGAT) solver for Scheduling Problem
        Args:
            mode (str): Mode of the graph, must be one of ['no_edge', 'edge', 'attention']
        """
        super().__init__()
        # self.actor = GraphSchedulerIndividual(num_heads=num_heads, num_layers=num_layers, layer_data=layer_data, task_assignment_node=True, graph=mode)
        
        self.task_allocator = GraphTaskAllocation(num_heads=num_heads, num_layers=num_layers, output_mode='task_assignment', task_assignment_node=True, graph=mode)
        self.scheduler = GraphTaskAllocation(num_heads=num_heads, num_layers=num_layers, output_mode='task_assignment', task_assignment_node=True, graph=mode)
        self.eps = 1e-5
        self.temperature = temperature
        self.epsilon = 0.1
        
    def set_environment(self, env, rerun=False):
        self.complete = False
        self.schedule = []
        self.env = env

    def get_task_allocation(self, observation, greedy=False):
        """Get task allocation from the observation"""
        task_allocation = self.task_allocator(observation)
        num_agent = len(observation['agent'])
        num_tasks = len(observation['task'])
        task_allocation_matrix = task_allocation.view(num_agent, num_tasks)
        task_assignment_probs = nn.functional.softmax(task_allocation_matrix, dim=0)
        
        selector = Categorical(logits=task_allocation_matrix.T)
        
        if not greedy:
            task_assignment = selector.sample()
        # Use an epsilon greedy sampling strategy if greedy is False. This is to ensure that the sampled policies are on-policy or close to the existing greedy policy.
        else:
            task_assignment = torch.argmax(task_allocation_matrix, dim=0)
            # task_assignment_sampled = selector.sample()
            # # of the N Tasks, for each one, sample based on epsilon greedy
            # task_assignment = torch.stack([task_assignment_sampled[i] if np.random.rand() < self.epsilon else task_assignment_greedy[i] for i in range(num_tasks)])
        # if not greedy:
        #     task_assignment = selector.sample()
        # else:
        #     task_assignment = torch.argmax(task_allocation_matrix, dim=0)
        log_prob = selector.log_prob(task_assignment)
        entropy = selector.entropy()
        return task_assignment.detach().numpy(), log_prob, entropy
    
    def get_scheduling(self, observation, task_allocation, greedy=False):
        """Get scheduling from the observation and task allocation"""
        scheduling = self.scheduler(observation)
        # task_scores = scheduling.flatten() # if using task as the output
        num_agent = len(observation['agent'])
        num_tasks = len(observation['task'])
        scheduling_matrix = scheduling.view(num_agent, num_tasks)
        # Task Scores are the scores for each task for each agent based on the task allocation
        task_ids = np.array([i for i in range(num_tasks)])
        task_scores = scheduling_matrix[task_allocation, task_ids]
        
        log_probs = []
        entropies = []
        if greedy:
            order = torch.argsort(task_scores, descending=True)
            ordered_scores = task_scores[order]
            for i in range(len(order)):
                prob_dist = torch.nn.functional.softmax(ordered_scores[i:], dim=0)
                m = Categorical(logits=prob_dist)
                log_probs.append(m.logits[0])
                entropies.append(m.entropy())
        else:
            order = torch.zeros_like(task_scores, dtype=torch.int64)
            indices = [i for i in range(len(task_scores))]
            assigned_tasks = []
            for i in range(len(order)):
                unassigned_tasks = [task for task in indices if task not in assigned_tasks]
                logit_scores = task_scores[unassigned_tasks]
                m = Categorical(logits=logit_scores)
                task_id_idx = m.sample().item()
                # run epsilon greedy sample to remain on policy
                # if np.random.rand() < self.epsilon:
                # else:
                #     task_id_idx = torch.argmax(logit_scores).item()
                # task_id_idx = m.sample().item()
                task_id = unassigned_tasks[task_id_idx]
                order[i] = task_id
                assigned_tasks.append(task_id)
                log_probs.append(m.logits[task_id_idx])
                entropies.append(m.entropy())
        log_probs = torch.stack(log_probs)
        entropies = torch.stack(entropies)
        return order.detach().numpy(), log_probs, entropies.mean()
    
        # order_scores = torch.nn.functional.softmax(task_scores, dim=0)
        # scheduling_matrix = scheduling.view(num_agent, num_tasks)
        
        # scheduling_scores = []
        # for task_id, agent_id in enumerate(task_allocation):
        #     scheduling_scores.append(scheduling_matrix[agent_id, task_id])
        # scheduling_scores = torch.stack(scheduling_scores)
        
        # order_scores = torch.nn.functional.softmax(scheduling_scores, dim=0)
        # sampler = Categorical(logits=order_scores) # before adding any additional noise
        # if not greedy:
        #     order_scores = order_scores + torch.randn_like(order_scores) * 0.5
        #     order_scores = torch.nn.functional.softmax(order_scores, dim=0)
        
        # order = torch.argsort(order_scores, descending=True)
        
        # return order, sampler.logits, sampler.entropy()

    def get_action_and_memory(self, x, replay=None, greedy=False, adaptive_temperature=False):
        action, log_prob, entropy, memory, select_probs = self.get_action(x, replay=replay, greedy=greedy, adaptive_temperature=adaptive_temperature)
        return action, log_prob, entropy, memory
    
    def get_action_and_memory_batched(self, x, batch_size, replay=None, greedy=False, adaptive_temperature=False):
        if batch_size is None or batch_size == 0:
            return self.get_action_and_memory(x, replay=replay, greedy=greedy, adaptive_temperature=adaptive_temperature)
        observation = x
        task_allocation = self.task_allocator(observation)
        num_agent = len(observation['agent'])
        num_tasks = len(observation['task'])
        task_allocation_matrix = task_allocation.view(num_agent, num_tasks)
        # task_assignment_probs = nn.functional.softmax(task_allocation_matrix, dim=0)        
        selector = Categorical(logits=task_allocation_matrix.T)
        
        task_assignments = []
        task_assignments.append(torch.argmax(task_allocation_matrix, dim=0).detach().numpy())
        for i in range(batch_size):
            task_assignments.append(selector.sample().detach().numpy())
        
        scheduling = self.scheduler(observation)
        scheduling_matrix = scheduling.view(num_agent, num_tasks)

        schedules = []
        # Task Scores are the scores for each task for each agent based on the task allocation
        task_ids = np.array([i for i in range(num_tasks)])
        task_scores = scheduling_matrix[task_allocation, task_ids]

        schedules.append(torch.argsort(task_scores, descending=True).detach().numpy())
        for i in range(batch_size):
            # sample along the task_scores
            order = torch.zeros_like(task_scores, dtype=torch.int64)
            indices = [i for i in range(len(task_scores))]
            assigned_tasks = []
            for j in range(len(order)):
                unassigned_tasks = [task for task in indices if task not in assigned_tasks]
                logit_scores = task_scores[unassigned_tasks]
                m = Categorical(logits=logit_scores)
                task_id_idx = m.sample().item()
                task_id = unassigned_tasks[task_id_idx]
                order[i] = task_id
                assigned_tasks.append(task_id)
            schedules.append(order.detach().numpy())
            
        ret_schedules = []
        for i, (task_assignment, schedule) in enumerate(zip(task_assignments, schedules)):
            full_schedule = [(task.item(), task_assignment[task]) for task in schedule]
            ret_schedules.append(full_schedule)
        return ret_schedules
    
    def replay_baseline(self, x, action):
        action, log_prob, entropy, memory, select_probs = self.get_action(x, action=action)    
        return action, log_prob, entropy, memory    
    
    def softmax(self, logits):
        # run softmax with self.temperature
        logits = logits / self.temperature
        return torch.softmax(logits, dim=0)
        
    def get_action(self, x, action=None, replay=None, greedy=True, adaptive_temperature=False):
        task_allocation, ta_log_prob, ta_entropy = self.get_task_allocation(x, greedy=greedy)
        schedule, s_log_prob, s_entropy = self.get_scheduling(x, task_allocation, greedy=greedy)
        
        # reorder the task allocation log probabilities to match the schedule
        ta_log_prob = ta_log_prob[schedule] # this is going to work with discounted rewards
        
        full_schedule = [(task.item(), task_allocation[task]) for task in schedule]
        # log_probs = torch.stack([s_log_prob[task] + ta_log_prob[task] for task in schedule])
        return full_schedule, (ta_log_prob, s_log_prob), (ta_entropy, s_entropy), None, None
    
    def get_agent_probs(self, x, greedy=False, adaptive_temperature=False):
        """Get agent probabilities
        Args:
            x (dict): Input data Observation
            greedy (bool): Greedy action
            adaptive_temperature (bool): Adaptive temperature
        Returns:
            tuple: Tuple of agent ID, agent probabilities, agent log probabilities
        """
        agent_logits, g, nf, ef = self.actor(x, memory=True)
        # softmax the agent_logits with temperature
        agent_logits_ = self.softmax(agent_logits)
        if adaptive_temperature:
            agent_logits_ = self.adaptive_temperature_run(agent_logits_)
        # agent_logits = torch.softmax(actor_output, dim=0).flatten()
        agent_probs = Categorical(logits=agent_logits_.flatten())
        if greedy:
            agent = torch.argmax(agent_logits_.flatten())
        else:
            agent = agent_probs.sample()
        
        agent_log_prob = F.log_softmax(agent_logits_, dim=0)

        return agent.item(), agent_probs, agent_log_prob

    def get_task_probs(self, x, agent, greedy=False, adaptive_temperature=False):
        """Get task probabilities
        Args:
            x (dict): Input data Observation
            agent (int): Agent ID
            greedy (bool): Greedy action
            adaptive_temperature (bool): Adaptive temperature
        Returns:
            tuple: Tuple of task ID, task probabilities, task log probabilities
        """
        task_output, g, nf, ef = self.actor(x, agent, memory=True)
        mask = x['task_to_task_select'][0]
        task_logits = task_output.flatten()[mask]
        task_logits_ = self.softmax(task_logits)
        if adaptive_temperature and task_logits.shape[0] > 1:
            task_logits_ = self.adaptive_temperature_run(task_logits_)
        task_probs = Categorical(logits=task_logits_)
        if greedy:
            task = torch.argmax(task_logits_)
        else:
            task = task_probs.sample()
        task_id = torch.tensor(x['task_to_task_select'][0][task.item()])
        task_log_prob = F.log_softmax(task_logits_, dim=0)
        return task_id.item(), task_probs, task_log_prob

    def adaptive_temperature_run(self, logits):
        # print(high_entropy)
        # print(f"Softmax, Temp = 1 {torch.functional.F.softmax(high_entropy, dim=0)}")
        temp = 1.0
        # print(logits.shape)
        N = logits.shape[0]
        dist = torch.functional.F.softmax(logits / temp, dim=0)
        diff = dist.max() - dist.min() # difference between the two values
        # entropy = -torch.sum(logits * torch.functional.F.log_softmax(agent_logits / temp, dim=0))
        # print(entropy)
        for i in range(10):
            if diff < 1/N:
                break
            if torch.isnan(dist).any():
                temp *= 2
                dist = torch.function.F.softmax(logits / temp, dim=0)
                break
            temp *= 0.5
            dist = torch.functional.F.softmax(logits / temp, dim=0)
            diff = dist.max() - dist.min() # difference between the two values
            # entropy = -torch.sum(high_entropy * torch.functional.F.log_softmax(high_entropy / temp, dim=0))
            # print(f"Temp = {temp} Entropy = {entropy} Diff = {diff}")
            # print(f"\t{torch.functional.F.softmax(high_entropy / temp, dim=0)}")
        return dist
    
    def get_action_and_probability(self, x, action=None):
        """Get action and probability of the action for the maximum probability
        Args:
            x (dict): Input data
        Returns:
            tuple: Tuple of action and probability
                - action (tuple): Tuple of task and agent id
                - probability (tuple): Tuple of agent and task probabilities
                - log_prob (tuple): Tuple of agent and task log probabilities
                - memory (tuple): Tuple of agent and task memory (g, nf, ef)
        """
        if action is not None:
            task = torch.tensor(action[0])
            
            task_id = action[0]
            task = torch.tensor(np.where(x['task_to_task_select'][0] == task_id)[0].item())
            
            agent = torch.tensor(action[1])
            agent_id = action[1]
            
        agent_output, a_g, a_nf, a_ef = self.actor(x, memory=True)
        agent_logits = torch.softmax(agent_output, dim=0).flatten()
        agent_probs = Categorical(logits=agent_logits)
        if action is None:
            agent = torch.argmax(agent_logits)
            agent_id = agent.item()
        
        agent_log_prob = agent_probs.logits[agent.item()]
        agent_prob = agent_probs.probs[agent.item()]
        
        task_output, t_g, t_nf, t_ef = self.actor(x, agent.item(), memory=True)
        task_output_masked = task_output.flatten()[x['task_to_task_select'][0]]
        task_logits = torch.softmax(task_output_masked, dim=0).flatten()
        task_probs = Categorical(logits=task_logits)
        # task = task_probs.sample()
        if action is None:
            task = torch.argmax(task_logits)
            task_id = x['task_to_task_select'][0][task.item()]
        
        task_log_prob = task_probs.logits[task.item()]
        task_prob = task_probs.probs[task.item()]
        
        log_prob = agent_log_prob + task_log_prob
        # prob = agent_probs.probs[agent.item()] * task_probs.probs[task.item()]
        return (task_id, agent_id), log_prob, torch.tensor([1.0]) * agent_prob * task_prob, torch.tensor([1.0]) * (agent_probs.entropy() + task_probs.entropy())/2
    