
import torch
import torch.nn.functional as F

def softmax_with_temperature(logits, temperature=1.0):
    return F.softmax(logits / temperature, dim=-1)

def kl_divergence(p_logits, q_logits, temperature=1.0):
    p = softmax_with_temperature(p_logits, temperature)
    q = softmax_with_temperature(q_logits, temperature)
    p_log = torch.log(p + 1e-8)
    return F.kl_div(p_log, q, reduction='batchmean')

def average_logits_across_clients(logit_list, client_sizes):
    total = sum(client_sizes)
    weighted_sum = sum([logits * size for logits, size in zip(logit_list, client_sizes)])
    return weighted_sum / total
