
import torch
import torch.nn as nn
import torchvision
import argparse
from opacus.accountants.utils import get_noise_multiplier


def compute_sigma(epsilon, delta):
    """Compute noise multiplier (sigma) based on epsilon, delta, and the number of samples."""
    sigma = get_noise_multiplier(
        target_epsilon=epsilon,
        target_delta=delta,
        sample_rate=1,
        epochs=1  # Assuming one pass for the embedding creation
    )
    return sigma


def train_step(model, data_loader, optimizer, scheduler, device):
    model.train()
    total_loss = 0
    for inputs, labels in data_loader:
        print(inputs.shape)
        inputs, labels = inputs.to(device), labels.to(device)

        # Forward pass
        outputs = model(inputs.float())
        loss = nn.functional.cross_entropy(outputs, labels.long())

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    scheduler.step()
    return total_loss / len(data_loader)

def test(model, data_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            predicted = outputs.argmax(-1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Accuracy on the test set: {accuracy:.2f}%')
    return accuracy


def get_transform(image_size, is_train=True):
    if is_train:
        return torchvision.transforms.Compose([
            torchvision.transforms.Resize(image_size),
            torchvision.transforms.RandomHorizontalFlip(0.5),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean=[0.4914, 0.482, 0.447], std=[0.247, 0.243, 0.262])
        ])
    else:
        return torchvision.transforms.Compose([
            torchvision.transforms.Resize(image_size),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean=[0.4914, 0.482, 0.447], std=[0.247, 0.243, 0.262])
        ])