import sys

import torch
from tqdm import tqdm

sys.path.append(".")
from src.tools.sharpness_tools.math_utils import compute_loss
from src.tools.sharpness_tools.utils import load_weights
from torch.cuda.amp import GradScaler


def low_pass(model, data_loader, sigma, mcmc_itr):
    scalar = GradScaler()
    out = 0.0
    with torch.no_grad():
        theta_star = [p.data.clone() for p in model.parameters()]
    for _ in tqdm(range(mcmc_itr), ncols=120):

        for mp, p in zip(model.parameters(), theta_star):
            mp.data.copy_(p + torch.zeros(p.shape, device=mp.data.device).normal_(0, sigma))

        out += compute_loss(model, data_loader, scalar)[0]

    load_weights(model, theta_star)
    return out / mcmc_itr
