import shortuuid
from typing import Any, List, Optional, Dict
from abc import ABC
import numpy as np
import torch
import asyncio
import random
import math
from HeGFlow.graph.node import Node
from HeGFlow.agents.agent_registry import AgentRegistry



class Edge:
    """
    Represents a directed edge between two nodes in the graph.

    Attributes:
        source_node_name (str): The name/ID of the source node.
        target_node_name (str): The name/ID of the target node.
        source_node_role (str): The role of the source node.
        target_node_role (str): The role of the target node.
        logit_value (float): The logit value associated with this edge.
        mask_value (int):  The mask value for this edge (1 for allowed, 0 for disallowed).
        selected_count (int): The number of times this edge has been selected during graph construction.
        correct_answer_count (int): The number of times this edge contributed to a correct answer.
    """

    def __init__(self, source_node_id: str, target_node_id: str, source_node_role: str, target_node_role: str, logit_value: float, edge_type:str, mask_value: int = 1):
        self.source_node_id = source_node_id
        self.target_node_id = target_node_id
        self.source_node_role = source_node_role
        self.target_node_role = target_node_role
        self.logit_value = logit_value
        self.edge_type = edge_type
        self.mask_value = mask_value
        self.selected_count = 0
        self.correct_answer_count = 0
        

    def __repr__(self):
        return f"Edge(Source: {self.source_node_id}, Target: {self.target_node_id}, Logit: {self.logit_value}, Mask: {self.mask_value},type: {self.edge_type})"

    def mark_selected(self):  
        """Increments the selected count."""
        self.selected_count += 1

    def mark_correct(self):  
        """Increments the correct answer count."""
        self.correct_answer_count += 1

    def show_edge(self): 
        """Returns a string representation of the edge."""
        return f"Edge(Source: {self.source_node_name}, Target: {self.target_node_name}, Logit: {self.logit_value}, Mask: {self.mask_value})"
    
    def get_proucb(self, all_num:int):

        alpha = 1
        beta = 1
        average_reward = alpha*(self.correct_answer_count / (self.selected_count + 1e-10)) 
        exploration = math.sqrt(2 * math.log(all_num + 1) / (self.selected_count + 1e-10)) 
        diversity_bonus = beta / (self.selected_count + 1)
        ucb_value = average_reward + exploration +  diversity_bonus 
        return ucb_value
        





