import os
import torch
import numpy as np
from torch import nn
from torch.nn.init import constant_
import torch.nn.functional as F

import torch
from torch.autograd import Function
import torch.nn as nn
try:
    import roll_cuda
except ImportError:
    print("CUDA extension not found. Compiling from source...")

class RollFunction(Function):
    @staticmethod
    def forward(ctx, input, shifts):
        ctx.save_for_backward(shifts)
        output = roll_cuda.roll_cuda(input, shifts)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        shifts, = ctx.saved_tensors
        grad_input = roll_cuda.roll_backward_cuda(grad_output, shifts)
        return grad_input, None

class Roll(nn.Module):
    def __init__(self):
        super(Roll, self).__init__()

    def forward(self, input, shifts):
        return RollFunction.apply(input, shifts)


class RollConv(nn.Module):
    def __init__(
            self,
            channels=64,
            out_channels=128,
            kernel_size=3,
            stride=1,
            padding=1,
            dilation=1,
            group=4,
            kernel_number=1,
            act_layer='GELU',
            norm_layer='LN',
            dropout_rate=0.1,  #0.1
    ):
        """
        RollConv Module
        :param channels
        :param kernel_size
        :param stride
        :param pad
        :param dilation
        :param group
        :param offset_scale
        :param act_layer
        :param norm_layer
        """
        super().__init__()
        if channels % group != 0:
            raise ValueError(
                f'channels must be divisible by group, but got {channels} and {group}')
        self.channels = channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.dilation = dilation
        self.padding = padding
        self.group = group
        self.kernel_number = kernel_number

        self.dw_conv = nn.Sequential(
            nn.Conv2d(
                channels,
                channels,
                kernel_size=3,
                stride=1,
                padding=1,
                groups=group),
            build_norm_layer(
                channels,
                norm_layer,
                'channels_first',
                'channels_first'),
            build_act_layer(act_layer))
        self.dw_conv2 = nn.Sequential(
            nn.Conv2d(
                channels,
                channels,
                kernel_size=3,
                stride=1,
                padding=1,
                groups=group),
            build_norm_layer(
                channels,
                norm_layer,
                'channels_first',
                'channels_first'),
            build_act_layer(act_layer))
        self.dropout1 = nn.Dropout(dropout_rate)
        # self.offset = nn.Linear(channels, group*kernel_number*2, bias=True)
        self.offset = nn.Linear(channels, group*kernel_number*2, bias=False)
        self.tanh = nn.Tanh()
        self.alphas_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout2= nn.Dropout(dropout_rate)
        self.alpha = nn.Linear(channels, group*kernel_number, bias=True)
        self.sigmod = nn.Sigmoid()
        self.roll = Roll()
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.channels_shuffle = ChannelShuffle(self.group)

        self.atten_conv1 = nn.Conv2d(in_channels=2*channels, out_channels=channels, kernel_size=1, padding=0, stride=1,
                                      groups=group, bias=True)
       
        self.weight = nn.Parameter(
            torch.Tensor(
                kernel_number,
                group, 
                out_channels,
                channels//group,
                kernel_size,
                kernel_size,
            )
        )
        nn.init.kaiming_normal_(self.weight, mode='fan_out', nonlinearity='relu')
        
       
        self._reset_parameters()

    def _reset_parameters(self):
        constant_(self.offset.weight.data, 0.)
        # constant_(self.offset.bias.data, 0.)
        constant_(self.alpha.weight.data, 0.)
        constant_(self.alpha.bias.data, 1.0) ##0.0

    def roll_float(self, x, shifts):
        # shifts [B, 2]
        # x [B, C, D, H, W]
        shifts_int = shifts.floor().int()
        shifts_frac = shifts - shifts_int
        
        
        roll00 = self.roll(x, shifts_int)
        roll10 = self.roll(x, shifts_int + torch.tensor([1, 0], device=x.device, dtype=torch.int32))
        roll01 = self.roll(x, shifts_int + torch.tensor([0, 1], device=x.device, dtype=torch.int32))
        roll11 = self.roll(x, shifts_int + torch.tensor([1, 1], device=x.device, dtype=torch.int32))

        a =  (1 - shifts_frac[:, 0, None, None, None, None]) * (1 - shifts_frac[:, 1, None, None, None, None])
        b =  shifts_frac[:, 0, None, None, None, None] * (1 - shifts_frac[:, 1, None, None, None, None])
        c = (1 - shifts_frac[:, 0, None, None, None, None]) * shifts_frac[:, 1, None, None, None, None]
        d = shifts_frac[:, 0, None, None, None, None] * shifts_frac[:, 1, None, None, None, None]
        
        result = a * roll00 + b * roll10 + c * roll01 + d * roll11
        return result


    def forward(self, input):
        B = input.shape[0]
        weights = self.weight.unsqueeze(0).repeat(B, 1, 1, 1, 1, 1, 1)
        
        x_avg = self.avg_pool(input)
        x_avg_shuffled = self.channels_shuffle(x_avg)
        x_avg = x_avg.reshape(B, self.channels//self.group, self.group, 1, 1)
        x_avg_shuffled = x_avg_shuffled.reshape(B, self.channels//self.group, self.group, 1, 1)
        x_avg = torch.cat([x_avg, x_avg_shuffled], dim=1)
        x_avg = x_avg.reshape(B, -1, 1, 1)
        x_avg = self.atten_conv1(x_avg)
        x_avg = torch.sigmoid(x_avg)
        input = input * x_avg
        

        feat = self.dw_conv(input)
        feat = self.alphas_pool(feat).squeeze(dim=-1).squeeze(dim=-1)
        offset = self.offset(self.dropout1(feat))  # important
        feat2 = self.dw_conv2(input)
        feat2 = self.alphas_pool(feat2).squeeze(dim=-1).squeeze(dim=-1)
        alpha = self.sigmod(self.alpha(self.dropout2(feat2))) # [B, group*kernel_number]
    
        alpha = alpha.reshape(B, self.kernel_number, self.group, 1)
        offset = offset.reshape(B, self.kernel_number, self.group, 2)
        
        offset = offset.reshape(-1, 2)
        weights = weights.reshape(-1, self.out_channels, self.channels//self.group, self.kernel_size, self.kernel_size)

        weights = self.roll_float(weights, offset)
        weights = weights.reshape(B, self.kernel_number, self.group, self.out_channels, self.channels//self.group, self.kernel_size, self.kernel_size)
        weights = alpha.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * weights  
        weights = weights.sum(1).permute(0, 2, 1, 3, 4, 5).reshape(B, self.out_channels, -1, self.kernel_size, self.kernel_size)

        out = F.conv2d(input.reshape(1, -1, *input.shape[2:]), weight=weights.reshape(-1, *weights.shape[2:]), bias=None, 
                       stride=self.stride, padding=self.padding, dilation=self.dilation, groups=B)
        out = out.reshape(B, self.out_channels, *out.shape[2:])

        return out


class to_channels_first(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x.permute(0, 3, 1, 2)


class to_channels_last(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x.permute(0, 2, 3, 1)


def build_norm_layer(dim,
                     norm_layer,
                     in_format='channels_last',
                     out_format='channels_last',
                     eps=1e-6):
    layers = []
    if norm_layer == 'BN':
        if in_format == 'channels_last':
            layers.append(to_channels_first())
        layers.append(nn.BatchNorm2d(dim))
        if out_format == 'channels_last':
            layers.append(to_channels_last())
    elif norm_layer == 'LN':
        if in_format == 'channels_first':
            layers.append(to_channels_last())
        layers.append(nn.LayerNorm(dim, eps=eps))
        if out_format == 'channels_first':
            layers.append(to_channels_first())
    else:
        raise NotImplementedError(
            f'build_norm_layer does not support {norm_layer}')
    return nn.Sequential(*layers)


def build_act_layer(act_layer):
    if act_layer == 'ReLU':
        return nn.ReLU(inplace=True)
    elif act_layer == 'SiLU':
        return nn.SiLU(inplace=True)
    elif act_layer == 'GELU':
        return nn.GELU()

    raise NotImplementedError(f'build_act_layer does not support {act_layer}')

class ChannelShuffle(nn.Module):
    def __init__(self, num_groups):
        super(ChannelShuffle, self).__init__()
        self.num_groups = num_groups
    def forward(self, x: torch.FloatTensor):
        batch_size, chs, h, w = x.shape
        chs_per_group = chs // self.num_groups
        x = torch.reshape(x, (batch_size, self.num_groups, chs_per_group, h, w))
         # (batch_size, num_groups, chs_per_group, h, w)
        x = x.transpose(1, 2)  # dim_1 and dim_2
        out = torch.reshape(x, (batch_size, -1, h, w))
        return out



# Code for testing the RollConv
# Firstly, compile with "python setup.py install" in the this directory.
if __name__ == '__main__':
    os.environ["CUDA_VISIBLE_DEVICES"] = '0'
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Test the Roll CUDA operator
    images = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).view(1, 1, 1, 4, 2).float().cuda()
    weights = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], requires_grad=True).view(1, 1, 1, 4, 2).float().cuda()
    weights.retain_grad()
    shifts = torch.tensor([[-1, 2]], dtype=torch.int32).cuda()
    weights1 = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], requires_grad=True).view(1, 1, 1, 4, 2).float().cuda()
    weights1.retain_grad()
   
    roll = Roll()
    rolled_images = roll(images*weights, shifts)
    loss = (rolled_images*weights1).mean()
    loss.backward()
    print(weights.grad)
    print(weights1.grad)
    print(rolled_images)

    x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).view(4, 2).float().cuda()
    w = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], requires_grad=True).view(4, 2).float().cuda()
    w.retain_grad()
    x = torch.roll(x*w, shifts=(-1, 2), dims=(0, 1))
    w1 = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], requires_grad=True).view(4, 2).float().cuda()
    w1.retain_grad()
    loss = (x*w1).mean()
    loss.backward()
    print(w.grad)
    print(w1.grad)
    print(x)
    A = np.random.rand(4, 32, 128, 128)
    A = A.astype(dtype=np.float32)
    A = torch.from_numpy(A)
    conv0 = RollConv(
            channels=32,
            out_channels=64,
            kernel_size=3,
            stride=1,
            padding=1,
            dilation=1,
            group=4,
            kernel_number=4,
            act_layer='GELU',
            norm_layer='LN',
            dropout_rate=0.2)
    if torch.cuda.is_available():
        A = A.to(device)
        conv0 = conv0.to(device)
    out = conv0(A)
    print(out.shape)