import math
from collections import Counter

def calculate_entropy(labels):
    label_count = Counter(labels)
    total_count = len(labels)
    entropy = 0.0
    for count in label_count.values():
        probability = count / total_count
        entropy -= probability * math.log2(probability)
    return entropy

def calculate_information_gain(clusters, dataset):
    if dataset[0].get("answer") is not None:
        answer_list = [prompt["answer"] for prompt in dataset]
    elif dataset[0].get("label") is not None:
        answer_list = [prompt["label"] for prompt in dataset]
    else:
        raise ValueError("Dataset does not contain 'answer' or 'label' key")
    
    cluster_answer_list = list(zip(clusters, answer_list))
    
    # Calculate overall entropy
    total_entropy = calculate_entropy(answer_list)
    
    # Calculate entropy for each cluster
    cluster_entropies = {}
    for cluster_id, answer in cluster_answer_list:
        if cluster_id not in cluster_entropies:
            cluster_entropies[cluster_id] = []
        cluster_entropies[cluster_id].append(answer)
    
    weighted_entropy = 0.0
    total_points = len(answer_list)
    for cluster_id, answers in cluster_entropies.items():
        cluster_size = len(answers)
        cluster_entropy = calculate_entropy(answers)
        weighted_entropy += (cluster_size / total_points) * cluster_entropy
    
    # Calculate information gain
    information_gain = total_entropy - weighted_entropy
    
    return information_gain

def calculate_information_gain_without_neutral(clusters, dataset):
    if dataset[0].get("answer") is not None:
        answer_list = [prompt["answer"] for prompt in dataset]
    elif dataset[0].get("label") is not None:
        answer_list = [prompt["label"] for prompt in dataset]
    else:
        raise ValueError("Dataset does not contain 'answer' or 'label' key")
    
    
    filtered_clusters_answers = [(cluster, answer) for cluster, answer in zip(clusters, answer_list) if answer != "Neutral"]
    
    if not filtered_clusters_answers:
        return 0.0
    
    clusters, answer_list = zip(*filtered_clusters_answers)
    
    # Calculate overall entropy
    total_entropy = calculate_entropy(answer_list)
    
    # Calculate entropy for each cluster
    cluster_entropies = {}
    for cluster_id, answer in zip(clusters, answer_list):
        if cluster_id not in cluster_entropies:
            cluster_entropies[cluster_id] = []
        cluster_entropies[cluster_id].append(answer)
    
    weighted_entropy = 0.0
    total_points = len(answer_list)
    for cluster_id, answers in cluster_entropies.items():
        cluster_size = len(answers)
        cluster_entropy = calculate_entropy(answers)
        weighted_entropy += (cluster_size / total_points) * cluster_entropy
    
    # Calculate information gain
    information_gain = total_entropy - weighted_entropy
    
    return information_gain