import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torch import nn, optim
from torchbearer import Trial
from torchbearer.callbacks import Mixup, MultiStepLR
from torchvision.datasets import VisionDataset
import matplotlib.pyplot as plt
from models import resnet
import torch.nn.functional as F
import augs_alter
import random
import numpy as np
import cv2
import math
import argparse

parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--device', default='cuda', type=str, help='Device on which to run')
parser.add_argument('--percentage', default=0.1, type=float, help='percentage of the image to be occluded')
parser.add_argument('--path', type=str)
parser.add_argument('--dataset-path', type=str, default=None, help='Optional dataset path')
args = parser.parse_args()

random.seed(0)
batch_size = 1
trainset = torchvision.datasets.CIFAR10(root=args.dataset_path, train=True, download=True,
                                           transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),]))
valset = torchvision.datasets.CIFAR10(root=args.dataset_path, train=False, download=True,
                                           transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),]))

valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=True, num_workers=8)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=8)
device = args.device if torch.cuda.is_available() else "cpu"

class ResNet_CAM(nn.Module):
    def __init__(self, net, layer_k):
        super(ResNet_CAM, self).__init__()
        self.resnet = net
        convs = nn.Sequential(*list(net.children())[:-1])
        self.first_part_conv = convs[:layer_k]
        self.second_part_conv = convs[layer_k:]
        self.linear = nn.Sequential(*list(net.children())[-1:])

    def forward(self, x):
        bs = x.shape[0]
        x = self.first_part_conv(x)
        x.register_hook(self.activations_hook)
        x = self.second_part_conv(x)
        x = F.adaptive_avg_pool2d(x, (1,1))
        x = x.view((bs, -1))
        x = self.linear(x)
        return x

    def activations_hook(self, grad):
        self.gradients = grad

    def get_activations_gradient(self):
        return self.gradients

    def get_activations(self, x):
        return self.first_part_conv(x)

def superimpose_heatmap(heatmap, img, percentage, first):
    superimposed_img = []
    masks = torch.zeros((img.shape[0], 1, img.shape[2], img.shape[3]))
    for i in range(0, len(img)):
        resized_heatmap = cv2.resize(heatmap[i].numpy(), (img[i].shape[1], img[i].shape[2]))
        resized_heatmap = np.uint8(255 * resized_heatmap)
        idx = resized_heatmap.reshape(-1).argsort()[::-1]
        mask =  resized_heatmap.reshape(-1)
        num = math.ceil(percentage * mask.size)
        mask[idx[:num]] = first
        mask[idx[num:]] = abs(first-1)
        mask = mask.reshape((1, 32, 32))
        masks[i] = torch.Tensor(mask).to(device)
    masks = masks.to(device)
    img = img.to(device)
    return img * masks

def get_grad_cam(net, imgs, percentage, first):
    net.eval()
    pred = net(imgs)
    pred[torch.arange(0, pred.shape[0]),pred.argmax(dim=1)].backward(torch.ones(pred.shape[0]).to(device))
    gradients = net.get_activations_gradient()
    pooled_gradients = torch.mean(gradients, dim=[2, 3])
    activations = net.get_activations(imgs).detach()
    for j in range(activations.size(0)):
        for i in range(activations.size(1)):
            activations[j, i, :, :] *= pooled_gradients[j][i]
    heatmap = torch.mean(activations, dim=1).cpu()
    for j in range(activations.size(0)):
        heatmap[j] = np.maximum(heatmap[j], 0)
        heatmap[j] /= torch.max(heatmap[j])
    return superimpose_heatmap(heatmap, imgs, percentage, first)

def get_acc(net, cam_net, loader, percentage, first):
    net.eval()
    correct = 0
    total = 0
    total_losss = 0

    for image, labels in loader:
        image = image.to(device)
        labels = labels.to(device)
        if cam_net:
            image = get_grad_cam(cam_net, image, percentage, first)
        outputs = net(image)

        _, predicted = torch.max(outputs.detach(), 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    accuracy = correct / total
    return accuracy

net = resnet.ResNet18(10).to(device)
net.load_state_dict(torch.load(args.path, map_location=lambda storage, loc: storage)['model'])
cam_net = ResNet_CAM(net, 6).to(device)
val_i = get_acc(net,  cam_net, valloader, args.percentage, 0)
train_i = get_acc(net, cam_net, trainloader, args.percentage, 0)
val = get_acc(net,  None, valloader, args.percentage, 0)
train = get_acc(net, None, trainloader, args.percentage, 0)

print(abs(train_i - val_i)/abs(train-val))
