import numpy as np
import cv2
import sys
import torch
from torch.autograd import Variable
sys.path.append('model/')
from model.PieAPPv0pt1_PT import PieAPP
sys.path.append('utils/')
from image_utils import *
import argparse
import os
from torch import autograd
import time
from apex import amp

home_path = '/home/ubuntu/'
def computeHessian(img1, gpu_id):
    patch_size = 64
    batch_size = 1
    sampling_mode = 'dense'
    
    os.environ['CUDA_VISIBLE_DEVICES'] = gpu_id

    img1 = cv2.cvtColor(np.float32(torch.transpose(img1,0,2).cpu().detach().numpy()), cv2.COLOR_RGB2BGR)
#     img2 = cv2.cvtColor(np.float32(torch.transpose(img2,0,2).cpu().numpy()), cv2.COLOR_RGB2BGR)
    
    
    imagesA = np.expand_dims(img1, axis=0).astype('float32')
#     imagesRef = np.expand_dims(img2, axis=0).astype('float32')

    
    _,rows,cols,ch = imagesA.shape

    if sampling_mode == 'sparse':
        stride_val = 27
    else:
        stride_val = 6

    try:
        gpu_num = float(gpu_id)
        use_gpu = 1
    except ValueError:
        use_gpu = 0
    except TypeError:
        use_gpu = 0

    y_loc = np.concatenate((np.arange(0, rows - patch_size, stride_val),np.array([rows - patch_size])), axis=0)
    num_y = len(y_loc)
    x_loc = np.concatenate((np.arange(0, cols - patch_size, stride_val),np.array([cols - patch_size])), axis=0)
    num_x = len(x_loc)
    num_patches_per_dim = 10
    num_patches=100
    ######## initialize the model
    PieAPP_net = PieAPP(batch_size,num_patches_per_dim)
    s = torch.load(home_path+'M3D-RPN/PerceptualImageError/weights/PieAPPv0.1.pth')
    s['ref_score_subtract.weight']=s['ref_score_subtract.weight'].unsqueeze(0)
    PieAPP_net.load_state_dict(s)

    if use_gpu == 1:
        PieAPP_net.cuda()

    score_accum = 0.0
    weight_accum = 0.0

    # iterate through smaller size sub-images (to prevent memory overload)
    for x_iter in range(0, -(-num_x//num_patches)):
        for y_iter in range(0, -(-num_y//num_patches)):
            # compute the size of the subimage
            if (num_patches_per_dim*(x_iter + 1) >= num_x):
                size_slice_cols = cols - x_loc[num_patches_per_dim*x_iter]
            else:
                size_slice_cols = x_loc[num_patches_per_dim*(x_iter + 1)] - x_loc[num_patches_per_dim*x_iter] + patch_size - stride_val
            if (num_patches_per_dim*(y_iter + 1) >= num_y):
                size_slice_rows = rows - y_loc[num_patches_per_dim*y_iter]
            else:
                size_slice_rows = y_loc[num_patches_per_dim*(y_iter + 1)] - y_loc[num_patches_per_dim*y_iter] + patch_size - stride_val
            # obtain the subimage and samples patches
            A_sub_im = imagesA[:, y_loc[num_patches_per_dim*y_iter]:y_loc[num_patches_per_dim*y_iter]+size_slice_rows, x_loc[num_patches_per_dim*x_iter]:x_loc[num_patches_per_dim*x_iter]+size_slice_cols,:]
#             ref_sub_im = imagesRef[:, y_loc[num_patches_per_dim*y_iter]:y_loc[num_patches_per_dim*y_iter]+size_slice_rows, x_loc[num_patches_per_dim*x_iter]:x_loc[num_patches_per_dim*x_iter]+size_slice_cols,:]
            A_patches, _ = sample_patches(A_sub_im, A_sub_im, patch_size=64, strideval=stride_val, random_selection=False, uniform_grid_mode = 'strided')
            num_patches_curr = A_patches.shape[0]/batch_size

            PieAPP_net.num_patches = int(num_patches_curr)

            # initialize variable to be  fed to PieAPP_net
            A_patches_var = torch.from_numpy(np.transpose(A_patches,(0,3,1,2))).requires_grad_(False)
            patch_shape = A_patches_var.shape
            A_patches_var = A_patches_var.reshape(-1)
            delta = torch.zeros(A_patches_var.shape, requires_grad=True)
            ref_patches_var = A_patches_var+delta

            if use_gpu == 1:
                A_patches_var = A_patches_var.cuda()
                ref_patches_var = ref_patches_var.cuda()

            # forward pass 
    
            _, PieAPP_patchwise_errors, PieAPP_patchwise_weights = PieAPP_net.compute_score(A_patches_var.float(), ref_patches_var.float(), patch_shape)

            curr_err = PieAPP_patchwise_errors
            curr_weights = 	PieAPP_patchwise_weights
            score_accum += torch.sum(torch.multiply(curr_err, curr_weights))
            weight_accum += torch.sum(curr_weights)
            grad =autograd.grad(score_accum,delta,retain_graph=True,create_graph=True,allow_unused=True)[0]
#             score_accum.backward()
#             print("grad",delta.grad)
            
            hessian = []
#             grad = delta.grad.reshape(-1)
#             grad.requires_grad_(True)
        #     grad = Variable(grad, requires_grad=True)
#             delta.grad.data.zero_()
            for i in range(grad.shape[0]): 
                gradi = autograd.grad(grad[i],delta,retain_graph=True,allow_unused=True)[0].unsqueeze(0)
                hessian.append(gradi)
        #         print(gradi)
            hessian_ret = torch.cat(hessian).detach().cuda()
    return hessian_ret

def computeGradient(img1, img2, gpu_id, mode = 'dense'):
    patch_size = 64
    batch_size = 1
    sampling_mode = mode
    
    device = img1.get_device()

    img1 = np.float32(img1.permute(1,2,0).detach().cpu().numpy())
    img2 = np.float32(img2.permute(1,2,0).detach().cpu().numpy())
    
    imagesA = np.expand_dims(img1, axis=0).astype('float32')
    imagesRef = np.expand_dims(img2, axis=0).astype('float32')

    _,rows,cols,ch = imagesRef.shape
    if sampling_mode == 'sparse':
        stride_val = 27
    else:
        stride_val = 6


    if device!='cpu':
        use_gpu = 1
    else:
        use_gpu = 0

    y_loc = np.concatenate((np.arange(0, rows - patch_size, stride_val),np.array([rows - patch_size])), axis=0)
    num_y = len(y_loc)
    x_loc = np.concatenate((np.arange(0, cols - patch_size, stride_val),np.array([cols - patch_size])), axis=0)
    num_x = len(x_loc)
    num_patches_per_dim = 20 #find the largest number that can fit in memory
    ######## initialize the model
    PieAPP_net = PieAPP(batch_size,num_patches_per_dim)    
    s = torch.load(home_path+'M3D-RPN/PerceptualImageError/weights/PieAPPv0.1.pth')
    s['ref_score_subtract.weight']=s['ref_score_subtract.weight'].unsqueeze(0)
    PieAPP_net.load_state_dict(s)

    if use_gpu == 1:
        PieAPP_net.to(device)
    PieAPP_net.eval()
    score_accum = 0.0
    weight_accum = 0.0

    # iterate through smaller size sub-images (to prevent memory overload)
    
    grad_total = torch.zeros((imagesRef.shape[0],3,imagesRef.shape[2],imagesRef.shape[1]), device=device)
    for x_iter in range(0, -(-num_x//num_patches_per_dim)):
        for y_iter in range(0, -(-num_y//num_patches_per_dim)):
            # compute the size of the subimage
            if (num_patches_per_dim*(x_iter + 1) >= num_x):
                size_slice_cols = cols - x_loc[num_patches_per_dim*x_iter]
            else:
                size_slice_cols = x_loc[num_patches_per_dim*(x_iter + 1)] - x_loc[num_patches_per_dim*x_iter] + patch_size - stride_val
            if (num_patches_per_dim*(y_iter + 1) >= num_y):
                size_slice_rows = rows - y_loc[num_patches_per_dim*y_iter]
            else:
                size_slice_rows = y_loc[num_patches_per_dim*(y_iter + 1)] - y_loc[num_patches_per_dim*y_iter] + patch_size - stride_val
            # obtain the subimage and samples patches
            A_sub_im = imagesA[:, y_loc[num_patches_per_dim*y_iter]:y_loc[num_patches_per_dim*y_iter]+size_slice_rows, x_loc[num_patches_per_dim*x_iter]:x_loc[num_patches_per_dim*x_iter]+size_slice_cols,:]
            ref_sub_im = imagesRef[:, y_loc[num_patches_per_dim*y_iter]:y_loc[num_patches_per_dim*y_iter]+size_slice_rows, x_loc[num_patches_per_dim*x_iter]:x_loc[num_patches_per_dim*x_iter]+size_slice_cols,:]
            A_patches, ref_patches, sample_rows, sample_cols = sample_patches(A_sub_im, ref_sub_im, patch_size=patch_size, strideval=stride_val, random_selection=False, uniform_grid_mode = 'strided')
            patch_shape = A_patches.shape
            num_patches_curr = A_patches.shape[0]/batch_size

            PieAPP_net.num_patches = int(num_patches_curr)

            # initialize variable to be  fed to PieAPP_net
            A_patches_var = Variable(torch.from_numpy(np.transpose(A_patches,(0,3,1,2))), requires_grad=True)          
            ref_patches_var = Variable(torch.from_numpy(np.transpose(ref_patches,(0,3,1,2))), requires_grad=False)
           
            if use_gpu == 1:
                A_patches_var = A_patches_var.to(device)
                ref_patches_var = ref_patches_var.to(device)
            A_patches_var = A_patches_var.type(torch.FloatTensor).to(device)
            ref_patches_var = ref_patches_var.type(torch.FloatTensor).to(device)
            # forward pass 

            _, PieAPP_patchwise_errors, PieAPP_patchwise_weights = PieAPP_net.compute_score(A_patches_var.float(), ref_patches_var.float(), patch_shape)

            tensor_weights = torch.sum(torch.multiply(PieAPP_patchwise_errors, PieAPP_patchwise_weights))
            curr_err = PieAPP_patchwise_errors.cpu().data.numpy()
            curr_weights = 	PieAPP_patchwise_weights.cpu().data.numpy()
            score_accum += np.sum(np.multiply(curr_err, curr_weights))
    
            weight_accum += np.sum(curr_weights)
            grad =autograd.grad(tensor_weights,A_patches_var,retain_graph=True,create_graph=True,allow_unused=True)[0].detach()
            for i in range(grad.shape[0]):
                y = y_loc[num_patches_per_dim*y_iter]+sample_rows[i]
                x = x_loc[num_patches_per_dim*x_iter]+sample_cols[i]
                grad_total[:, :, y:y+patch_size, x:x+patch_size] += grad[i]
    return grad_total

def computeGradient_batch(img1, img2, gpu_id):
    patch_size = 64
    batch_size = 2
    sampling_mode = 'dense'
    
    os.environ['CUDA_VISIBLE_DEVICES'] = gpu_id
    device = "cuda:"+gpu_id

#     img1 = cv2.cvtColor(np.float32(torch.transpose(img1,0,2).detach().cpu().numpy()), cv2.COLOR_RGB2BGR)
#     img2 = cv2.cvtColor(np.float32(torch.transpose(img2,0,2).detach().cpu().numpy()), cv2.COLOR_RGB2BGR)
    img1 = np.float32(img1.permute(0,2,3,1).detach().cpu().numpy())
    img2 = np.float32(img2.permute(0,2,3,1).detach().cpu().numpy())
    
    imagesA = img1.astype('float32')
    imagesRef = img2.astype('float32')
#     imagesA = img1.astype('float32')
#     imagesRef = img2.astype('float32')

    
    _,rows,cols,ch = imagesRef.shape
    if sampling_mode == 'sparse':
        stride_val = 27
    else:
        stride_val = 13

    try:
        gpu_num = float(gpu_id)
        use_gpu = 1
    except ValueError:
        use_gpu = 0
    except TypeError:
        use_gpu = 0

    y_loc = np.concatenate((np.arange(0, rows - patch_size, stride_val),np.array([rows - patch_size])), axis=0)
    num_y = len(y_loc)
    x_loc = np.concatenate((np.arange(0, cols - patch_size, stride_val),np.array([cols - patch_size])), axis=0)
    num_x = len(x_loc)

    num_patches_per_dim = num_x #find the largest number that can fit in memory
    ######## initialize the model
    PieAPP_net = PieAPP(batch_size,num_patches_per_dim)    
    s = torch.load(home_path+'M3D-RPN/PerceptualImageError/weights/PieAPPv0.1.pth')
    s['ref_score_subtract.weight']=s['ref_score_subtract.weight'].unsqueeze(0)
    PieAPP_net.load_state_dict(s)

    if use_gpu == 1:
        PieAPP_net.to(device)
    PieAPP_net.eval()
    score_accum = 0.0
    weight_accum = 0.0

    # iterate through smaller size sub-images (to prevent memory overload)
    grad_total = torch.zeros((imagesRef.shape[0],3,256,256), device=device)
    for x_iter in range(0, -(-num_x//num_patches_per_dim)):
        for y_iter in range(0, -(-num_y//num_patches_per_dim)):
            # compute the size of the subimage
            if (num_patches_per_dim*(x_iter + 1) >= num_x):
                size_slice_cols = cols - x_loc[num_patches_per_dim*x_iter]
            else:
                size_slice_cols = x_loc[num_patches_per_dim*(x_iter + 1)] - x_loc[num_patches_per_dim*x_iter] + patch_size - stride_val
            if (num_patches_per_dim*(y_iter + 1) >= num_y):
                size_slice_rows = rows - y_loc[num_patches_per_dim*y_iter]
            else:
                size_slice_rows = y_loc[num_patches_per_dim*(y_iter + 1)] - y_loc[num_patches_per_dim*y_iter] + patch_size - stride_val
            # obtain the subimage and samples patches
            A_sub_im = imagesA[:, y_loc[num_patches_per_dim*y_iter]:y_loc[num_patches_per_dim*y_iter]+size_slice_rows, x_loc[num_patches_per_dim*x_iter]:x_loc[num_patches_per_dim*x_iter]+size_slice_cols,:]
            ref_sub_im = imagesRef[:, y_loc[num_patches_per_dim*y_iter]:y_loc[num_patches_per_dim*y_iter]+size_slice_rows, x_loc[num_patches_per_dim*x_iter]:x_loc[num_patches_per_dim*x_iter]+size_slice_cols,:]
            A_patches, ref_patches, sample_rows, sample_cols = sample_patches(A_sub_im, ref_sub_im, patch_size=patch_size, strideval=stride_val, random_selection=False, uniform_grid_mode = 'strided')
            patch_shape = A_patches.shape
            num_patches_curr = A_patches.shape[0]/batch_size

            PieAPP_net.num_patches = int(num_patches_curr)

            # initialize variable to be  fed to PieAPP_net
            A_patches_var = Variable(torch.from_numpy(np.transpose(A_patches,(0,3,1,2))), requires_grad=True)
            ref_patches_var = Variable(torch.from_numpy(np.transpose(ref_patches,(0,3,1,2))), requires_grad=False)
            if use_gpu == 1:
                A_patches_var = A_patches_var.to(device)
                ref_patches_var = ref_patches_var.to(device)

            # forward pass 
            _, PieAPP_patchwise_errors, PieAPP_patchwise_weights = PieAPP_net.compute_score(A_patches_var.float(), ref_patches_var.float(), patch_shape)

            tensor_weights = torch.sum(torch.multiply(PieAPP_patchwise_errors, PieAPP_patchwise_weights))
            curr_err = PieAPP_patchwise_errors.cpu().data.numpy()
            curr_weights = 	PieAPP_patchwise_weights.cpu().data.numpy()
            score_accum += np.sum(np.multiply(curr_err, curr_weights))
    
            weight_accum += np.sum(curr_weights)

            grad =autograd.grad(tensor_weights,A_patches_var,retain_graph=True,create_graph=True,allow_unused=True)[0].detach()
            for b in range(batch_size):
                for i in range(grad.shape[0]):
                    ii = i-sample_rows.shape[0]*int(i/sample_rows.shape[0])
                    y = y_loc[num_patches_per_dim*y_iter]+sample_rows[ii]
                    x = x_loc[num_patches_per_dim*x_iter]+sample_cols[ii]
                    grad_total[b, :, y:y+patch_size, x:x+patch_size] += grad[i]
    return grad_total


def computePieAppDist(img1, img2, gpu_id):
    patch_size = 64
    batch_size = 1
    sampling_mode = 'sparse'
    
    os.environ['CUDA_VISIBLE_DEVICES'] = gpu_id

    
    img1 = np.float32(img1.detach().cpu().numpy())*255.0
    img2 = np.float32(img2.detach().cpu().numpy())*255.0
    imagesA = np.expand_dims(img1, axis=0).astype('float32')
    imagesRef = np.expand_dims(img2, axis=0).astype('float32')

    
    _,rows,cols,ch = imagesRef.shape

    if sampling_mode == 'sparse':
        stride_val = 27
    else:
        stride_val = 6

    try:
        gpu_num = float(gpu_id)
        use_gpu = 1
    except ValueError:
        use_gpu = 0
    except TypeError:
        use_gpu = 0

    y_loc = np.concatenate((np.arange(0, rows - patch_size, stride_val),np.array([rows - patch_size])), axis=0)
    num_y = len(y_loc)
    x_loc = np.concatenate((np.arange(0, cols - patch_size, stride_val),np.array([cols - patch_size])), axis=0)
    num_x = len(x_loc)
    num_patches_per_dim = 10
    num_patches=10
    ######## initialize the model
    PieAPP_net = PieAPP(batch_size,num_patches_per_dim)
    s = torch.load(home_path+'M3D-RPN/PerceptualImageError/weights/PieAPPv0.1.pth')
    s['ref_score_subtract.weight']=s['ref_score_subtract.weight'].unsqueeze(0)
    PieAPP_net.load_state_dict(s)

    if use_gpu == 1:
        PieAPP_net.cuda()

    score_accum = 0.0
    weight_accum = 0.0

    
    # iterate through smaller size sub-images (to prevent memory overload)

    for x_iter in range(0, -(-num_x//num_patches)):
        for y_iter in range(0, -(-num_y//num_patches)):
            # compute the size of the subimage
            if (num_patches_per_dim*(x_iter + 1) >= num_x):
                size_slice_cols = cols - x_loc[num_patches_per_dim*x_iter]
            else:
                size_slice_cols = x_loc[num_patches_per_dim*(x_iter + 1)] - x_loc[num_patches_per_dim*x_iter] + patch_size - stride_val
            if (num_patches_per_dim*(y_iter + 1) >= num_y):
                size_slice_rows = rows - y_loc[num_patches_per_dim*y_iter]
            else:
                size_slice_rows = y_loc[num_patches_per_dim*(y_iter + 1)] - y_loc[num_patches_per_dim*y_iter] + patch_size - stride_val
            # obtain the subimage and samples patches
            A_sub_im = imagesA[:, y_loc[num_patches_per_dim*y_iter]:y_loc[num_patches_per_dim*y_iter]+size_slice_rows, x_loc[num_patches_per_dim*x_iter]:x_loc[num_patches_per_dim*x_iter]+size_slice_cols,:]
            ref_sub_im = imagesRef[:, y_loc[num_patches_per_dim*y_iter]:y_loc[num_patches_per_dim*y_iter]+size_slice_rows, x_loc[num_patches_per_dim*x_iter]:x_loc[num_patches_per_dim*x_iter]+size_slice_cols,:]
            A_patches, ref_patches, _, _ = sample_patches(A_sub_im, ref_sub_im, patch_size=64, strideval=stride_val, random_selection=False, uniform_grid_mode = 'strided')
            num_patches_curr = A_patches.shape[0]/batch_size

            PieAPP_net.num_patches = int(num_patches_curr)

            # initialize variable to be  fed to PieAPP_net
            A_patches_var = Variable(torch.from_numpy(np.transpose(A_patches,(0,3,1,2))), requires_grad=False)
            ref_patches_var = Variable(torch.from_numpy(np.transpose(ref_patches,(0,3,1,2))), requires_grad=False)
            if use_gpu == 1:
                A_patches_var = A_patches_var.cuda()
                ref_patches_var = ref_patches_var.cuda()

            # forward pass 
            _, PieAPP_patchwise_errors, PieAPP_patchwise_weights = PieAPP_net.compute_score(A_patches_var.float(), ref_patches_var.float(), A_patches_var.shape)
            curr_err = PieAPP_patchwise_errors.cpu().data.numpy()	
            curr_weights = 	PieAPP_patchwise_weights.cpu().data.numpy()		
            score_accum += np.sum(np.multiply(curr_err, curr_weights))
            weight_accum += np.sum(curr_weights)
#     print('PieAPP value of '+args.A_path+ ' with respect to: '+str(score_accum/weight_accum))
    return score_accum/weight_accum