import sys
import os
project_dir = os.getcwd()
sys.path.append(project_dir)

import math
import torch


def compute_kl(message_len):
    return message_len*math.log(2)+(2*math.log(message_len))

def compute_bound(m=None, n=None, delta=0.05, global_message_len=None, sum_task_message_lens=None):
    global_part = math.sqrt((compute_kl(global_message_len) + math.log(4*math.sqrt(n)/delta))/(2*n))
    task_spec_part = math.sqrt((compute_kl(global_message_len) + compute_kl(sum_task_message_lens) + math.log(1/delta))/(2*m*n))

    return global_part,task_spec_part


def compute_bound_single_task(m=None, delta=0.05, message_len=None):
    bound = math.sqrt((compute_kl(message_len) + math.log(1/delta))/(2*m))
    return bound

