import math
import torch

def init_delta(size, epsilon=1e-1, init_type="zero"):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if init_type == "zero":
        return torch.zeros(size, device=device)
    elif init_type == "rand":
        return torch.rand(size, device=device) * epsilon
    elif init_type == "randn":
        return torch.randn(size, device=device) * epsilon

def ls(P, Q, task_name):
    task_type = "classification" if task_name != "STS-B" else "regression"
    if(task_type == "classification"):
        ls_fn = torch.nn.KLDivLoss(reduction='batchmean')
        return ls_fn(P.softmax(dim=-1).log(), Q.softmax(dim=-1)) + ls_fn(Q.softmax(dim=-1).log(), P.softmax(dim=-1))
    elif(task_type == "regression"):
        ls_fn = torch.nn.MSELoss(reduction="sum")
        return ls_fn(P, Q)

def SGLD(x, grad, step, epsilon):
    noise = init_delta(x.size(), epsilon=epsilon, init_type="randn")
    x = x - step * grad + math.sqrt(2 * step) * noise
    return x

def dynamic_rate(total_iterations, current_iteration, sampling_step, warm_up):
    if (current_iteration < total_iterations * warm_up):
        sampling_step = sampling_step * (current_iteration + 1) / total_iterations
    else:
        sampling_step = sampling_step * (total_iterations - current_iteration) / total_iterations
    return sampling_step
