# Here we define the training and testing algorithms for the model:
import os
from itertools import cycle
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from opacus import PrivacyEngine
from opacus.utils.batch_memory_manager import BatchMemoryManager
from opacus.validators import ModuleValidator

def train(model, trainloader, optimizer, criterion, epoch, device, dp, privacy_engine=None, neg=1, savepath=None):
    # Training
    model.train()
    running_loss = 0.0
    if dp['activate']:
        with BatchMemoryManager(
            data_loader=trainloader,
            max_physical_batch_size=128,
            optimizer=optimizer
        ) as memory_safe_data_loader:
            for i, data in enumerate(memory_safe_data_loader, 0):
                # Get the inputs
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)
                # Zero the parameter gradients
                optimizer.zero_grad()
                # Forward + backward + optimize
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                # Print statistics
                running_loss += loss.item()
    else:
        for i, data in enumerate(trainloader, 0):
            # Get the inputs
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            # Zero the parameter gradients
            optimizer.zero_grad()
            # Forward + backward + optimize
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            # Print statistics
            running_loss += loss.item()
    return model


def neggrad(model, trainloader, fgloader, optimizer, criterion, epoch, device, privacy_engine=None, neg=1, savepath=None):
    # Training
    model.train()
    running_loss = 0.0
    alpha = 0.9999
    for idx, ((input, target), (del_input, del_target)) in enumerate(zip(trainloader, cycle(fgloader))):
        input, del_input = input.float().to(device), del_input.float().to(device)
        target, del_target = target.to(device), del_target.to(device)
        
        # Forward passes
        output = model(input)
        del_output = model(del_input)
        r_loss = criterion(output, target)
        del_loss = criterion(del_output, del_target)
        
        loss = alpha * r_loss - (1 - alpha) * del_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # Print statistics
        running_loss += loss.item()
        if idx % 100 == 99:  # Print every 100 mini-batches
            print('[Epoch: %d, Batch: %5d] loss: %.3f' % (epoch + 1, idx + 1, running_loss / 100))
            running_loss = 0.0
    return model


def test(model, testloader, criterion, device):
    # Testing
    model.eval()
    correct, total, running_loss = 0, 0, 0.0
    with torch.no_grad():
        for data in testloader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            
            loss = criterion(outputs, labels)
            outputs = F.softmax(outputs, dim=-1)
            # Print statistics
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    print('Testing Accuracy: %.3f %%' % accuracy)
    print('Testing Loss: %.3f' % (running_loss / total))
    return accuracy


def NegGrad(total_epoch, target_net, target_trainloader, fgloader, target_optimizer, 
            lr_scheduler, criterion, target_testloader, device, savepath):
    """
    Implements the Negative Gradient (NegGrad) training loop for a given target network.
    """
    target_net.train()
    privacy_engine = None

    # Training loop for the specified number of epochs
    for epoch in range(total_epoch):
        print(f"Epoch {epoch} starts!")
        # Apply neggrad to the target network
        target_net = neggrad(
            target_net.to(device), target_trainloader, fgloader, 
            target_optimizer, criterion, epoch, device, privacy_engine
        )
        # Evaluate accuracy on the test set
        acc = test(target_net.to(device), target_testloader, criterion, device)

    # Save the trained model
    os.makedirs(f"{savepath}/NegGrad/", exist_ok=True)
    torch.save(target_net, f"{savepath}/NegGrad/neggrad.pth")


def target_train(dp, total_epoch, target_net, target_trainloader, target_optimizer, 
                 lr_scheduler, criterion, target_testloader, device, savepath, modelname, neggrad=False):
    """
    Trains the target network for the given number of epochs, with optional differential privacy.
    """
    target_net.train()
    privacy_engine = None

    # Enable differential privacy if activated
    if dp['activate']:
        privacy_engine = PrivacyEngine()
        target_net, target_optimizer, target_trainloader = privacy_engine.make_private_with_epsilon(
            module=target_net,
            optimizer=target_optimizer,
            data_loader=target_trainloader,
            epochs=total_epoch,
            target_epsilon=dp['epsilon'],
            target_delta=dp['delta'],
            max_grad_norm=dp['max_grad_norm'],
        )
        print(f"Using sigma={target_optimizer.noise_multiplier} and C={dp['max_grad_norm']}")

    # Training loop for the specified number of epochs
    for epoch in range(total_epoch):
        print(f"Epoch {epoch} starts!")
        # Train the target network for one epoch
        target_net = train(
            target_net.to(device), target_trainloader, target_optimizer, 
            criterion, epoch, device, dp, privacy_engine
        )
        # Step the learning rate scheduler
        lr_scheduler.step()
        # Evaluate accuracy on the test set
        acc = test(target_net.to(device), target_testloader, criterion, device)

    # Save the trained model
    os.makedirs(savepath, exist_ok=True)
    save_model_path = f"{savepath}/{modelname}.pth"
    if dp['activate']:
        # Save the differentially private model
        torch.save(target_net._module.state_dict(), save_model_path)
    else:
        # Save the standard trained model
        torch.save(target_net.state_dict(), save_model_path)


def shadow_train(total_epoch, shadow_net, shadow_trainloader, shadow_optimizer, lr_scheduler, criterion, shadow_testloader, device, savepath):
    for epoch in range(total_epoch):
        print("Epoch ", epoch, "starts!")
        dp = {'activate': False}
        shadow_net = train(shadow_net.to(device), shadow_trainloader, shadow_optimizer, criterion, epoch, dp=dp, device=device)
        lr_scheduler.step()
        acc = test(shadow_net.to(device), shadow_testloader, criterion, device)
    os.makedirs(savepath, exist_ok=True)
    torch.save(shadow_net, savepath + '/' + str(acc) + '.pth')
