from pytorch_ssim import *
from torch import from_numpy
from torch import optim
from torch import autograd
import cv2
from scipy import sparse
import numpy as np

def eval_Hessian_mnist(npImg1):
    # print("image shape", gray.shape)
    img1 = npImg1.requires_grad_(False)
    img_shape=img1.shape
    img1= img1.reshape(-1)
    delta = torch.zeros(img1.shape, requires_grad=True).cuda()
    ssim_loss = SSIM()
    s = 1-ssim_loss(img1,img1+delta,img_shape)
    grad = autograd.grad(s,delta,retain_graph=True,create_graph=True)[0]
#     hessian = torch.zeros((grad.shape[0],grad.shape[0]))
    hessian = []
#     with torch.autograd.profiler.profile() as prof:
#     print(grad.shape)

    for i in range(grad.shape[0]):
        gradi = autograd.grad(grad[i],delta,retain_graph=True)[0].unsqueeze(0)

        # gradi[np.absolute(gradi)<1e-5] = 0
#         hessian[i] = gradi
        hessian.append(gradi)
        
    hessian_ret = torch.cat(hessian).detach().cuda()
    return hessian_ret

def eval_Hessian_mnist_masked(npImg1):
    # print("image shape", gray.shape)
    img1 = npImg1.requires_grad_(False)
    img_shape=img1.shape
    img1= img1.reshape(-1)
    delta = torch.zeros(img1.shape, requires_grad=True).cuda()
    
    ssim_loss = SSIM()
    s = (1-ssim_loss(img1,img1+delta,img_shape)).cuda()
    grad = autograd.grad(s,delta,retain_graph=True,create_graph=True)[0].cuda()
#     hessian_f = torch.zeros((grad.shape[0],grad.shape[0]))
    hessian = []
    # with torch.autograd.profiler.profile() as prof:
#     print(grad.shape)
    mask = torch.zeros(grad.shape, dtype=torch.bool).cuda()
    mask[:img_shape[2]*11]=True
    for i in range(grad.shape[0]):  
            
#         if i <img_shape[2]*11:
#             mask[:i+img_shape[2]*11]=True
#         else:
#             begin = i-img_shape[2]*11
#             mask[begin:begin+img_shape[2]*11]=True
        delta.register_hook(lambda grad: grad * mask.float())
#         print(grad[i])
        delta.retain_grad()
        grad[i].backward(retain_graph=True, create_graph=True)
        hessian.append(delta.grad.unsqueeze(0).detach().clone())
        delta.grad.data.zero_()
        if i+img_shape[2]*11<grad.shape[0]:
            mask[img_shape[2]*11+i]=True
        if(i>img_shape[2]*11):
            mask[i-img_shape[2]*11]=False
    # print(prof.key_averages().table(sort_by="self_cpu_time_total"))
    hessian_ret = torch.cat(hessian).detach().cuda()
    return hessian_ret

def eval_Hessian_mnist_sp(npImg1):
    img1 = npImg1.requires_grad_(False)
    img_shape=img1.shape
    img1= img1.reshape(-1)
    delta = torch.zeros(img1.shape, requires_grad=True).cuda()
    ssim_loss = SSIM(window_size = 11)
    s = 1-ssim_loss(img1,img1+delta,img_shape)
    # inputs = (img1, delta)
    # hessian = autograd.functional.hessian(ssim_loss,inputs)
    # print(len(hessian))
    # print(hessian[1][1].size())
    grad = autograd.grad(s,delta,retain_graph=True,create_graph=True)[0]
    # print(grad.shape)
    hessian_size = grad.shape[0]
    sparse_hessian = []
    data = np.array([])
    rows = np.array([])
    cols = np.array([])
    # with torch.autograd.profiler.profile() as prof:
    for i in range(hessian_size):
        gradi = autograd.grad(grad[i],delta,retain_graph=True)[0].detach().cpu().numpy()
        # gradi[np.absolute(gradi)<1e-5] = 0
        h = sparse.coo_matrix(gradi)
        data = np.concatenate((data,h.data))
        rows = np.concatenate((rows,h.row+i))
        cols = np.concatenate((cols,h.col))
    # print(prof.key_averages().table(sort_by="self_cpu_time_total"))
    hessian = sparse.csc_matrix((data, (rows,cols)),shape=(hessian_size,hessian_size))

    return hessian


def eval_Hessian(npImg1,x1, y1, x2, y2):
    gray = cv2.cvtColor(npImg1, cv2.COLOR_BGR2GRAY)
    # print("image shape", gray.shape)

    img1 = from_numpy(gray).float().cuda().unsqueeze(0).unsqueeze(1)/255.0
    cropped = img1[:, :, x1:x2,y1:y2]
    cropped = cropped.requires_grad_(False)

    img_shape=cropped.shape
    cropped= cropped.reshape(-1)
    delta = torch.zeros(cropped.shape, requires_grad=True).cuda()
    ssim_loss = SSIM()
    s = 1-ssim_loss(cropped,cropped+delta,img_shape)
    grad = autograd.grad(s,delta,retain_graph=True,create_graph=True)[0]
    hessian = torch.zeros((grad.shape[0],grad.shape[0]))
    # with torch.autograd.profiler.profile() as prof:
    for i in range(grad.shape[0]):
        gradi = autograd.grad(grad[i],delta,retain_graph=True)[0].cpu()
        # gradi[np.absolute(gradi)<1e-5] = 0
        hessian[i] = gradi
    # print(prof.key_averages().table(sort_by="self_cpu_time_total"))
    return hessian

def eval_Hessian_sp(npImg1, x1, y1, x2, y2):
    gray = cv2.cvtColor(npImg1, cv2.COLOR_BGR2GRAY)

    img1 = from_numpy(gray).float().cuda().unsqueeze(0).unsqueeze(1)/255.0
    cropped = img1[:, :, x1:x2,y1:y2]
    # print(cropped.shape)
    cropped = cropped.requires_grad_(False)

    img_shape=cropped.shape
    cropped= cropped.reshape(-1)
    delta = torch.zeros(cropped.shape, requires_grad=True).cuda()
    ssim_loss = SSIM(window_size = 11)
    s = 1-ssim_loss(cropped,cropped+delta,img_shape)
    # inputs = (img1, delta)
    # hessian = autograd.functional.hessian(ssim_loss,inputs)
    # print(len(hessian))
    # print(hessian[1][1].size())
    grad = autograd.grad(s,delta,retain_graph=True,create_graph=True)[0]
    # print(grad.shape)
    hessian_size = grad.shape[0]
    sparse_hessian = []
    data = np.array([])
    rows = np.array([])
    cols = np.array([])
    # with torch.autograd.profiler.profile() as prof:
    for i in range(hessian_size):
        gradi = autograd.grad(grad[i],delta,retain_graph=True)[0].cpu()
        # gradi[np.absolute(gradi)<1e-5] = 0
        h = sparse.coo_matrix(gradi)
        data = np.concatenate((data,h.data))
        rows = np.concatenate((rows,h.row+i))
        cols = np.concatenate((cols,h.col))
    # print(prof.key_averages().table(sort_by="self_cpu_time_total"))
    hessian = sparse.csc_matrix((data, (rows,cols)),shape=(hessian_size,hessian_size))

    return hessian

if __name__ == "__main__":
    npImg1 = cv2.imread("000651.png")
    gray = cv2.cvtColor(npImg1, cv2.COLOR_BGR2GRAY)
    print("image shape", gray.shape)

    img1 = torch.from_numpy(gray).float().unsqueeze(0).unsqueeze(1)/255.0
    # img1 = img1.requires_grad_(False)
    img_shape=img1.shape
    img1= img1.reshape(-1)
    ssim_loss = pytorch_ssim.SSIM()

    PD = 0
    PSD = 0
    NPD = 0
    print(npImg1)
    for j in range(10):
        delta = torch.rand(img1.size(), requires_grad=True)*0.001

        if torch.cuda.is_available():
            img1 = img1.cuda()
            delta = delta.cuda()

        # Functional: pytorch_ssim.ssim(img1, img2, window_size = 11, size_average = True)
        SSIM = 1-ssim_loss(img1,img1+delta,img_shape)
        # inputs = (img1, delta)
        # hessian = autograd.functional.hessian(ssim_loss,inputs)
        # print(len(hessian))
        # print(hessian[1][1].size())
        grad = autograd.grad(SSIM,delta,retain_graph=True,create_graph=True)[0]
        hessian = torch.zeros((grad.shape[0],grad.shape[0]))
        for i in range(grad.shape[0]):
            hessian[i] = autograd.grad(grad[i],delta,retain_graph=True)[0]
        print(hessian.shape)
        e, _ = torch.symeig(hessian, eigenvectors=False)
        if(e[0]>0):
            PD +=1
        elif(e[0]==0):
            PSD +=1
        else:
            NPD +=1
    print("Positive definite: ", PD)
    print("Postive semidefinite: ", PSD)
    print("Not postive semidefinite: ", NPD)
