from collections import Counter
from typing import Dict, List, Callable, Any, Optional
import scipy.stats
import numpy as np 

# default role
DEFAULT_UNDETERMINED_ROLE = "UNDETERMINED" 

def initialize_node_data(nodes: List[str]) -> Dict[str, Dict[str, Any]]:
    
    data = {}
    for node in nodes:
        data[node] = {
            "samples": [],  
            "v1_role": DEFAULT_UNDETERMINED_ROLE,  
            "v1_count": 0,  
            "v2_role": None,  
            "v2_count": 0,  
            "confidence": 0.0, 
            "stop_sampling": False
        }
    return data

def update_role_frequencies(node_specific_data: Dict[str, Any]) -> None:
    
    samples = node_specific_data["samples"]
    if not samples:
        node_specific_data["v1_count"] = 0
        node_specific_data["v2_role"] = None
        node_specific_data["v2_count"] = 0
        return

    valid_samples = [s for s in samples if s and s.strip()]
    if not valid_samples:
        node_specific_data["v1_role"] = DEFAULT_UNDETERMINED_ROLE
        node_specific_data["v1_count"] = 0
        node_specific_data["v2_role"] = None
        node_specific_data["v2_count"] = 0
        return

    counter = Counter(valid_samples)
    most_common_roles = counter.most_common(2)  

    if len(most_common_roles) >= 1:
        node_specific_data["v1_role"] = most_common_roles[0][0]
        node_specific_data["v1_count"] = most_common_roles[0][1]

    if len(most_common_roles) >= 2:
        node_specific_data["v2_role"] = most_common_roles[1][0]
        node_specific_data["v2_count"] = most_common_roles[1][1]
    else:
        node_specific_data["v2_role"] = None
        node_specific_data["v2_count"] = 0

def has_sufficient_samples_for_confidence(
    node_specific_data: Dict[str, Any],
    min_total_samples: int = 1 
) -> bool:
    return (len(node_specific_data["samples"]) >= min_total_samples and
            node_specific_data["v1_count"] > 0)

def calculate_beta_confidence(v1_count: int, v2_count: int, v2_bias: float=1.0, v1_bias: float=1.0) -> float:
    if v1_count <= 0: 
        return 0.0
    
    v2_count = max(0, v2_count) 

    alpha = v2_count + v2_bias
    beta_parameter = v1_count + v1_bias 

    try:
        confidence = scipy.stats.beta.cdf(0.5, a=alpha, b=beta_parameter)
        return float(confidence)
    except Exception as e:
        print(f"Warning: Caculating BetaCDF Error(v1={v1_count}, v2={v2_count}): {e}") 
        return 0.0 

def adaptive_role_extraction_for_node_set(
    target_nodes_list: List[str],
    llm_role_extraction_func: Callable[[List[str]], Dict[str, str]],
    node_confidence_threshold_tau_node: float = 0.9,
    overall_stopping_logic: str = "ALL_NODES_CONFIDENT", # "ALL_NODES_CONFIDENT" or "GEOMETRIC_MEAN_CONFIDENT"
    v1_bias: float = 1.0,
    v2_bias: float = 1.0,
    global_geometric_mean_threshold: float = 0.9, # GEOMETRIC_MEAN_CONFIDENT
    max_sampling_rounds_N_max: int = 10,
    min_samples_per_node_for_beta: int = 2 
) -> Dict[str, str]:
   
    current_sampling_round = 0
    node_data = initialize_node_data(target_nodes_list) 


    while current_sampling_round < max_sampling_rounds_N_max:
        current_sampling_round += 1
        
        current_role_assignments = llm_role_extraction_func(node_data) 

        all_individual_nodes_confident_this_round = True 
        node_confidences_for_geometric_mean = [] 

        for node_id in target_nodes_list:
            try:
                assigned_role = current_role_assignments.get(node_id, None) 
            except:
                assigned_role = None
            
            if assigned_role is not None and not isinstance(assigned_role, str): 
                assigned_role = str(assigned_role)
                assigned_role = assigned_role.strip().replace("\n", "")
            
            if assigned_role is None or not assigned_role.strip():
                assigned_role = None
            
            node_data[node_id]["samples"].append(assigned_role) 
            update_role_frequencies(node_data[node_id]) 
            
            if has_sufficient_samples_for_confidence(node_data[node_id], min_samples_per_node_for_beta):
                v1 = node_data[node_id]["v1_count"]
                v2 = node_data[node_id]["v2_count"]
                node_data[node_id]["confidence"] = calculate_beta_confidence(v1_count=v1, v2_count=v2, v1_bias=v1_bias, v2_bias=v2_bias)
            else:
                node_data[node_id]["confidence"] = 0.0 
            
            if node_data[node_id]["confidence"] > node_confidence_threshold_tau_node: 
                node_data[node_id]["stop_sampling"] = True
            
            node_confidences_for_geometric_mean.append(node_data[node_id]["confidence"])
            
            if node_data[node_id]["confidence"] < node_confidence_threshold_tau_node:
                all_individual_nodes_confident_this_round = False
        
        stop_sampling_now = False
        if overall_stopping_logic == "ALL_NODES_CONFIDENT":
            stop_sampling_now = all_individual_nodes_confident_this_round
        elif overall_stopping_logic == "GEOMETRIC_MEAN_CONFIDENT":
            if node_confidences_for_geometric_mean: 
                adjusted_confidences = [max(1e-9, conf) for conf in node_confidences_for_geometric_mean]
                if adjusted_confidences: 
                    geometric_mean = np.exp(np.mean(np.log(adjusted_confidences)))
                    stop_sampling_now = geometric_mean >= global_geometric_mean_threshold
                else:
                    stop_sampling_now = False 
            else: 
                stop_sampling_now = False
        
        if stop_sampling_now:
            break 
            
    final_node_roles = {}
    for node_id in target_nodes_list:
        final_node_roles[node_id] = node_data[node_id]["v1_role"] 
    
    return final_node_roles, node_data