from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import time
import copy
import argparse
from utils import initialize_model, initialize_dataset, random_sample
from attention_module import initialize_predictor

print("PyTorch Version: ", torch.__version__)
print("Torchvision Version: ", torchvision.__version__)


"""
This file is the scheme of attention-based freezing.
An example of fine-tuning CNNs on CIFAR dataset.
"""


parser = argparse.ArgumentParser(description='test')
parser.add_argument('--epochs', type=int, default=10)
parser.add_argument('--model_name', type=str, default='resnet')
parser.add_argument('--total_layers', type=int, default=53)
parser.add_argument('--num_class', type=int, default=100)
parser.add_argument('--dataset', type=str, default='cifar100')
parser.add_argument('--re_size', type=int, default=1024)

args = parser.parse_args()
num_epochs = args.epochs  # Number of epochs to train for
dataset_name = args.dataset
model_name = args.model_name
total_frozen_layers = args.total_layers
# Number of classes in the dataset
num_classes = args.num_class
# Batch size for training (change depending on how much memory you have)
batch_size = 32
# The uniform size of parameters after resizing.
re_size = args.re_size
# Flag for feature extracting. When False, we finetune the whole model,
#   when True we only update the reshaped layer params
feature_extract = False
num_training_samples = 50000
key_list = list()
window_size = 30  # The attention window size


def Get_freeze_digit(freeze_prob):
    result_set = np.array([0, 1])
    prob = np.array([1 - freeze_prob, freeze_prob])
    result = np.random.choice(a=result_set, size=1, replace=True, p=prob)
    return result[0].item()


# The active layers
conv_active = list()
bn_active = list()

# Store the layers
conv_layer = list()
bn_layer = list()
fc_layer = list()

# Store the frozen parameters
conv_layer_param = dict()
bn_layer_param = dict()
fc_layer_param = dict()

# Store the weights for freezing prediction
conv_active_weights = dict()

# Store the frozen layers
conv_frozen = list()
bn_frozen = list()
fc_frozen = list()


def train_model(model, dataloaders, freeze_predictor, criterion, optimizer, num_epochs=1, is_inception=False):
    val_acc_history = []

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        train_time_start = time.time()
        # Training phase
        phase = 'train'
        print('Training...')
        model.train()  # Set model to training mode
        running_loss = 0.0
        running_corrects = 0

        # Iterate over data.
        i = 0
        for inputs, labels in dataloaders[phase]:
            # Store the layer weights for all the active layer
            # Do it only when there is active layer
            # print(conv_active)
            if len(conv_active):
                # Need to record the overhead here
                # Need to make sure that we can conduct the freezing at each epoch.
                if i % int((num_training_samples/batch_size) / 30) == 0:
                    # Get the weights of current models, only for active layers
                    for layer_index in conv_active:
                        # conv_active_weights[layer_index].append(conv_layer[layer_index].weight.clone().detach().cpu().reshape(-1).resize_(re_size).unsqueeze(0))
                        conv_active_weights[layer_index].append(random_sample(conv_layer[layer_index].weight.clone().detach().cpu().reshape(-1).unsqueeze(0), re_size))
            model.train()
            inputs = inputs.to(device)
            labels = labels.to(device)
            # zero the parameter gradients
            optimizer.zero_grad()
            # forward
            with torch.set_grad_enabled(phase == 'train'):
                # Get model outputs and calculate loss
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                _, preds = torch.max(outputs, 1)
                # backward + optimize only if in training phase
                if phase == 'train':
                    loss.backward()
                    optimizer.step()

            # statistics
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

            # Put the parameters back to mitigate the modification in model weights by momentum
            for k in conv_frozen:
                m = 0
                for params in conv_layer[k].parameters():
                    params.data = conv_layer_param[k][m]
                    m += 1
            for k in bn_frozen:
                m = 0
                for params in bn_layer[k].parameters():
                    params.data = bn_layer_param[k][m]
                    m += 1
            for k in fc_frozen:
                m = 0
                for params in fc_layer[k].parameters():
                    params.data = fc_layer_param[k][m]
                    m += 1
            # Record the iteration
            i += 1

        epoch_loss = running_loss / len(dataloaders[phase].dataset)
        epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
        print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

        # Freeze layers in the end of each epoch
        conv_freeze_list = []
        for p_index, p in enumerate(conv_active):
            # Get the layer frozen list
            # Concat all the tensor
            freeze_input = conv_active_weights[p][0]
            for index, weights in enumerate(conv_active_weights[p]):
                if index == 0:
                    continue
                freeze_input = torch.cat((freeze_input, weights), 1)
                if index >= 29:
                    break
            freeze_input = freeze_input.to(device)
            # Predict the freezing decision
            pred = freeze_predictor(freeze_input)
            freeze_intend = pred[0][1].item()
            if Get_freeze_digit(freeze_intend):
                conv_freeze_list.append(p)

        bn_freeze_list = conv_freeze_list.copy()

        # Do the freezing
        for i2 in conv_freeze_list:
            conv_frozen.append(i2)  # Record the frozen layer
            conv_active.remove(i2)  # Remove the corresponding entry from the list and dictionary
            conv_active_weights.pop(i2)
            for params in conv_layer[i2].parameters():
                params.requires_grad = False
                conv_layer_param.setdefault(i2, list())
                conv_layer_param[i2].append(params.data.clone().detach())

        for i2 in bn_freeze_list:
            bn_frozen.append(i2)  # Record the frozen layer
            for params in bn_layer[i2].parameters():
                params.requires_grad = False
                bn_layer_param.setdefault(i2, list())
                bn_layer_param[i2].append(params.data.clone().detach())

        # Validation phase
        phase = 'val'
        print('Validating...')
        model.eval()   # Set model to evaluate mode
        running_loss = 0.0
        running_corrects = 0
        # Iterate over data.
        for inputs, labels in dataloaders[phase]:
            inputs = inputs.to(device)
            labels = labels.to(device)
            # zero the parameter gradients
            optimizer.zero_grad()
            # forward
            # track history if only in train
            with torch.set_grad_enabled(phase == 'train'):
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                _, preds = torch.max(outputs, 1)
            # statistics
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        epoch_loss = running_loss / len(dataloaders[phase].dataset)
        epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
        print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
        # deep copy the model
        if phase == 'val' and epoch_acc > best_acc:
            best_acc = epoch_acc
            best_model_wts = copy.deepcopy(model.state_dict())
        if phase == 'val':
            val_acc_history.append(epoch_acc)

    # Store the stats
    print('Best val Acc: {:4f}'.format(best_acc))
    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, val_acc_history


# initialize the models
model_ft, input_size = initialize_model(model_name, num_classes, feature_extract, use_pretrained=True)

# initialize the datasets
dataloaders = initialize_dataset(dataset_name, input_size, batch_size)

# Print the model we just instantiated
# print(model_ft)

# Detect if we have a GPU available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Send the model to GPU
model_ft = model_ft.to(device)


# Initialize the attention based predictor
in_channel = re_size
hid_channel = 256
out_channel = 64
predictor = initialize_predictor(in_channel, hid_channel, out_channel)
predictor_path = './predictor.pth'
predictor.load_state_dict(torch.load(predictor_path))
predictor = predictor.to(device)

for name, layer in model_ft.named_modules():
    if isinstance(layer, torch.nn.Conv2d):
        conv_layer.append(layer)
    elif isinstance(layer, torch.nn.Linear):
        fc_layer.append(layer)
    elif isinstance(layer, torch.nn.BatchNorm2d):
        bn_layer.append(layer)

# Prepare the data structure needed for calculating CKA
key = 0
for name, layer in model_ft.named_modules():
    if isinstance(layer, torch.nn.Conv2d):
        key_list.append(key)
        conv_active_weights.setdefault(key, list())
        conv_active.append(key)
        bn_active.append(key)
        key += 1

params_to_update = []
# print("Params to learn:")
for name, param in model_ft.named_parameters():
    if param.requires_grad == True:
        params_to_update.append(param)
        # print("\t", name)

# Observe that all parameters are being optimized
optimizer_ft = optim.SGD(params_to_update, lr=0.001, momentum=0.9)

# Set up the loss function
criterion = nn.CrossEntropyLoss()

# Train and evaluate
model_ft, hist = train_model(model_ft, dataloaders, predictor, criterion, optimizer_ft, num_epochs=num_epochs, is_inception=(model_name=="inception"))
