# --------------------------------------------------------
# Fused kernel for window process for SwinTransformer
# Copyright (c) 2022 Nvidia
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------

import torch
import swin_window_process
import random
import time
import unittest


class WindowProcess(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, B, H, W, C, shift_size, window_size):
        output = swin_window_process.roll_and_window_partition_forward(input, B, H, W, C, shift_size, window_size)

        ctx.B = B
        ctx.H = H
        ctx.W = W 
        ctx.C = C 
        ctx.shift_size = shift_size
        ctx.window_size = window_size
        return output

    @staticmethod
    def backward(ctx, grad_in):
        B = ctx.B
        H = ctx.H
        W = ctx.W 
        C = ctx.C 
        shift_size = ctx.shift_size
        window_size = ctx.window_size

        grad_out = swin_window_process.roll_and_window_partition_backward(grad_in, B, H, W, C, shift_size, window_size)
        return grad_out, None, None, None, None, None, None, None


class WindowProcessReverse(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, B, H, W, C, shift_size, window_size):
        output = swin_window_process.window_merge_and_roll_forward(input, B, H, W, C, shift_size, window_size)

        ctx.B = B
        ctx.H = H
        ctx.W = W 
        ctx.C = C 
        ctx.shift_size = shift_size
        ctx.window_size = window_size

        return output

    @staticmethod
    def backward(ctx, grad_in):
        B = ctx.B
        H = ctx.H
        W = ctx.W 
        C = ctx.C 
        shift_size = ctx.shift_size
        window_size = ctx.window_size

        grad_out = swin_window_process.window_merge_and_roll_backward(grad_in, B, H, W, C, shift_size, window_size)
        return grad_out, None, None, None, None, None, None, None


def window_partition(x, window_size):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size
    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows

def window_reverse(windows, window_size, H, W):
    """
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size
        H (int): Height of image
        W (int): Width of image
    Returns:
        x: (B, H, W, C)
    """
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x


def pyt_forward(x, shift_size, window_size):
    # x in shape(B, H, W, C)
    # cyclic shift
    if shift_size > 0:
        shifted_x = torch.roll(x, shifts=(-shift_size, -shift_size), dims=(1, 2))
    else:
        shifted_x = x
    # partition windows
    x_windows = window_partition(shifted_x, window_size)
    return x_windows


def reverse_pyt_forward(attn_windows, shift_size, window_size, H, W):
    # x in shape(B*nH*nW, window_size, window_size, C)
    shifted_x = window_reverse(attn_windows, window_size, H, W)
    if shift_size > 0:
        x = torch.roll(shifted_x, shifts=(shift_size, shift_size), dims=(1, 2))
    else:
        x = shifted_x
    return x


def copy_one_tensor(input, requires_grad=True):
    input1 = input.clone().detach().requires_grad_(requires_grad).cuda()
    return input1

class Test_WindowProcess(unittest.TestCase):
    def setUp(self):
        self.B = 192
        self.H = 56
        self.W = 56
        self.C = 96
        self.shift_size = 2
        self.window_size = 7
        self.nH = self.H // self.window_size
        self.nW = self.W // self.window_size
    
    def test_roll_and_window_partition_forward(self, dtype=torch.float32):
        input = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda()
        
        input1 = copy_one_tensor(input, True)
        input2 = copy_one_tensor(input, True)

        with torch.no_grad():
            # ori
            expected = pyt_forward(input1, self.shift_size, self.window_size)
            # fused kernel
            fused_output = WindowProcess.apply(input2, self.B, self.H, self.W, self.C, -self.shift_size, self.window_size)
        
        self.assertTrue(torch.equal(expected, fused_output))
        #self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08))
    
    def test_roll_and_window_partition_backward(self, dtype=torch.float32):
        input = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda()
        d_loss_tensor = torch.randn((self.B*self.nW*self.nH, self.window_size, self.window_size, self.C), dtype=dtype).cuda()
        
        input1 = copy_one_tensor(input, True)
        input2 = copy_one_tensor(input, True)

        # ori
        expected = pyt_forward(input1, self.shift_size, self.window_size)
        expected.backward(d_loss_tensor)
        # fused kernel
        fused_output = WindowProcess.apply(input2, self.B, self.H, self.W, self.C, -self.shift_size, self.window_size)
        fused_output.backward(d_loss_tensor)
        
        self.assertTrue(torch.equal(expected, fused_output))
        #self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08))

    def test_window_merge_and_roll_forward(self, dtype=torch.float32):
        input = torch.randn((self.B*self.nH*self.nW, self.window_size, self.window_size, self.C), dtype=dtype, requires_grad=True).cuda()
        
        input1 = copy_one_tensor(input, True)
        input2 = copy_one_tensor(input, True)

        with torch.no_grad():
            # ori
            expected = reverse_pyt_forward(input1, self.shift_size, self.window_size, self.H, self.W)
            # fused kernel
            fused_output = WindowProcessReverse.apply(input2, self.B, self.H, self.W, self.C, self.shift_size, self.window_size)
        
        self.assertTrue(torch.equal(expected, fused_output))
        #self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08))
    

    def test_window_merge_and_roll_backward(self, dtype=torch.float32):
        input = torch.randn((self.B*self.nH*self.nW, self.window_size, self.window_size, self.C), dtype=dtype, requires_grad=True).cuda()
        d_loss_tensor = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda()
        
        input1 = copy_one_tensor(input, True)
        input2 = copy_one_tensor(input, True)

        # ori
        expected = reverse_pyt_forward(input1, self.shift_size, self.window_size, self.H, self.W)
        expected.backward(d_loss_tensor)
        # fused kernel
        fused_output = WindowProcessReverse.apply(input2, self.B, self.H, self.W, self.C, self.shift_size, self.window_size)
        fused_output.backward(d_loss_tensor)
        
        self.assertTrue(torch.equal(expected, fused_output))
        #self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08))

    def test_forward_backward_speed(self, dtype=torch.float32, times=1000):
        input = torch.randn((self.B*self.nH*self.nW, self.window_size, self.window_size, self.C), dtype=dtype, requires_grad=True).cuda()
        d_loss_tensor = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda()
        
        input1 = copy_one_tensor(input, True)
        input2 = copy_one_tensor(input, True)

        # SwinTransformer official
        def run_pyt(t=1000):
            for _ in range(t):
                expected = reverse_pyt_forward(input1, self.shift_size, self.window_size, self.H, self.W)
                expected.backward(d_loss_tensor)

        # my op
        def run_fusedop(t=1000):
            for _ in range(t):
                fused_output = WindowProcessReverse.apply(input2, self.B, self.H, self.W, self.C, self.shift_size, self.window_size)
                fused_output.backward(d_loss_tensor)
        
        torch.cuda.synchronize()
        t1 = time.time()
        run_pyt(t=times)
        torch.cuda.synchronize()
        t2 = time.time()
        run_fusedop(t=times)
        torch.cuda.synchronize()
        t3 = time.time()
        self.assertTrue((t3 - t2) < (t2 - t1))

        print('Run {} times'.format(times))
        print('Original time cost: {}'.format(t2 - t1))
        print('Fused op time cost: {}'.format(t3 - t2))
    
    def test_roll_and_window_partition_forward_fp16(self, dtype=torch.float16):
        self.test_roll_and_window_partition_forward(dtype=dtype)

    def test_roll_and_window_partition_backward_fp16(self, dtype=torch.float16):
        self.test_roll_and_window_partition_backward(dtype=dtype)

    def test_window_merge_and_roll_forward_fp16(self, dtype=torch.float16):
        self.test_window_merge_and_roll_forward(dtype=dtype)
    
    def test_window_merge_and_roll_backward_fp16(self, dtype=torch.float16):
        self.test_window_merge_and_roll_backward(dtype=dtype)

    def test_forward_backward_speed_fp16(self, dtype=torch.float16, times=1000):
        self.test_forward_backward_speed(dtype=dtype, times=times)


if __name__ == '__main__':
    print('Pass only two tensors are exactly the same (using torch.equal).\n')
    torch.manual_seed(0)
    unittest.main(verbosity=2)
