import torch

def model_evaluate_imagedata(model, data_loader, device, loss_function = torch.nn.functional.cross_entropy):
    '''
    Evaluates model on all datapoints in data_loader. Averages are taken on granularity of each datapoint, 
    and not each batch. 
    '''
    model.eval()
    total_loss = 0
    total_correct = 0
    total_data_points = 0

    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.to(device), target.to(device)
            batch_size = data.size(0)
            output = model(data)
            loss = loss_function(output, target)
            
            total_loss += loss.item() * batch_size
            total_correct += output.argmax(dim=1).eq(target).sum().item()
            total_data_points += batch_size

    avg_loss = total_loss / total_data_points
    avg_accuracy = total_correct / total_data_points
    return avg_loss, avg_accuracy, total_data_points

def model_train_imagedata(model, train_loader, device, optimizer, loss_function = torch.nn.functional.cross_entropy):
    '''
    Trains model on all datapoints in data_loader. Averages are taken on granularity of each datapoint, 
    and not each batch. 
    '''
    model.train()
    total_train_loss = 0
    total_correct = 0
    total_data_points = 0

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        batch_size = data.size(0)

        optimizer.zero_grad()
        output = model(data)
        loss = loss_function(output, target)
        loss.backward()
        optimizer.step()

        total_train_loss += loss.item() * batch_size
        total_correct += output.argmax(dim=1).eq(target).sum().item()
        total_data_points += batch_size

    train_loss = total_train_loss / total_data_points
    train_accuracy = total_correct / total_data_points
    return train_loss, train_accuracy, total_data_points

def initialize_optimizer(args, model):
    if args.opt == 0:  # SGD
        optimizer = torch.optim.SGD(
            model.parameters(), 
            lr=args.lr, 
            momentum=args.beta1, 
            weight_decay=args.weight_decay
        )
    elif args.opt == 1:  # Adagrad
        optimizer = torch.optim.Adagrad(
            model.parameters(), 
            lr=args.lr, 
            weight_decay=args.weight_decay, 
            eps=args.eps
        )
    elif args.opt == 2:  # Adam
        optimizer = torch.optim.Adam(
            model.parameters(), 
            lr=args.lr, 
            betas=(args.beta1, args.beta2), 
            eps=args.eps, 
            weight_decay=args.weight_decay
        )
    elif args.opt == 3:  # AdamW
        optimizer = torch.optim.AdamW(
            model.parameters(), 
            lr=args.lr, 
            betas=(args.beta1, args.beta2), 
            eps=args.eps, 
            weight_decay=args.weight_decay
        )
    elif args.opt == 4:  # RMSprop
        optimizer = torch.optim.RMSprop(
            model.parameters(), 
            lr=args.lr, 
            alpha=args.beta2, 
            eps=args.eps, 
            weight_decay=args.weight_decay,
            momentum=args.beta1,
            centered=False 
        )
    elif args.opt == 5:  # Adadelta
        optimizer = torch.optim.Adadelta(
            model.parameters(), 
            lr=args.lr, 
            rho=args.beta2, 
            eps=args.eps, 
            weight_decay=args.weight_decay
        )
    else:
        raise ValueError(f"Unsupported optimizer type: {args.opt}")
    return optimizer