import shortuuid
from typing import Any, List, Optional, Dict
from abc import ABC
import numpy as np
import torch
import asyncio

from HeGFlow.graph.node import Node
from HeGFlow.graph.edge import Edge
from HeGFlow.agents.agent_registry import AgentRegistry
import random

max_reflections = 3

class Graph(ABC):
    """
    A framework for managing and executing a network of nodes using a language model.

    This class enables the creation of a graph structure for processing and analyzing data. Each node
    in the graph can perform specific operations, allowing for complex data processing workflows.
    The graph supports integration with language models, making it suitable for tasks that require
    natural language processing capabilities.

    The communication of the node depends on the node.spatial_predecessors and node.spatial_successors.
    
    Attributes:
        domain (str): The domain for which this graph is used.
        llm_name (str): The name of the llm that used for processing within the nodes.
        nodes (dict): A collection of nodes, each identified by a unique UUID.

    Methods:
        build_graph(): Method to be implemented for constructing the graph structure.
        add_node(node): Adds a new node to the graph with a unique identifier.
        run(inputs, num_steps=10, single_agent=False): Executes the graph for a specified number of steps, processing provided inputs.
    """

    def __init__(self, 
                domain: str,  
                llm_name: Optional[str],
                agent_names: List[str],
                decision_method: str,   
                optimized_spatial:bool = False,   
                initial_spatial_probability: float = 0.5,    
                fixed_spatial_masks:List[List[int]] = None,    
                optimized_temporal:bool = False,   
                diff:bool = False,  
                dec:bool = False,   
                rounds: int = 2,
                initial_temporal_probability: float = 0.5,   
                fixed_temporal_masks:List[List[int]] = None,   
                node_kwargs:List[Dict] = None,
                reflection:bool = True,  
                ):
        
        self.fixed_spatial_masks = torch.tensor(fixed_spatial_masks)      
        self.fixed_temporal_masks = torch.tensor(fixed_temporal_masks)    
        if fixed_spatial_masks is None:
            fixed_spatial_masks = [[1 if i!=j else 0 for j in range(len(agent_names))] for i in range(len(agent_names))]
        if fixed_temporal_masks is None:
            fixed_temporal_masks = [[1 for j in range(len(agent_names))] for i in range(len(agent_names))]  
        fixed_spatial_masks = torch.tensor(fixed_spatial_masks).view(-1)
        fixed_temporal_masks = torch.tensor(fixed_temporal_masks).view(-1)  
        assert len(fixed_spatial_masks)==len(agent_names)*len(agent_names),"The fixed_spatial_masks doesn't match the number of agents" 
        assert len(fixed_temporal_masks)==len(agent_names)*len(agent_names),"The fixed_temporal_masks doesn't match the number of agents"
        
        self.id:str = shortuuid.ShortUUID().random(length=4)
        self.domain:str = domain
        self.llm_name:str = llm_name
        self.agent_names:List[str] = agent_names
        self.optimized_spatial = optimized_spatial
        self.optimized_temporal = optimized_temporal
        self.decision_node:Node = AgentRegistry.get(decision_method, **{"domain":self.domain,"llm_name":self.llm_name})   
        print(f"domain: {self.domain}")
        self.critic_node:Node = AgentRegistry.get("CriticNode", **{"domain":self.domain,"llm_name":self.llm_name})
        self.nodes:Dict[str,Node] = {}    
        
        self.potential_spatial_edges:List[List[str, str]] = []
        self.potential_temporal_edges:List[List[str,str]] = []

        self.node_kwargs = node_kwargs if node_kwargs is not None else [{} for _ in agent_names]
        self.diff=diff
        self.rounds=rounds
        # self.dec=False
        self.dec_1=False
        self.skip_nodes = []
        self.reflection = reflection
        
        self.spatial_edges: Dict[tuple[str, str], Edge] = {}  
        self.temporal_edges: Dict[tuple[str, str], Edge] = {}   
        
        self.init_nodes() 
        self.init_potential_edges() 
       
     
        
        if self.reflection: 
        # Connect CriticNode to all nodes************5.8
            for node_id, node in self.nodes.items():
                self.critic_node.add_critic_connection(node)
           
            print(self.critic_node.critic_connections)



        init_spatial_logit = torch.log(torch.tensor(initial_spatial_probability / (1 - initial_spatial_probability))) if optimized_spatial else 10.0
        init_temporal_logit = torch.log(torch.tensor(initial_temporal_probability / (1 - initial_temporal_probability))) if optimized_temporal else 10.0
        
        init_critic_logit = torch.log(torch.tensor(0.5 / (1 - 0.5)))  
        print(self.potential_spatial_edges)
        print(self.potential_temporal_edges)

        if dec:
           
            self.decision_logits = torch.ones(5) * torch.log(torch.tensor(1.0))
     
        if not diff:
            self.spatial_masks = torch.nn.Parameter(fixed_spatial_masks,requires_grad=False)  # fixed edge masks
            self.spatial_logits = torch.nn.Parameter(torch.ones(len(self.potential_spatial_edges), requires_grad=optimized_spatial) * init_spatial_logit,requires_grad=optimized_spatial) # trainable edge logits
            self.temporal_logits = torch.nn.Parameter(torch.ones(len(self.potential_temporal_edges), requires_grad=optimized_temporal) * init_temporal_logit,requires_grad=optimized_temporal) # trainable edge logits
            self.temporal_masks = torch.nn.Parameter(fixed_temporal_masks,requires_grad=False)  # fixed edge masks
            
            #self.critic_logits = torch.nn.Parameter(torch.ones(len(self.nodes), requires_grad=True) * init_critic_logit, requires_grad=True)
            #self.critic_masks = torch.nn.Parameter(torch.ones(len(self.nodes)), requires_grad=False)



            if dec:
                self.spatial_logits_1 = torch.nn.Parameter(torch.ones(len(self.potential_spatial_edges), requires_grad=optimized_spatial) * init_spatial_logit,requires_grad=optimized_spatial)
                self.temporal_logits_1 = torch.nn.Parameter(torch.ones(len(self.potential_temporal_edges), requires_grad=optimized_temporal) * init_temporal_logit,requires_grad=optimized_temporal)
            # print(self.spatial_logits)
            # print(self.spatial_masks)
        else:
            if dec:
                self.spatial_logits_1 = torch.nn.ParameterList([torch.nn.Parameter(torch.ones(len(self.potential_spatial_edges), requires_grad=optimized_spatial) * init_spatial_logit,requires_grad=optimized_spatial) for _ in range(rounds)])
                self.temporal_logits_1 = torch.nn.ParameterList([torch.nn.Parameter(torch.ones(len(self.potential_temporal_edges), requires_grad=optimized_temporal) * init_temporal_logit,requires_grad=optimized_temporal) for _ in range(rounds-1)])
                
            self.spatial_masks = torch.nn.ParameterList([torch.nn.Parameter(fixed_spatial_masks.clone(), requires_grad=False) for _ in range(rounds)])  # fixed edge masks
            self.spatial_logits = torch.nn.ParameterList([torch.nn.Parameter(torch.ones(len(self.potential_spatial_edges), requires_grad=optimized_spatial) * init_spatial_logit,requires_grad=optimized_spatial) for _ in range(rounds)])
            self.temporal_logits = torch.nn.ParameterList([torch.nn.Parameter(torch.ones(len(self.potential_temporal_edges), requires_grad=optimized_temporal) * init_temporal_logit,requires_grad=optimized_temporal) for _ in range(rounds-1)])
            self.temporal_masks = torch.nn.ParameterList([torch.nn.Parameter(fixed_temporal_masks.clone(), requires_grad=False) for _ in range(rounds-1)])
            
            self.critic_logits = torch.nn.ParameterList([torch.nn.Parameter(torch.ones(len(self.nodes), requires_grad=True) * init_critic_logit, requires_grad=True) for _ in range(rounds)])
            self.critic_masks = torch.nn.ParameterList([torch.nn.Parameter(torch.ones(len(self.nodes)), requires_grad=False) for _ in range(rounds)])
        
    @property
    def spatial_adj_matrix(self):     
        matrix = np.zeros((len(self.nodes), len(self.nodes)))  
        for i, node1_id in enumerate(self.nodes):  
            for j, node2_id in enumerate(self.nodes):
                if self.nodes[node2_id] in self.nodes[node1_id].spatial_successors: 
                    matrix[i, j] = 1
        return matrix

    @property 
    def temporal_adj_matrix(self):     
        matrix = np.zeros((len(self.nodes), len(self.nodes)))
        for i, node1_id in enumerate(self.nodes):
            for j, node2_id in enumerate(self.nodes):
                if self.nodes[node2_id] in self.nodes[node1_id].temporal_successors: 
                    matrix[i, j] = 1
        return matrix

    @property
    def num_edges(self):
        num_edges = 0
        for node in self.nodes.values():
            num_edges += len(node.spatial_successors)
        return num_edges
    
    @property
    def num_nodes(self):
        return len(self.nodes)

    def find_node(self, id: str):
        if id in self.nodes.keys():
            return self.nodes[id]
        raise Exception(f"Node not found: {id} among "
                        f"{[node.id for node in self.nodes.values()]}")
        
    def add_node(self, node: Node):
        node_id = node.id if node.id is not None else shortuuid.ShortUUID().random(length=4)
        while node_id in self.nodes:
            node_id = shortuuid.ShortUUID().random(length=4)
        node.id = node_id
        self.nodes[node_id] = node
        return node
    
    def init_nodes(self):  
        """
        Creates and adds new nodes to the graph.
        """
        for agent_name,kwargs in zip(self.agent_names,self.node_kwargs):
            if agent_name in AgentRegistry.registry:
                kwargs["domain"] = self.domain
                kwargs["llm_name"] = self.llm_name
                agent_instance = AgentRegistry.get(agent_name, **kwargs)
                self.add_node(agent_instance)
    
    def init_potential_edges(self): 
        """
        Creates and potential edges to the graph.
        """
        # for node1_id in self.nodes.keys():
        #     for node2_id in self.nodes.keys():
        #         self.potential_spatial_edges.append([node1_id,node2_id])
        #         spatial_edge = Edge(node1_id, node2_id, node1.role, node2.role, logit_value=0.0)
        #         self.potential_temporal_edges.append([node1_id,node2_id])
        #         temporal_edge = Edge(node1_id, node2_id, node1.role, node2.role, logit_value=0.0)
        for node1_id, node1 in self.nodes.items():
            for node2_id, node2 in self.nodes.items():
                self.potential_spatial_edges.append([node1_id,node2_id])
                self.potential_temporal_edges.append([node1_id,node2_id])
            # Create and store Edge instanc
                spatial_edge = Edge(node1_id, node2_id, node1.role, node2.role, logit_value=0.0, edge_type="spatial") # 
                temporal_edge = Edge(node1_id, node2_id, node1.role, node2.role, logit_value=0.0, edge_type="temporal")
                self.spatial_edges[(node1_id, node2_id)] = spatial_edge
                self.temporal_edges[(node1_id, node2_id)] = temporal_edge
 

    def clear_spatial_connection(self):
        """
        Clear all the spatial connection of the nodes in the graph.
        """
        for node_id in self.nodes.keys():
            self.nodes[node_id].spatial_predecessors = []
            self.nodes[node_id].spatial_successors = []
        self.decision_node.spatial_predecessors = []
        self.decision_node.spatial_successors = []

        # for (source_id, target_id), edge in self.spatial_edges.items():
        #     edge.selected_count = 0
        #     edge.correct_answer_count = 0
        # self.spatial_edges: Dict[tuple[str, str], Edge] = {}
        
    
    def clear_temporal_connection(self):
        """
        Clear all the temporal connection of the nodes in the graph.
        """
        for node_id in self.nodes.keys():
            self.nodes[node_id].temporal_predecessors = []
            self.nodes[node_id].temporal_successors = []

        # for (source_id, target_id), edge in self.temporal_edges.items():
        #     edge.selected_count = 0
        #     edge.correct_answer_count = 0
        # self.temporal_edges: Dict[tuple[str, str], Edge] = {}

    def connect_decision_node(self):          
        for node_id in self.nodes.keys():
            self.nodes[node_id].add_successor(self.decision_node)



    def construct_spatial_connection(self, temperature: float = 1.0, threshold: float = None,): 

        all_selected_count = 0
        selected_edges = []  

        for edge in self.spatial_edges.values():
            all_selected_count += edge.selected_count 

        self.clear_spatial_connection() 
        log_probs = [torch.tensor(0.0, requires_grad=self.optimized_spatial)]
        
        for potential_connection, edge_logit, edge_mask in zip(self.potential_spatial_edges, self.spatial_logits, self.spatial_masks):
            out_node:Node = self.find_node(potential_connection[0]) 
            in_node:Node = self.find_node(potential_connection[1])  
            Proucb = self.spatial_edges[(out_node.id, in_node.id)].get_proucb(all_selected_count)
            
            if edge_mask == 0.0:
                continue
            elif edge_mask == 1.0 and self.optimized_spatial==False:
                if not self.check_cycle(in_node, {out_node}):
                    out_node.add_successor(in_node,'spatial')

                    self.spatial_edges[(out_node.id, in_node.id)].mark_selected()  
                    selected_edges.append((out_node.id, in_node.id))
                continue
            if not self.check_cycle(in_node, {out_node}): 
                
                final_logit = edge_logit + Proucb


                edge_prob = torch.sigmoid(final_logit / temperature)

                if threshold:  
                    edge_prob = torch.tensor(1 if edge_prob > threshold else 0)
                if torch.rand(1) < edge_prob: 
                    out_node.add_successor(in_node,'spatial')
                    log_probs.append(torch.log(edge_prob))

                    self.spatial_edges[(out_node.id, in_node.id)].mark_selected()  
                    selected_edges.append((out_node.id, in_node.id))
                else:
                    log_probs.append(torch.log(1 - edge_prob))
                    
        return torch.sum(torch.stack(log_probs)),selected_edges
    
    
    
    def construct_spatial_connection_diff(self, round:int = 0, temperature: float = 1.0, threshold: float = None,): # temperature must >= 1.0
        self.clear_spatial_connection()

        all_selected_count = 0 
        selected_edges = [] 

        for edge in self.spatial_edges.values():
            all_selected_count += edge.selected_count
            

        #print(f"All selected count: {all_selected_count}")
        log_probs = [torch.tensor(0.0, requires_grad=self.optimized_spatial)]

        for potential_connection, edge_logit, edge_mask in zip(self.potential_spatial_edges, self.spatial_logits[round], self.spatial_masks[round]):
            out_node:Node = self.find_node(potential_connection[0])
            in_node:Node = self.find_node(potential_connection[1])
            out_id = list(self.nodes).index(out_node.id)
            in_id = list(self.nodes).index(in_node.id)

            Proucb = self.spatial_edges[(out_node.id, in_node.id)].get_proucb(all_selected_count)
            if edge_mask == 0.0:
                continue
            elif edge_mask == 1.0 and self.optimized_spatial==False:
                # if round == self.rounds-1 and self.dec_1 and (self.decision_masks[in_id]==0 or self.decision_masks[out_id]==0):
                #     print(11111)
                #     continue
                if not self.check_cycle(in_node, {out_node}):
                    out_node.add_successor(in_node,'spatial')
                    self.spatial_edges[(out_node.id, in_node.id)].mark_selected()
                    # print(potential_connection)

                  
                    selected_edges.append((out_node.id, in_node.id))
                continue
            if not self.check_cycle(in_node, {out_node}):  

                final_logit = edge_logit + Proucb

                edge_prob = torch.sigmoid(edge_logit / temperature)
                #
                if threshold:
                    edge_prob = torch.tensor(1 if edge_prob > threshold else 0)
                if torch.rand(1) < edge_prob:
                    out_node.add_successor(in_node,'spatial')
                    self.spatial_edges[(out_node.id, in_node.id)].mark_selected()

                    selected_edges.append((out_node.id, in_node.id))
                    log_probs.append(torch.log(edge_prob))
                else:
                    log_probs.append(torch.log(1 - edge_prob))
                    
        return torch.sum(torch.stack(log_probs)),selected_edges
    
    def construct_temporal_connection(self, round:int = 0, temperature: float = 1.0, threshold: float = None,):  # temperature must >= 1.0
        self.clear_temporal_connection()

        all_selected_count = 0 
        selected_edges = []  
        for edge in self.temporal_edges.values():
            all_selected_count += edge.selected_count


        log_probs = [torch.tensor(0.0, requires_grad=self.optimized_temporal)]
        if round == 0:
            return torch.sum(torch.stack(log_probs))  
        for potential_connection, edge_logit, edge_mask in zip(self.potential_temporal_edges, self.temporal_logits, self.temporal_masks):
            out_node:Node = self.find_node(potential_connection[0])
            in_node:Node = self.find_node(potential_connection[1])
            Proucb = self.temporal_edges[(out_node.id, in_node.id)].get_proucb(all_selected_count)
            if edge_mask == 0.0:
                continue
            elif edge_mask == 1.0 and self.optimized_temporal==False:
                if not self.check_cycle(in_node, {out_node}):
                    out_node.add_successor(in_node,'temporal')
                    self.temporal_edges[(out_node.id, in_node.id)].mark_selected()
                    selected_edges.append((out_node.id, in_node.id))
                    # print(potential_connection)
                continue
            final_logit = edge_logit + Proucb

            edge_prob = torch.sigmoid(edge_logit / temperature)

            
            if threshold:
                edge_prob = torch.tensor(1 if edge_prob > threshold else 0)
            if torch.rand(1) < edge_prob:
                out_node.add_successor(in_node,'temporal')
                self.temporal_edges[(out_node.id, in_node.id)].mark_selected()
                selected_edges.append((out_node.id, in_node.id))
                log_probs.append(torch.log(edge_prob))
            else:
                log_probs.append(torch.log(1 - edge_prob))
                    
        return torch.sum(torch.stack(log_probs)),selected_edges


    def construct_temporal_connection_diff(self, round:int = 0, temperature: float = 1.0, threshold: float = None,):  # temperature must >= 1.0
        self.clear_temporal_connection()
        log_probs = [torch.tensor(0.0, requires_grad=self.optimized_temporal)]
        all_selected_count = 0 
        selected_edges = [] 
        for edge in self.temporal_edges.values():
            all_selected_count += edge.selected_count
        if round == 0:
            return torch.sum(torch.stack(log_probs)),selected_edges  
        for potential_connection, edge_logit, edge_mask in zip(self.potential_temporal_edges, self.temporal_logits[round-1], self.temporal_masks[round-1]):
            out_node:Node = self.find_node(potential_connection[0])
            in_node:Node = self.find_node(potential_connection[1])
            Proucb = self.temporal_edges[(out_node.id, in_node.id)].get_proucb(all_selected_count)
            if edge_mask == 0.0:
                continue
            elif edge_mask == 1.0 and self.optimized_temporal==False:
                if not self.check_cycle(in_node, {out_node}):
                    out_node.add_successor(in_node,'temporal')
                    self.temporal_edges[(out_node.id, in_node.id)].mark_selected()
                    selected_edges.append((out_node.id, in_node.id))

                continue
            final_logit = edge_logit + Proucb
          

            edge_prob = torch.sigmoid(edge_logit / temperature)
        

            edge_prob = torch.sigmoid(final_logit / temperature)
            if threshold:
                edge_prob = torch.tensor(1 if edge_prob > threshold else 0)
            if torch.rand(1) < edge_prob:
                out_node.add_successor(in_node,'temporal')
                self.temporal_edges[(out_node.id, in_node.id)].mark_selected()
                selected_edges.append((out_node.id, in_node.id))

                log_probs.append(torch.log(edge_prob))
            else:
                log_probs.append(torch.log(1 - edge_prob))
                    
        return torch.sum(torch.stack(log_probs)),selected_edges


 
    
   


    
    
    def run(self, inputs: Any, 
                  num_rounds:int = 3, 
                  max_tries: int = 3, 
                  max_time: int = 600,) -> List[Any]:
        # inputs:{'task':"xxx"}
        log_probs = 0
        for round in range(num_rounds):
            log_probs += self.construct_spatial_connection()
            log_probs += self.construct_temporal_connection(round)
            
            in_degree = {node_id: len(node.spatial_predecessors) for node_id, node in self.nodes.items()}
            zero_in_degree_queue = [node_id for node_id, deg in in_degree.items() if deg == 0]

            while zero_in_degree_queue:
                current_node_id = zero_in_degree_queue.pop(0)
                tries = 0
                while tries < max_tries:
                    try:
                        self.nodes[current_node_id].execute(inputs) # output is saved in the node.outputs
                        break
                    except Exception as e:
                        print(f"Error during execution of node {current_node_id}: {e}")
                    tries += 1
                for successor in self.nodes[current_node_id].spatial_successors:
                    if successor.id not in self.nodes.keys():
                        continue
                    in_degree[successor.id] -= 1
                    if in_degree[successor.id] == 0:
                        zero_in_degree_queue.append(successor.id)
            
            self.update_memory() 
            
        self.connect_decision_node()
        self.decision_node.execute(inputs)  
        final_answers = self.decision_node.outputs
        if len(final_answers) == 0:
            final_answers.append("No answer of the decision node")
            
        return final_answers, log_probs

    async def arun(self, input: Dict[str,str],
                  num_rounds:int = 2, 
                  max_tries: int = 3, 
                  max_time: int = 6000,
                  skip: bool=False,
                  case: bool=False) -> List[Any]:
                  
        log_probs = 0
        log_probs_skip = 0
        all_answers = []
        selected_spatial_edges_all = []  
        selected_temporal_edges_all = []  

        for round in range(num_rounds):  
            round_answers = {}
            if not self.diff:
                log_prob ,selected_spatial_edges = self.construct_spatial_connection
           
                log_probs += log_prob

                log_prob ,selected_temporal_edges = self.construct_temporal_connection(round)

                log_probs += log_prob
                
            else:
 
                log_prob ,selected_spatial_edges =  self.construct_spatial_connection_diff(round)  #
                
                log_probs += log_prob
                log_prob ,selected_temporal_edges =  self.construct_temporal_connection_diff(round)
                log_probs += log_prob
            

            selected_spatial_edges_all.extend(selected_spatial_edges)  
            selected_temporal_edges_all.extend(selected_temporal_edges)  

            in_degree = {node_id: len(node.spatial_predecessors) for node_id, node in self.nodes.items()}
            zero_in_degree_queue = [node_id for node_id, deg in in_degree.items() if deg == 0]

            in_degree_t = {node_id: len(node.temporal_predecessors) for node_id, node in self.nodes.items()}
            zero_in_degree_queue_t = [node_id for node_id, deg in in_degree_t.items() if deg == 0]

            selected_index=-1

            if round <= 5 and skip:
                # log_probs = 0
                min_logit=100
                min_node=None
                min_loss=100
                loss_t_list = []
                loss_f_list = []
                log_list=[]
                for node_id, node in self.nodes.items():
                    in_id = list(self.nodes).index(node_id)
                    count=0
                    logits_count=0.
                    loss_t=0
                    loss_f=0
                    t=1.0
                    
                    for last_node in node.spatial_successors:
                        last_id = list(self.nodes).index(last_node.id)
                        count+=1
                        # logits_count+=torch.sigmoid(t*self.spatial_logits_1[round][in_id*5+last_id])
                        
                        logits_count+=t*self.spatial_logits_1[round][in_id*5+last_id]
                        loss_t+=torch.log(1-torch.sigmoid(self.spatial_logits_1[round][in_id*5+last_id]))
                        loss_f+=torch.log(torch.sigmoid(self.spatial_logits_1[round][in_id*5+last_id]))
                    for last_node in node.spatial_predecessors:
                        last_id = list(self.nodes).index(last_node.id)
                        count+=1
                        # logits_count+=torch.sigmoid(t*self.spatial_logits_1[round][last_id*5+in_id])
                        logits_count+=t*self.spatial_logits_1[round][last_id*5+in_id]
                        loss_t+=torch.log(1-torch.sigmoid(self.spatial_logits_1[round][last_id*5+in_id]))
                        loss_f+=torch.log(torch.sigmoid(self.spatial_logits_1[round][last_id*5+in_id]))

                    log_list.append(logits_count)
                    if count==0:
                        count=1.0
                    loss_t_list.append(loss_t)
                    loss_f_list.append(loss_f)
          
                    
                p = torch.softmax(torch.tensor(log_list),dim=0)
                selected_index = torch.multinomial(p, num_samples=1, replacement=False)
                # selected_index = list(self.nodes).index(min_node)
                # selected_index = 1
                for i in range(5):
                    if i==selected_index:
                        log_probs_skip += 4.0*loss_t_list[i]
                    else:
                        log_probs_skip += loss_f_list[i]
          
            while zero_in_degree_queue: 

                current_node_id = zero_in_degree_queue.pop(0)
                tries = 0
                while tries < max_tries:
                    try:
                        # if current_node_id in need_skip:
                        
                        if list(self.nodes).index(current_node_id) == selected_index and skip:
                            # print(111)
                            # if selected_index==1:
                            self.find_node(current_node_id).outputs = ['None.']
                            
                            break
                            
                        elif self.skip_nodes:
                            
                            if list(self.nodes).index(current_node_id) == self.skip_nodes[round]:
                                self.find_node(current_node_id).outputs = ['None.']
                                break
                            
                        
                        await asyncio.wait_for(self.nodes[current_node_id].async_execute(input, critic_enabled = self.reflection, critic_node = self.critic_node),timeout=max_time) 
                        break
                    except Exception as e:
                        print(f"Error during execution of node {current_node_id}: {e}")
                    tries += 1
                for successor in self.nodes[current_node_id].spatial_successors:
                    if successor.id not in self.nodes.keys():
                        continue
                    in_degree[successor.id] -= 1
                    if in_degree[successor.id] == 0:
                        zero_in_degree_queue.append(successor.id)
            for node in self.nodes:
                round_answers[self.nodes[node].role+str(node)] = self.nodes[node].outputs
            all_answers.append(round_answers)
            self.update_memory()


        if len(self.potential_spatial_edges)>0:
            self.connect_decision_node()
            await self.decision_node.async_execute(input)
            final_answers = self.decision_node.outputs
        else:
            final_answers = list(self.nodes.values())[0].outputs
        if len(final_answers) == 0:
            final_answers.append("No answer of the decision node")

        if skip:
            return final_answers, log_probs_skip, selected_spatial_edges_all, selected_temporal_edges_all
        elif case:
            return final_answers, log_probs, all_answers
        else:
            return final_answers, log_probs, selected_spatial_edges_all, selected_temporal_edges_all
    
    def update_memory(self):
        for id,node in self.nodes.items():
            node.update_memory()
    
    def check_cycle(self, new_node, target_nodes):
        if new_node in target_nodes:
            return True
        for successor in new_node.spatial_successors:
            if self.check_cycle(successor, target_nodes):
                return True
        return False

    def update_masks(self, pruning_rate: float) -> torch.Tensor:
        if self.optimized_spatial:
            num_edges = (self.spatial_masks > 0).sum()
            num_masks = (self.spatial_masks == 0).sum()
            prune_num_edges = torch.round(num_edges*pruning_rate) if torch.round(num_edges*pruning_rate)>0 else 1
            _edge_logits = self.spatial_logits.clone()
            min_edge_logit = _edge_logits.min()
            _edge_logits[self.spatial_masks == 0] = min_edge_logit - 1.0
            sorted_edges_idx = torch.argsort(_edge_logits)
            prune_idx = sorted_edges_idx[:int(prune_num_edges + num_masks)]
            self.spatial_masks[prune_idx] = 0
            for i, (source_target, edge) in enumerate(self.spatial_edges.items()):
                if i in prune_idx:
                    edge.mask_value = 0
        
        if self.optimized_temporal:
            num_edges = (self.temporal_masks > 0).sum()
            num_masks = (self.temporal_masks == 0).sum()
            prune_num_edges = torch.round(num_edges*pruning_rate) if torch.round(num_edges*pruning_rate)>0 else 1
            _edge_logits = self.temporal_logits.clone()
            min_edge_logit = _edge_logits.min()
            _edge_logits[self.temporal_masks == 0] = min_edge_logit - 1.0
            sorted_edges_idx = torch.argsort(_edge_logits)
            prune_idx = sorted_edges_idx[:int(prune_num_edges + num_masks)]
            self.temporal_masks[prune_idx] = 0

            for i, (source_target, edge) in enumerate(self.temporal_edges.items()):
                if i in prune_idx:
                    edge.mask_value = 0
        return self.spatial_masks, self.temporal_masks
    



    def update_masks_diff(self, pruning_rate: float) -> torch.Tensor:
        if self.optimized_spatial:
            for i in range(self.rounds):
                num_edges = (self.spatial_masks[i] > 0).sum()
                num_masks = (self.spatial_masks[i] == 0).sum()
                prune_num_edges = torch.round(num_edges*pruning_rate) if torch.round(num_edges*pruning_rate)>0 else 1
                _edge_logits = self.spatial_logits[i].clone()
                min_edge_logit = _edge_logits.min()
                _edge_logits[self.spatial_masks[i] == 0] = min_edge_logit - 1.0
                sorted_edges_idx = torch.argsort(_edge_logits)
   
                prune_idx = sorted_edges_idx[:int(prune_num_edges + num_masks)]
                self.spatial_masks[i][prune_idx] = 0
        
        if self.optimized_temporal:
            for i in range(self.rounds-1):
                num_edges = (self.temporal_masks[i] > 0).sum()
                num_masks = (self.temporal_masks[i] == 0).sum()
                prune_num_edges = torch.round(num_edges*pruning_rate) if torch.round(num_edges*pruning_rate)>0 else 1
                _edge_logits = self.temporal_logits[i].clone()
                min_edge_logit = _edge_logits.min()
                _edge_logits[self.temporal_masks[i] == 0] = min_edge_logit - 1.0
                sorted_edges_idx = torch.argsort(_edge_logits)
       
                prune_idx = sorted_edges_idx[:int(prune_num_edges + num_masks)]
                self.temporal_masks[i][prune_idx] = 0
    
        return self.spatial_masks, self.temporal_masks

    def update_masks_dec(self, num_node:Any):  
        
        num_node = num_node[0]
        spatial_matrix_train = [param.reshape((num_node, num_node)) for param in self.spatial_logits_1]
        temporal_matrix_train = [param.reshape((num_node, num_node)) for param in self.temporal_logits_1]
        

        for i in range(len(spatial_matrix_train)):
            print(len(spatial_matrix_train))
            
            min = 100
            min_node = -1
            for j in range(num_node):
                sum = torch.sum(spatial_matrix_train[i][j,:]).item() + torch.sum(spatial_matrix_train[i][:,j]).item()
                # if i >= 1:
                #     sum += torch.sum(temporal_matrix_train[i-1][j,:]).item() + torch.sum(temporal_matrix_train[i-1][:,j]).item()
                count = torch.sum(self.fixed_spatial_masks[j,:]).item() + torch.sum(self.fixed_spatial_masks[:,j]).item()
                sum = sum / count
                if sum < min:
                    min = sum
                    min_node = j
            # min_node=random.randint(0, 4)
            self.skip_nodes.append(min_node)
            print(f"Round {i}: Disconnecting node {min_node}")
            for k in range(num_node):
                self.spatial_masks[i][min_node*num_node+k]=0
                self.spatial_masks[i][k*num_node+min_node]=0
            if i > 0:
                for k in range(num_node):
                    self.temporal_masks[i-1][k*num_node+min_node]=0
            if i < len(spatial_matrix_train) - 1:
                for k in range(num_node):
                    self.temporal_masks[i][min_node*num_node+k]=0

    def update_masks_stage_one(self, num_node_tuple: Any, num_nodes_to_keep: int): 
        
        num_node = num_node_tuple[0] 

        
        spatial_logits_final = [param.reshape((num_node, num_node)) for param in self.spatial_logits_1]
        temporal_logits_final = [param.reshape((num_node, num_node)) for param in self.temporal_logits_1]
        
        node_strengths_per_round = []
        for i_round in range(len(spatial_logits_final)):
            spatial_probs_round = torch.sigmoid(spatial_logits_final[i_round])
            node_activity_spatial = torch.sum(spatial_probs_round, dim=0) + torch.sum(spatial_probs_round, dim=1)
            
            node_activity_temporal = torch.zeros_like(node_activity_spatial)
            if i_round < len(temporal_logits_final): 
                temporal_probs_outgoing = torch.sigmoid(temporal_logits_final[i_round])
                node_activity_temporal += torch.sum(temporal_probs_outgoing, dim=1)
            if i_round > 0:
                temporal_probs_incoming = torch.sigmoid(temporal_logits_final[i_round-1])
                node_activity_temporal += torch.sum(temporal_probs_incoming, dim=0)
                
            node_strengths_per_round.append(node_activity_spatial + node_activity_temporal)

        avg_node_strength_across_rounds = torch.mean(torch.stack(node_strengths_per_round), dim=0)

     
        N_keep = num_nodes_to_keep 
        _, top_k_indices = torch.topk(avg_node_strength_across_rounds, k=N_keep)
        retained_nodes = top_k_indices.tolist()
        
        print(f"Retaining {N_keep} nodes: {retained_nodes}")
        self.retained_nodes = retained_nodes 
        for i in range(len(self.spatial_masks)):
            self.spatial_masks[i].fill_(0) 
        for i in range(len(self.temporal_masks)):
            self.temporal_masks[i].fill_(0) 

  
        for i in range(len(spatial_logits_final)):
           
            for r_node in retained_nodes:
                for c_node in retained_nodes:
                    
                    self.spatial_masks[i][r_node * num_node + c_node] = 1 
                    self.spatial_masks[i][c_node * num_node + r_node] = 1 
                    
            
            if i < len(temporal_logits_final): 
                for r_node_curr in retained_nodes:
                    for r_node_next in retained_nodes:
                        self.temporal_masks[i][r_node_curr * num_node + r_node_next] = 1
            
 
            if i > 0 and i-1 < len(temporal_logits_final): # Incoming to current round
                for r_node_prev in retained_nodes:
                    for r_node_curr in retained_nodes:
                        self.temporal_masks[i-1][r_node_prev * num_node + r_node_curr] = 1
                        
