
import argparse
import torch
import numpy as np
import random
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from copy import deepcopy
# import wandb
from utils import *



def Trainer_pro(arg, net, trainloader, testloader,device, mask):
    val_accuracy = 100.

    # %% ========= setting ==========
    # Use the cross entropy loss function in the neural network toolbox nn
    criterion = nn.CrossEntropyLoss()
   

    optimizer = optim.Adam(filter(lambda param: param.requires_grad, net.parameters()), lr=arg.pro_lr)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)


    for epoch in range(arg.num_epoch_tr):  
        # set to the eval mode to fix the paramaters of batchnorm since we only keep some conv weights secure
        net.eval() 
        sum_loss = 0.0
        correct = 0.0
        total = 0.0



        for i, (inputs, targets) in enumerate(trainloader):
            inputs, targets = inputs.to(device), targets.to(device)

            
            optimizer.zero_grad()

            outputs = net(inputs, 0)
            
            loss = -criterion(outputs, targets) 
            
            loss.backward()

            
            for name, para in net.named_parameters():
               
                para.grad *= mask[name].long()
               
            optimizer.step()
            
            
            sum_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)  # Update the number of test pictures
            correct += (predicted == targets).sum()  # Update the number of correctly classified pictures

        scheduler.step()

        with torch.no_grad():
                for name, para in net.named_parameters():
                    para.data = (1-mask[name]) * para + (mask[name]*para).clamp_(arg.min_w, arg.max_w)

        
  
        if arg.verbose_flag == 1:
            print('========== Protector Training: [epoch:%d] Loss: %.03f | Acc: %.3f%%' % (epoch + 1, sum_loss / total,  100. * correct / total))
       

        acc = inference(net,device,testloader)
       
        if arg.verbose_flag == 1:
            print(f'Val: | Acc: {acc:.5f}')

        if acc < val_accuracy:
            val_accuracy = acc
            val_iter = 0
            state_dict = deepcopy(net.state_dict())
        else:
            val_iter = val_iter + 1
        if val_iter == arg.max_val_iter:
            if arg.verbose_flag == 1:
                print("Validation accuracy did not improve for the last {} validation runs. Early stopping..."
                      .format(arg.max_val_iter))
            break

    return val_accuracy, state_dict
 


def Trainer_att(arg, net, trainloader, testloader, device):

    val_accuracy = 0.

    # %% ========= setting ==========
    # Use the cross entropy loss function in the neural network toolbox nn
    criterion = nn.CrossEntropyLoss()
    # Use SGD (stochastic gradient descent) optimization, learning rate is 0.001, momentum is 0.9

    # optimizer = optim.SGD(net.classifier[-1].parameters(), lr=0.001)
    optimizer = optim.Adam(net.parameters(), lr=arg.att_lr)
    # scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

    for epoch in range(arg.num_epoch_att):  # Specify how many epochs to cycle through the training
        # set to the eval mode to fix the paramaters of batchnorm
        net.train()
        sum_loss = 0.0
        correct = 0.0
        total = 0.0

        for i, (inputs, targets) in enumerate(trainloader):
            inputs, targets = inputs.to(device), targets.to(device)

            # Initialize the grad value of the parameter to
            optimizer.zero_grad()
            # forward + backward + optimize
            outputs = net(inputs, 0)
            # Use cross entropy to calculate loss for output and labels
            loss = criterion(outputs, targets)
            # print('loss_c \n',loss_c)
            # Backpropagation
            loss.backward()

            optimizer.step()

            # loss.item() converted to numpy
            # loss itself is of Variable type, so use loss.data[0] to get its Tensor, because it is a scalar, so take 0
            sum_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)  # Update the number of test pictures
            correct += (predicted == targets).sum()  # Update the number of correctly classified pictures
            # break

        if arg.verbose_flag == 1:
            print('========== Attacker Training: [epoch:%d] Loss: %.03f | Acc: %.3f%%' % (
            epoch + 1, sum_loss / total, 100. * correct / total))

        acc = inference(net, device, testloader)
        if arg.verbose_flag == 1:
            print(f'Val: | Acc: {acc:.5f}')

        if acc > val_accuracy:
            val_accuracy = acc
            val_iter = 0
            state_dict = deepcopy(net.state_dict())
        else:
            val_iter = val_iter + 1
        if val_iter == arg.max_val_iter:
            if arg.verbose_flag == 1:
                print("Validation accuracy did not improve for the last {} validation runs. Early stopping..."
                      .format(arg.max_val_iter))
            break

    return val_accuracy, state_dict

# if __name__ == '__main__':
    # net.load_state_dict(victim_state_dict)
    # # Generate mask
    # layer_list = [64, 128, 256, 256, 512, 512, 512, 512, 4096, 4096, 10]
    # layer_filters = [np.random.choice(i) for i in layer_list]
    # mask_basic = mask_layer(layer_filters, net)
    # # Get the model with degraded accuracy
    # basic_accuracy, basic_state = Trainer_pro(args, net, trainloader_pro, testloader_pro, device, mask_basic)
    # print('The model with base protection: {:.3%}'.format(basic_accuracy))
    # net.load_state_dict(basic_state)
    # # eval_accuracy = inference(net, device, testloader_att)
    # # dir_use_acc.append(eval_accuracy)
    # # print('Direct use under base protection: {:.3%}'.format(eval_accuracy))
    # # Fine-tune model under base protect
    # basic_tran_acc, _ = Trainer_att(args, net, trainloader_att, testloader_att, device)
    # trans_acc.append(basic_tran_acc)
    # print('The transfer results with base protection: {:.3%}'.format(basic_tran_acc))
