import numpy as np
import time

import torch
import torch.nn.functional as F
from torch import nn

import cuda_gridsample as cu
import naive_gridsample as nv
from torch.autograd import grad
from functools import partial

import unittest
from torch.testing import assert_close

class CudaGridsampleTest(unittest.TestCase):
           
    def test_naive_constant(self):
        image = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).reshape(1, 1, 3, 3)
        optical = np.array([0.1, 0.1]).reshape(1, 1, 1, 2)
        self.cmp_with_naive(image, optical)
    
    def test_naive_oob(self):
        image = np.array([[1, 2], [3, 4]]).reshape(1, 1, 2, 2)
        optical = np.array([0.1, 1.1]).reshape(1, 1, 1, 2)
        self.cmp_with_naive(image, optical)

      
    def test_naive_random(self):
        for i in range(10):
            input, grid = self.create_random_input(oob=False)
            self.cmp_with_naive(input, grid)  

    def test_gradcheck_constant(self):
        image = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).reshape(1, 1, 3, 3)
        optical = np.array([-2.1, 0.1]).reshape(1, 1, 1, 2)
        self.gradcheck(image, optical, padding_mode='zeros')
    
    def test_gradcheck_random(self):
        for i in range(10):
            input, grid = self.create_random_input(max_dim=5)
            self.gradcheck(input, grid)

    def test_grad_output(self):
        for i in range(10):
            input, grid, grad_output = self.create_random_input(oob=True, max_dim=10, with_grad_output=True)

            input = torch.DoubleTensor(input)
            grid = torch.DoubleTensor(grid)
            grad_output = torch.DoubleTensor(grad_output)

            input = input.cuda()
            grid = grid.cuda()
            grad_output = grad_output.cuda()

            input.requires_grad = True
            grid.requires_grad = True
            grad_output.requires_grad = True
            torch.autograd.gradcheck(lambda grad_output, input, grid: cu._GridSample2dBackward.apply(grad_output, input, grid), (grad_output, input, grid))
 
    def test_gradcheck_random_oob_zeros(self):
        for i in range(10):
            input, grid = self.create_random_input(oob=True, max_dim=5)
            self.gradcheck(input, grid, padding_mode='zeros')

    def test_gradcheck_random_oob_border(self):
        for i in range(10):
            input, grid = self.create_random_input(oob=True, max_dim=5)
            self.gradcheck(input, grid, padding_mode='border')

    def test_gradcheck_random_nocorners(self):
        for i in range(10):
            input, grid = self.create_random_input(oob=True, max_dim=5)
            self.gradcheck(input, grid, padding_mode='zeros', align_corners=False)

    def test_float(self):
        image = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).reshape(1, 1, 3, 3)
        optical = np.array([0.1, 0.1]).reshape(1, 1, 1, 2)
        self.cmp_with_naive(image, optical, double=False) 

    
    def test_strided(self):
        for i in range(100):
            input, grid = self.create_random_input(max_dim=10)
            if i % 2 == 0:
                bs_stride = np.random.randint(1, input.shape[0] + 1) 
                c_stride = np.random.randint(1, input.shape[1] + 1)
                h_stride = np.random.randint(1, input.shape[2] + 1) 
                w_stride = np.random.randint(1, input.shape[3] + 1)

                hg_stride = np.random.randint(1, grid.shape[1] + 1)
                wg_stride = np.random.randint(1, grid.shape[2] + 1)
            else:
                bs_stride = np.random.randint(1, 3) 
                c_stride = np.random.randint(1, 3)
                h_stride = np.random.randint(1, 3) 
                w_stride = np.random.randint(1, 3)

                hg_stride = np.random.randint(1, 3)
                wg_stride = np.random.randint(1, 3)


            input = torch.DoubleTensor(input)
            grid = torch.DoubleTensor(grid)
      
            input.requires_grad = True
            grid.requires_grad = True

            image = input.cuda()
            optical = grid.cuda()
           
            image = image[::bs_stride, ::c_stride, ::h_stride, ::w_stride]
            optical = optical[::bs_stride, ::hg_stride, ::wg_stride]

            self.assertTrue(torch.autograd.gradcheck(partial(cu.grid_sample_2d), inputs=(image, optical)))
            self.assertTrue(torch.autograd.gradgradcheck(partial(cu.grid_sample_2d), inputs=(image, optical)))
            
    
    def test_use_case(self):
        torch.set_default_dtype(torch.float64)
        for i in range(10):
            input, grid = self.create_random_input(max_dim=10)

            l1 = nn.Conv2d(input.shape[1], input.shape[1], 1)
            l2 = nn.Conv2d(input.shape[1], 1, 1)

            image = torch.DoubleTensor(input)
            optical = torch.DoubleTensor(grid)
      
            image.requires_grad = True
            optical.requires_grad = True

            image = image.cuda()
            optical = optical.cuda()
            l1.cuda()
            l2.cuda()

            def fn(image, optical):
                out = l1(image)
                out = F.relu(out)
                out = cu.grid_sample_2d(out, optical)
                out = l2(out)
                out = out * out
                out = out.sum()
                return out

            self.assertTrue(torch.autograd.gradcheck(partial(fn), inputs=(image, optical), nondet_tol=1e-05))
            self.assertTrue(torch.autograd.gradgradcheck(partial(fn), inputs=(image, optical), nondet_tol=1e-05))
        torch.set_default_dtype(torch.float32)
    
    def create_random_input(self, oob=False, max_dim=20, with_grad_output=False):
        bs = np.random.randint(1, max_dim)
        c = np.random.randint(1, max_dim)
        h = np.random.randint(1, max_dim)
        w = np.random.randint(1, max_dim)

        hg = np.random.randint(1, max_dim)
        wg = np.random.randint(1, max_dim)
        if oob:
            bounds = (-2, 2)
        else:
            bounds = (-1, 1)
        grid = np.random.uniform(bounds[0], bounds[1], size=(bs, hg, wg, 2))
        input = np.random.normal(size=(bs, c, h, w))
        if not with_grad_output:
            return input, grid
        else:
            grad_output = np.random.normal(size=(bs, c, hg, wg))
            return input, grid, grad_output

    def cmp_with_naive(self, image, optical, double=True):
        if double:
            image = torch.DoubleTensor(image)
            optical = torch.DoubleTensor(optical)
        else:
            image = torch.FloatTensor(image)
            optical = torch.FloatTensor(optical)
 
        image.requires_grad = True
        optical.requires_grad = True
        
        nv_out = nv.grid_sample_2d(image, optical)
        nv_out = torch.sum(nv_out ** 2)

        nv_grad_image, nv_grad_optical = grad(nv_out, [image, optical], create_graph=True)
        nv_grad2_image, nv_grad2_optical = grad(torch.sum(nv_grad_image) + torch.sum(nv_grad_optical), [image, optical])
        
        image = image.cuda()
        optical = optical.cuda()

        out = cu.grid_sample_2d(image, optical, padding_mode='border', align_corners=True)
        out = torch.sum(out ** 2)
        grad_image, grad_optical = grad(out, [image, optical], create_graph=True)
        grad2_image, grad2_optical = grad(torch.sum(grad_image) + torch.sum(grad_optical), [image, optical])
          
        assert_close(nv_out, out.cpu())
 
        assert_close(nv_grad_image, grad_image.cpu())
        assert_close(nv_grad_optical, grad_optical.cpu())

        assert_close(nv_grad2_optical, grad2_optical.cpu())
        assert_close(nv_grad2_image, grad2_image.cpu())

    def gradcheck(self, image, optical, padding_mode='border', align_corners=True):
        image = torch.DoubleTensor(image)
        optical = torch.DoubleTensor(optical)
      
        image.requires_grad = True
        optical.requires_grad = True

        image = image.cuda()
        optical = optical.cuda()

        self.assertTrue(torch.autograd.gradcheck(partial(cu.grid_sample_2d, padding_mode=padding_mode, align_corners=align_corners), inputs=(image, optical)))
        self.assertTrue(torch.autograd.gradgradcheck(partial(cu.grid_sample_2d, padding_mode=padding_mode, align_corners=align_corners), inputs=(image, optical)))


if __name__ == "__main__":
    # unittest.main()
    image = torch.randn(1, 3, 512, 512).cuda()
    grid = F.affine_grid(torch.eye(2, 3).unsqueeze(0) + 0.01 * torch.randn((1, 2, 3)), [1, 1, 200, 200], align_corners=True).requires_grad_(True).cuda()
    image_new = cu.grid_sample_2d(image, grid+0.001)
    t1 = time.time()
    image_new = cu.grid_sample_2d(image, grid)
    t2 = time.time()
    image_new_naive = nv.grid_sample_2d(image, grid)
    t3 = time.time()
    print("Naive: ", t3 - t2)
    print("Cuda: ", t2 - t1)
    print("Speedup: ", (t3 - t2) / (t2 - t1))
    print((image_new - image_new_naive).abs().mean())
    # backward
    t4 = time.time()
    grid_grad_naive = torch.autograd.grad(image_new_naive.mean(), grid, create_graph=True)[0]
    t5 = time.time()
    grid_grad_cu = torch.autograd.grad(image_new.mean(), grid, create_graph=True)[0]
    t6 = time.time()
    print("Naive grad: ", t5 - t4)
    print("Cuda grad: ", t6 - t5)
    print("Speedup: ", (t5 - t4) / (t6 - t5))
    print((grid_grad_naive - grid_grad_cu).abs().mean())

    ### Test 2
    # image = torch.randn(1, 1, 128, 128).cuda()
    # grid = F.affine_grid(torch.eye(2, 3).unsqueeze(0), [1, 1, 50, 50], align_corners=True).cuda().detach().requires_grad_(True)
    # image_new = cu.grid_sample_2d(image, grid)
    # grid_grad = torch.autograd.grad(image_new.mean(), grid, create_graph=True)[0]
    # loss = (grid_grad**2).sum()
    # loss.backward()
    # print(grid.grad.shape, grid.grad.min(), grid.grad.max())
    # print(image.grad)
