from resnetPAAM import *
import torchvision.transforms as transforms
from torch.utils.data.sampler import SubsetRandomSampler
import numpy as np
import random
import torch
import torchvision
import torch.optim as optim
import argparse
import os

parser = argparse.ArgumentParser()
parser.add_argument("--model", default="ResNet56")
parser.add_argument("--batch_size", default=128, type=int)
parser.add_argument("--landa", default=0.0005, type=float)
parser.add_argument("--input_path", default='./resnet56_50warmup', type=str)
parser.add_argument("--output_path", default='./resnet56_pruned', type=str)

args = parser.parse_args()

batch_size = args.batch_size
test_batch_size = args.batch_size
num_workers = 2
landa = args.landa
output_dir = args.output_path
device = torch.device(0)
print(device)

torch.manual_seed(123)
torch.cuda.manual_seed(123)
np.random.seed(123)
random.seed(123)

if args.model == "ResNet56":
    model = CifarResNet(BasicBlock, [9] * 3)
    block_length = 9
elif args.model == "ResNet110":
    model = CifarResNet(BasicBlock, [18] * 3)
    block_length = 18

PATH = args.input_path + '/checkpoint.pt'
model.load_state_dict(torch.load(PATH), strict=False)


model.to(device)
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# Normalize the test set same as training set without augmentation
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
n_training_samples = 49999 # Max: 50 000 - n_val_samples
n_val_samples = 1
n_test_samples = 10000

train_sampler = SubsetRandomSampler(np.arange(n_training_samples, dtype=np.int64))
val_sampler = SubsetRandomSampler(np.arange(n_training_samples, n_training_samples + n_val_samples, dtype=np.int64))
test_sampler = SubsetRandomSampler(np.arange(n_test_samples, dtype=np.int64))

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

train_set_augmented = torchvision.datasets.CIFAR10(root='./data', train=True,
                                                   download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(train_set_augmented, batch_size=test_batch_size,
                                                     sampler=train_sampler,num_workers=num_workers)



def trainable_network_weights(net, requires_grad=False):
    net.train(mode=requires_grad)
    for name, param in net.named_parameters():
        if 'S' not in name:
            param.requires_grad = requires_grad


def trainable_score_weights(net, requires_grad=False):
    for name, param in net.named_parameters():
        if 'S' in name:
            param.requires_grad = requires_grad


correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        # images, labels = data
        inputs, labels = data[0].to(device), data[1].to(device)
        # calculate outputs by running images through the network
        _, outputs = model(inputs, is_score_training = False)
        # the class with the highest energy is what we choose as prediction
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Test Accuracy: {:.1f}'.format(
    100 * float(correct) / total))


scores_optimizer = optim.Adam(model.parameters(), lr=0.000001)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()


for epoch in range(90):  # loop over the dataset multiple times

    running_loss = 0.0
    if epoch % 9 in [0, 1, 2]:
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data[0].to(device), data[1].to(device)
            trainable_network_weights(model, requires_grad=False)
            trainable_score_weights(model, requires_grad=True)

            # zero the parameter gradients
            scores_optimizer.zero_grad()

            # forward + backward + optimize
            scores, outputs = model(inputs, is_score_training=0)
            loss = criterion(outputs, labels)
            index = 0
            for score in scores:
                l1_norm1 = torch.abs(torch.norm(score, p=1))
                if index == 0:
                    loss += l1_norm1 * landa
                elif index<2*block_length+1:
                    loss += l1_norm1 * landa * 16
                elif index<4*block_length+1:
                    loss += l1_norm1 * landa * 4
                else:
                    loss += l1_norm1 * landa
                index+=1

            loss.backward()
            scores_optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % 201 == 0:  # print every 2000 mini-batches
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / 200))
                for score in scores:
                    print(torch.sum(score))
                    print(score)

                running_loss = 0.0
    elif epoch % 9 in [3, 4, 5, 6, 7, 8]:
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data[0].to(device), data[1].to(device)
            trainable_network_weights(model, requires_grad=True)
            trainable_score_weights(model, requires_grad=False)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            scores, outputs = model(inputs, is_score_training=1)
            # outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % 201 == 200:  # print every 2000 mini-batches
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / 200))
                # print(torch.sum(scores1), torch.sum(scores2), torch.sum(scores3), torch.sum(scores4))
                running_loss = 0.0
    else:
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data[0].to(device), data[1].to(device)
            trainable_network_weights(model, requires_grad=True)
            trainable_score_weights(model, requires_grad=False)
            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            scores, outputs = model(inputs, is_score_training=2)
            # outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % 201 == 200:  # print every 2000 mini-batches
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / 200))
                # print(torch.sum(scores1), torch.sum(scores2), torch.sum(scores3), torch.sum(scores4))
                running_loss = 0.0


PATH = output_dir / "Model.p"
torch.save(model.state_dict(), PATH)
with torch.no_grad():
    for data in testloader:
        # images, labels = data
        inputs, labels = data[0].to(device), data[1].to(device)
        # calculate outputs by running images through the network
        _, outputs = model(inputs, is_score_training=1)
        # the class with the highest energy is what we choose as prediction
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Test Accuracy: {:.1f}'.format(
    100 * float(correct) / total))



with torch.no_grad():
    for data in testloader:
        # images, labels = data
        inputs, labels = data[0].to(device), data[1].to(device)
        # calculate outputs by running images through the network
        scores, outputs = model(inputs, is_score_training=1)
        # the class with the highest energy is what we choose as prediction
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Test Accuracy: {:.1f}'.format(
    100 * float(correct) / total))

import pickle


pickle.dump(scores, open(output_dir / "Res_scores.p", "wb"))




