import random
import os
import numpy as np
import torch
from torch.nn import functional as F
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset


def set_seed(seed: int = 42) -> None:
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")

def train_epoch(data_iter, model, optimizer, device):
    model = model.to(device)
    model.train()
    total_loss = 0
    for data, targets in tqdm(data_iter, leave=False):
        data, targets = data.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(data)
        loss = F.cross_entropy(outputs, targets)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(data_iter)
        
def get_accuracy(data_iter, model, device, batch_size=64, num_workers=4):
    if isinstance(data_iter, Dataset):
        data_iter = DataLoader(data_iter, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    model = model.to(device)
    model.eval()
    correct = 0
    total_loss = 0
    with torch.no_grad():
        for data, targets in tqdm(data_iter, leave=False, desc="Get Accuracy"):
            data, targets = data.to(device), targets.to(device)
            outputs = model(data)
            pred = outputs.argmax(dim=1, keepdim=True)
            loss = F.cross_entropy(outputs, targets)
            total_loss += loss.item()
            correct += pred.eq(targets.view_as(pred)).sum().item()
    return correct / len(data_iter.dataset), total_loss / len(data_iter)

def get_optimizer(model, optimizer_name, learning_rate, weight_decay, momentum):
    if optimizer_name == "adam":
        return torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    elif optimizer_name == "sgd":
        return torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)
    else:
        raise ValueError(f"Optimizer {optimizer_name} not supported")