import numpy as np 
import pandas as pd
import subprocess
import os
import torchvision.models as models
import torchvision
import torch
import scipy
import scipy.ndimage
import time
import random
import argparse
import cv2

from tqdm import tqdm

from copy import deepcopy

from torch.utils.data import Dataset, DataLoader

from torchvision import datasets, transforms
from torchvision.utils import save_image

from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score

from torchsummary import summary

from xml.dom import minidom

from os.path import basename

from PIL import Image

from scipy.ndimage.filters import maximum_filter
from scipy.ndimage.morphology import generate_binary_structure, binary_erosion

from skimage.segmentation import slic, mark_boundaries
from skimage.util import img_as_float
from skimage.data import astronaut
from skimage.measure import label, regionprops
from skimage.color import label2rgb

from labels_imagenet import labels_dict
from functions_imagenet import *



# Global variables
dataroot     = '/home/people/???/scratch'
xml_folder   = dataroot + '/data/ILSVRC/Annotations/CLS-LOC/train/'
train_folder = dataroot + '/data/ILSVRC/Data/CLS-LOC/train/'
val_folder   = dataroot + '/data/ILSVRC/Data/CLS-LOC/val/'
IMG_SIZE     = 224
MEAN_NORM=(0.485, 0.456, 0.406)
STD_NORM=(0.229, 0.224, 0.225)
FILTER_SIZE = 7
ALPHA = 1
LEARNING_RATE = 0.001      # Default = 0.001
WEIGHT_DECAY  = 0.0001
MOMENTUM      = 0.9
WORKERS       = 2
BATCH_SIZE    = 256
DEVICE        = 'cuda'
NUM_EPOCHS    = 1
USE_SP_BOX = False
SEGMENT_DIVISIONS = 20

PERC_USE = 0.3
TRAIN_ON_SEGMENT = True
NUM_SPS_USE = 1


# In[3]:


transformNoNormalize = transforms.Compose([
	transforms.Resize(224),
	transforms.CenterCrop(224),
	transforms.ToTensor()
])

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
								 std=[0.229, 0.224, 0.225])

transformNormalize = transforms.Compose([
		transforms.Resize(224),
		transforms.CenterCrop(224),
		transforms.ToTensor(),
		normalize
])



class NetC(torch.nn.Module):

	def __init__(self, net):
		super(NetC, self).__init__()
		self.net = net
		self.main = torch.nn.Sequential(*list(self.net.children()))[:-2]
		self.avgpool = net.avgpool
		self.linear = net.fc 

	def forward(self, x):
		C = self.main(x)      
		x = self.avgpool(C)  
		x = x.view(x.shape[0], x.shape[1])
		logits = self.linear(x)
		return logits, x, C


class netClassifier(torch.nn.Module):
	
	def __init__(self, netC):
		super(netClassifier, self).__init__()
		self.net = netC
		
	def forward(self, C):
		x = self.net.avgpool(C)
		x = x.view(-1, 2048)
		logits = self.net.linear(x)
		return logits


def get_cam(C, pred, cam_type='CAM'):
	"""
	Get a 7x7 CAM
	"""
	
	if len(C.shape) == 4:
		C = C[0]
		
	if cam_type == 'CAM':
		test_weights = WEIGHTS[pred] 
#         print(test_weights.shape, C.shape)
		C_conts = test_weights.reshape(test_weights.shape[0],1,1) * C
		return C_conts.sum(axis=0)
	
	if cam_type == 'FAM':
		gap = torch.nn.AvgPool2d(7)
		x = gap(C).view(-1)
		test_weights = WEIGHTS[pred]    
		c = x * test_weights
		nb_feature_idx = torch.argmax(c).item()
		cam = C[nb_feature_idx]  
		return cam
	
	if cam_type == 'Random':
		cam = torch.tensor(np.random.rand(7,7))
		return cam


def inverse_normalize(in_tensor, mean=MEAN_NORM, std=STD_NORM):
	tensor = in_tensor.clone().detach()
	for t, m, s in zip(tensor, mean, std):
		t.mul_(s).add_(m)
	return tensor


def get_transformed_traning_data(idx, loader):
	"""
	Takes in indexs and a training dataloader
	returns: Those indexs transformed
	"""
	
	img = new_transform(  Image.open(train_dataset.imgs[idx][0]).convert('RGB')  ).view(-1,3,IMG_SIZE, IMG_SIZE)

	return img


# In[16]:


def largest_indices(ary, n):
	"""
	Order 2d array by largest values
	"""
	
	flat = ary.flatten()
	indices = np.argpartition(flat, -n)[-n:]
	indices = indices[np.argsort(-flat[indices])]
	return np.unravel_index(indices, ary.shape)


def get_superpixel_ccrs(keep_sp=True):
	"""
	Carve out the superpixel precisely
	"""
	
	image = (inverse_normalize(QUERY_IMG, mean=MEAN_NORM, std=STD_NORM) * 255)[0].permute(1,2,0).cpu().detach().numpy().round().astype('uint8')
	
	results = list()

	segments = slic(image, n_segments=SEGMENT_DIVISIONS, sigma=5, start_label=1)
	
	for idx, region in enumerate(regionprops(segments)):

		occlude_sp = keep_ccr_superpixel(image, segments, np.unique(segments)[idx], [], keep_sp=keep_sp) 
		PIL_image  = Image.fromarray(occlude_sp)
		torch_img  = transformNormalize(PIL_image).view(-1, 3, 224, 224)
		new_logits, _, _ = netC(torch_img.to(DEVICE))
		feature_logit = new_logits[0][QUERY_PRED].item()  
		results.append([feature_logit, idx+1, torch_img, None, None, region, SEGMENT_DIVISIONS, segments ])
		
	if keep_sp:
		results = sorted(results, key=lambda x: x[0], reverse=True) 
	
	if not keep_sp:
		results = sorted(results, key=lambda x: x[0], reverse=False) 
		
	return results


def crop_center(img, cropx, cropy):
	y, x, c = img.shape
	startx = x//2 - cropx//2
	starty = y//2 - cropy//2    
	return img[starty:starty+cropy, startx:startx+cropx, :]


def get_centre_cropped_image(query_idx, training=False):
	"""
	data index, training or not
	return: centre cropped
	"""
	
	if training:
		image = cv2.imread(  train_dataset.imgs[query_idx][0]  )
	else:
		image = cv2.imread(val_dataset.img_dir + val_dataset.img_names[query_idx])
		
	image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
	h = image.shape[0]
	w = image.shape[1]
	crop = min(h, w)
	image = crop_center(image, crop, crop)
	return image


def keep_ccr_superpixel(img, segments, segVal, seg_list=list(), keep_sp=True):
	"""
	Keep only the sp area
	"""
	
	if keep_sp:
		mask = np.zeros(img.shape, dtype = "uint8")
		mask[segments == segVal] = 1

		for i in range(len(seg_list)):
			mask[segments == seg_list[i]] = 1

		arr = img * mask
		return arr
	
	else:
		mask = np.zeros(img.shape, dtype = "uint8")
		mask += 1
		mask[segments == segVal] = 0

		for i in range(len(seg_list)):
			mask[segments == seg_list[i]] = 0

		arr = img * mask
		return arr


def crop_ccr_superpixel_box(img, minr, minc, maxr, maxc, crop_image=False):
	"""
	Blackout recantalge superpixel area (i think)
	"""
	
	top_left = (minc, minr)
	to_right = maxc - minc
	down     = maxr - minr
	
	if not crop_image:
		mask = np.zeros(img.shape)
		mask[minr: minr+down, minc: minc+to_right, :] = 1
		arr = img * mask
		arr = arr.astype('uint8')
				
		return arr
	else:
		return img[minr: minr+down, minc: minc+to_right, :]


def expand2square(pil_img, background_color):
	width, height = pil_img.size
	if width == height:
		return pil_img
	elif width > height:
		result = Image.new(pil_img.mode, (width, width), background_color)
		result.paste(pil_img, (0, (width - height) // 2))
		return result
	else:
		result = Image.new(pil_img.mode, (height, height), background_color)
		result.paste(pil_img, ((height - width) // 2, 0))
		return result


def get_nn_superpixel_ccrs(query_idx, segment_divisions=[20], use_box=False):
	"""
	Carve out the superpixel precisely
	"""
	
	image = QUERY_PIL_IMG
	results = list()
	
	for numSegments in segment_divisions:
		segments = slic(image, n_segments=numSegments, sigma=5, start_label=1)

		for idx, region in enumerate(regionprops(segments)):

			minr, minc, maxr, maxc = region.bbox

			if use_box:
				occlude_sp = crop_ccr_superpixel_box(image, minr, minc, maxr, maxc, crop_image=True)
			else:
				occlude_sp = keep_ccr_superpixel(image, segments, np.unique(segments)[idx], []) 
				occlude_sp = crop_ccr_superpixel_box(occlude_sp, minr, minc, maxr, maxc, crop_image=True)
				
			PIL_image = Image.fromarray(occlude_sp)
			PIL_image = expand2square(PIL_image, background_color=0)
			torch_img = transformNormalize(PIL_image).view(-1, 3, 224, 224)
				
			_, x, _ = netC(torch_img.to(DEVICE))
			results.append([None, idx+1, torch_img, x, None, region, numSegments, segments ])
			
	return results


def get_coords_nb_feature_in_nn(Conv, cam, query_feature):

	# Threshold for saliency
	threshold = cam.flatten().max() / ALPHA
	coords = None
	max_dist = float('inf')
	for i in range(Conv.shape[2]):
		for j in range(Conv.shape[3]):
			temp_feature = Conv[:, :, i:i+1, j:j+1 ]
			dist = torch.cdist(query_feature.view(-1, 2048), temp_feature.view(-1, 2048), p=2.0).item() 
			
			if cam[i][j] > threshold:
				if dist < max_dist:
					max_dist = dist
					coords = [i, j]
					
	return max_dist, coords


def get_max_2d(a):
	maxindex = a.argmax()
	return np.unravel_index(a.argmax(), a.shape)


def crop_ccr_cam_box(img, x, y, h, w):
	
	m = torch.nn.Upsample(scale_factor=7)
	
	if len(img) == 3:
		img = img[int(y): int(y+h), int(x): int(x+w), :]
		img = m(img)
		return img
	else:
		img = img[:, :, int(y): int(y+h), int(x): int(x+w)]
		img = m(img)
		return img


def get_num_cam_pixels_to_occlude(cam_type='CAM'):
	"""
	return: num pixels CAM includes with alpha = 5
	"""
	
	pixel_cam = get_upsampled_cam_query(netC, cam_type=cam_type, upsample=True)
	threshold = pixel_cam.flatten().max() / ALPHA
	num_pixels = (pixel_cam >= threshold).flatten().sum()
	return num_pixels, pixel_cam


def find_num_sp_to_use(results, cam_pix_num, train_on_segment, perc_use):
	"""
	take cam pixel number, and find best approximation
	"""

	total_blacked_out = 0
	img = QUERY_IMG.clone().detach()
	
	if train_on_segment:
		img = torch.zeros(QUERY_IMG.shape)


	max_sal = max([x[0] for x in results])
	threshold = max_sal / BETA

	for i in range(len(results)):

		if results[i][0] < threshold:
			break

		sp_idx = results[i][1]
		segments = results[i][-1]
		img = blackout_ccr_superpixel(img, segments, sp_idx, train_on_segment=train_on_segment)

		# Change to black background
		temp = img.clone()
		temp[temp==0] = -2.1179

		logits, _, _ = netC(temp.to(DEVICE))
		logit = logits[0][QUERY_PRED].item()
		pred = torch.argmax(logits, dim=1).item()
		total_blacked_out += (segments == sp_idx).sum()
		prob = torch.softmax(logits, dim=1)[0][QUERY_PRED].item()   
		

	if 'temp' not in locals():
		total_blacked_out = 224**2
		temp = torch.zeros(QUERY_IMG.shape)
		temp[temp==0] = -2.1179

	return total_blacked_out, temp


def resize_cam_with_sp(cam, threshold, train_on_segment):
	
	if not train_on_segment:
		idx, idy = largest_indices(cam, threshold)
		mask = torch.zeros(cam.shape)
		mask += 1.
		mask[idx, idy] = 0

		img = QUERY_IMG.clone().cpu().detach()
		img *= mask
		img[img==0] = -2.1179

		return img

	else:
		idx, idy = largest_indices(cam, threshold)
		mask = torch.zeros(cam.shape)
		mask[idx, idy] = 1

		img = QUERY_IMG.clone().cpu().detach()
		img *= mask
		img[img==0] = -2.1179

		return img


def return_occluded_imgs(train_on_segment, perc_use=0.5):
	"""
	Return Occluded data for training
	"""
		
	# Get num original CAM pixels
	cam_pix_num, pixel_cam = get_num_cam_pixels_to_occlude()
	sp_results = get_superpixel_ccrs()
	
	# Find best approximation of sp data to CAM
	total_sp_blacked_out, sp_torch_img = find_num_sp_to_use(sp_results, cam_pix_num, train_on_segment=train_on_segment, perc_use=perc_use)
			
	# Get Random
	_, pixel_rand = get_num_cam_pixels_to_occlude(cam_type='Random')
	rand_torch_img = resize_cam_with_sp(pixel_rand, total_sp_blacked_out, train_on_segment)
		
	return sp_torch_img, rand_torch_img, total_sp_blacked_out





def blackout_ccr_superpixel(img, segments, segVal, train_on_segment=False):
	"""
	Choose to Keep only the sp area or not
	"""

	img = img.cpu().detach()
	
	if not train_on_segment:
		mask = np.zeros((IMG_SIZE, IMG_SIZE), dtype="float32")
		mask += 1
		mask[segments == segVal] = 0
		mask = np.array([mask, mask, mask]).reshape(1, 3, IMG_SIZE, IMG_SIZE)
		mask = torch.tensor(mask)
		arr = img * mask
		
	else:
		mask = np.zeros((IMG_SIZE, IMG_SIZE), dtype="float32")
		mask[segments == segVal] = 1
		mask = np.array([mask, mask, mask]).reshape(1, 3, IMG_SIZE, IMG_SIZE)
		mask = torch.tensor(mask)
		new_unmasked_area = QUERY_IMG.cpu() * mask
		arr = img + new_unmasked_area
	
	return arr


def get_upsampled_cam_query(netC, cam_type='CAM', upsample=True):
	"""
	For getting a pixel-level CAM
	"""
	
	temp_weights = WEIGHTS[QUERY_PRED]
	
	if cam_type=='FAM':
		c = QUERY_X * temp_weights
		nb_feature_idx = torch.argmax(c).item()
		cam = QUERY_C[nb_feature_idx].cpu().detach().numpy()  # really the FAM
		
	elif cam_type=='CAM':
		cam = get_cam(QUERY_C, QUERY_PRED).cpu().detach().numpy()
		
	if cam_type=='Random':
		cam = np.random.rand(7,7)
		
	if upsample:
		cam = scipy.ndimage.zoom(cam, (32, 32), order=3) 
		return cam
	else:
		return cam


def get_cam_image_masked(cam, threshold, occlude_pos=True):
	"""
	Take in cam and num of pixels to change
	return a mask
	"""
	
	if occlude_pos:
		idx, idy = largest_indices(cam, threshold)
		mask = torch.zeros(cam.shape)
		mask += 1.

		for i in range(len(idx)):
			x, y = idx[i], idy[i]
			mask[x][y] = 0.

		img = QUERY_IMG.clone().detach()
		img *= mask
		img[img==0] = -2.1179
	
	else:
		idx, idy = largest_indices(cam, threshold)
		mask = torch.zeros(cam.shape)

		for i in range(len(idx)):
			x, y = idx[i], idy[i]
			mask[x][y] = 1.

		img = QUERY_IMG.clone().detach()
		img *= mask
		img[img==0] = -2.1179
	
	return img


def save(occ_type, accs, iterations, epochs):
	df = pd.DataFrame()
	df['Accuracy']   = accs
	df['Dataset']    = 'ImageNet'
	df['Iterations'] = iterations
	df['Technique']  = occ_type
	df['Epoch']      = epochs
	df['BETA']       = BETA
	df['PrecIncluded'] = sum(avg_size) / len(avg_size)
	df.to_csv('data/Training_BETA_SPS_' + occ_type + '_' + str(BETA) + '.csv')


def collect_model():
	resnet = models.resnet50(pretrained=True).train()
	netC = NetC(resnet)
	weights = resnet.fc.weight
	netC.to(DEVICE)
	return netC, weights


def evaluate_validation(netC, val_loader, val_dataset):

	with torch.no_grad():

		top1_correct = 0
		top5_correct = 0

		for i, data in enumerate(val_loader):

			imgs, labels = data
			imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
			
			logits, _, _ = netC(imgs)
			preds = torch.argmax(logits, axis=1)
			top1_correct += torch.sum(preds == labels)

			_, top_5_preds = logits.topk(5)
			top5_correct += (labels.reshape(imgs.shape[0], 1).expand(imgs.shape[0], 5) == top_5_preds).sum()
				
	return top1_correct.item() / val_dataset.y.size





for BETA in [1.05, 1.1, 1.25, 1.5, 2, 3]:

	print(" ")
	print(" ")
	print(" =======================================  ")
	print("BETA:", BETA)
	print(" =======================================  ")
	print(" ")


	#### Training Loop

	torch.cuda.empty_cache()

	netC, WEIGHTS = collect_model()    # for calculating occlusions

	netC_retrain_sp, _ = collect_model()  # for retraining
	netC_retrain_rand, _ = collect_model()

	netC = netC.eval()

	netC_retrain_sp = netC_retrain_sp.train()
	netC_retrain_rand = netC_retrain_rand.train()

	train_loader, test_loader, train_dataset, test_dataset = imagenet_dataloaders(transform_train=False)
	acc = evaluate_validation(netC, test_loader, test_dataset)
	cce_loss = torch.nn.CrossEntropyLoss()

	optimizer_sp = torch.optim.SGD(netC_retrain_sp.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
	optimizer_rand = torch.optim.SGD(netC_retrain_rand.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)

	print("Start Training...")
	train_loader, test_loader, train_dataset, test_dataset = imagenet_dataloaders(transform_train=True)

	avg_size = list()
	iterations = [0]
	epochs = [0]
	current_iter = 0

	start_time = time.time()
	current_iter = 0

	sp_accs = [acc]
	rand_accs = [acc]

	for epoch in range(NUM_EPOCHS):
		for i, data in enumerate(train_loader):

			row = pd.DataFrame()

			current_iter += 1

			netC_retrain_sp.zero_grad()
			netC_retrain_sp.train()
			netC_retrain_rand.zero_grad()
			netC_retrain_rand.train()

			imgs, labels = data
			imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)

			sp_torch_imgs   = torch.zeros(imgs.shape)
			rand_torch_imgs = torch.zeros(imgs.shape)

			with torch.no_grad():

				rolling_size = list()

				logits, imgs_x, imgs_C = netC(imgs)
				preds = torch.argmax(logits, dim=1)
				
				for k in range(len(imgs)):
					QUERY_LABEL = labels[k]
					QUERY_IMG = imgs[k].view(1, 3, IMG_SIZE, IMG_SIZE)
					QUERY_PRED = preds[k]
					QUERY_X = imgs_x[k]
					QUERY_C = imgs_C[k]

					sp_torch_img, rand_torch_img, a_size = return_occluded_imgs(
						train_on_segment=TRAIN_ON_SEGMENT, perc_use=PERC_USE)

					sp_torch_imgs[k] = sp_torch_img[0]
					rand_torch_imgs[k] = rand_torch_img[0]

					# convert from pixels to % of imag
					a_size /= 224**2

					rolling_size.append( a_size ) 
					
			logits, _, _ = netC_retrain_sp(sp_torch_imgs.to(DEVICE))
			classify_loss = cce_loss(logits, labels)
			classify_loss.backward()
			optimizer_sp.step()

			logits, _, _ = netC_retrain_rand(rand_torch_imgs.to(DEVICE))
			classify_loss = cce_loss(logits, labels)
			classify_loss.backward()
			optimizer_rand.step()

			if current_iter == 10 or current_iter == 50 or current_iter % 100 == 0:
				acc_sp   = evaluate_validation(netC_retrain_sp.eval(),   test_loader, test_dataset)
				acc_rand = evaluate_validation(netC_retrain_rand.eval(), test_loader, test_dataset)
				print("Meta Data:", current_iter, acc_sp, acc_rand)
				
				sp_accs.append(acc_sp)
				rand_accs.append(acc_rand)
				
				iterations.append(current_iter)
				epochs.append(epoch)
				avg_size.append( sum(rolling_size) / len(rolling_size)  )

				save('Superpixels', sp_accs, iterations, epochs)
				save('Random', rand_accs, iterations, epochs)
				print(  "Average Size:", sum(avg_size) / len(avg_size) )
				print(" ")
				

			# Go to next alpha
			if current_iter == 1000:
				break


	print("Time Taken:", time.time() - start_time)
	print("Average Size:", sum(avg_size) / len(avg_size) )
	print(" ")
	print(" ")



