import torch
import torch.nn as nn
import numpy as np


def estimate_sigma(dataset, net, device):
    loader = torch.utils.data.DataLoader(
        dataset, batch_size=128, shuffle=False, num_workers=2)
    criterion = nn.MSELoss(reduction="mean")
    loss = 0.
    count = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(loader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs).squeeze()
            loss += criterion(outputs, targets).cpu().item() * inputs.shape[0]
            count += inputs.shape[0]
    return np.sqrt(loss / count)


def generate_pseudo_labels(dataset, net, device, sigma=1.0):
    loader = torch.utils.data.DataLoader(
        dataset, batch_size=128, shuffle=False, num_workers=2)
    new_targets = np.zeros(len(dataset), dtype=np.int)
    count = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(loader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs).squeeze()
            pseudo_targets = outputs + sigma * torch.randn_like(outputs).to(device)
            new_targets[count:(count+len(targets))] = pseudo_targets.squeeze().cpu().numpy()
            count += len(targets)
    dataset.targets = new_targets.tolist()
    return dataset
