import argparse
import torch, torchvision
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision.datasets as datasets
import torch.utils.data as data
import torchvision.transforms as transforms
from torch.autograd import Variable
import torchvision.models as models
import matplotlib.pyplot as plt
import time, os, copy, numpy as np
from tqdm import tqdm
from train_model import train_model
from util.masking import Mask_attr

from pdb import set_trace as bp

import torch
from collections import OrderedDict


def key_transformation(old_key):
        if old_key[:6] == "module":
            return old_key[7:]
        return old_key

def rename_state_dict_keys(source, key_transformation, target=None):
    """
    
    source             -> Path to the saved state dict.
    key_transformation -> Function that accepts the old key names of the state
                          dict as the only argument and returns the new key name.
    target (optional)  -> Path at which the new state dict should be saved
                          (defaults to `source`)
    Example:
    Rename the key `layer.0.weight` `layer.1.weight` and keep the names of all
    other keys.
    ```py
    def key_transformation(old_key):
        if old_key == "layer.0.weight":
            return "layer.1.weight"
        return old_key
    rename_state_dict_keys(state_dict_path, key_transformation)
    ```
    """
    if target is None:
        target = source

    state_dict = torch.load(source)
    new_state_dict = OrderedDict()

    for key, value in state_dict.items():
        new_key = key_transformation(key)
        new_state_dict[new_key] = value

    torch.save(new_state_dict, target)

parser = argparse.ArgumentParser(description='CNN')
parser.add_argument('--mask_attr', action='store_true', default=False
                    )

parser.add_argument('--batch_size', type=int, default=32,
                    help='input batch size for training (default: 32)')


args = parser.parse_args()

data_transforms = { 'train': transforms.Compose([transforms.ToTensor()]),
                    'val'  : transforms.Compose([transforms.ToTensor(),]) }

# data_dir = 'images/64'
data_dir = 'tiny-imagenet-200'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=args.batch_size, shuffle=True, num_workers=2)
              for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}

model_ft = models.resnet18()
#Finetune Final few layers to adjust for tiny imagenet input
model_ft.avgpool = nn.AdaptiveAvgPool2d(1)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 200)

model_ft.conv1 = nn.Conv2d(3,64, kernel_size=(3,3), stride=(1,1), padding=(1,1))
model_ft.maxpool = nn.Sequential()

device_id = "cuda:2"
device = torch.device(device_id if torch.cuda.is_available() else "cpu")
model_ft = model_ft.to(device)
# bp()
saved_path = 'models/resnet18-pretrained-tinyimagenet-wo-mask.pt'
rename_state_dict_keys(saved_path, key_transformation)
model_ft.load_state_dict(torch.load(saved_path))

# Multi GPU
model_ft = torch.nn.DataParallel(model_ft, device_ids=[2, 1])

model_ft.eval()

running_corrects = 0

for i,(inputs, labels) in enumerate(tqdm(dataloaders['val'])):
    inputs = inputs.to(device)
    labels = labels.to(device)

    #........................
    if args.mask_attr:
            # bp()
            mask_fn = Mask_attr()
            inputs = mask_fn.apply(model_ft, inputs, labels)
    #........................

    with torch.no_grad():   
        outputs = model_ft(inputs)

    _, preds = torch.max(outputs, 1)
    running_corrects += torch.sum(preds == labels.data)

epoch_acc = running_corrects.double() / dataset_sizes['val']

val_acc = epoch_acc

print('Val Acc: {:.4f}'.format(val_acc))