import torch
import numpy as np
import torch.nn as nn
from einops import rearrange
import torch.nn.functional as F
from mmcv.cnn import xavier_init
from timm.models.layers import trunc_normal_


class LayerNormProxy(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        x = rearrange(x, 'b c h w -> b h w c')
        x = self.norm(x)
        return rearrange(x, 'b h w c -> b c h w').contiguous()

class LDAU(nn.Module):
    def __init__(self, in_c, reduction=4, up_factor=2., k_e=3, k_u=3, n_groups=2,
                 range_factor=11, rpb=True, up_m="nearest"):
        super(LDAU, self).__init__()
        self.rpb = rpb
        self.k_u = k_u
        self.k_e = k_e
        self.up_factor = up_factor
        self.n_groups = n_groups
        self.offset_range_factor = range_factor
        self.upsample_method = up_m

        self.attn_dim = in_c // reduction
        self.scale = self.attn_dim ** -0.5
        self.hidden_dim = in_c // reduction
        self.proj_q = nn.Conv2d(
            in_c, self.hidden_dim,
            kernel_size=1, stride=1, padding=0, bias=False
        )

        self.proj_k = nn.Conv2d(
            in_c, self.hidden_dim,
            kernel_size=1, stride=1, padding=0, bias=False
        )

        self.group_channel = in_c // (reduction * self.n_groups)
        self.conv_offset = nn.Sequential(
            nn.Conv2d(self.group_channel, self.group_channel, 3, 1, 1,
                      groups=self.group_channel, bias=False),
            LayerNormProxy(self.group_channel),
            nn.GELU(),
            nn.Conv2d(self.group_channel, 2 * k_u ** 2, k_e, 1, k_e // 2)
        )
        self.layer_norm = LayerNormProxy(in_c)

        self.pad = int((self.k_u - 1) / 2)
        base = np.arange(-self.pad, self.pad + 1).astype(np.float32)
        base_y = np.repeat(base, self.k_u)
        base_x = np.tile(base, self.k_u)
        base_offset = np.stack([base_y, base_x], axis=1).flatten()
        base_offset = torch.tensor(base_offset).view(1, -1, 1, 1)
        self.register_buffer("base_offset", base_offset, persistent=False)

        if self.rpb:
            self.relative_position_bias_table = nn.Parameter(torch.zeros(1, 1, 1, self.k_u ** 2, self.hidden_dim))
            trunc_normal_(self.relative_position_bias_table, std=.02)

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                xavier_init(m, distribution='uniform')
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)
        nn.init.constant_(self.conv_offset[-1].weight, 0)
        nn.init.constant_(self.conv_offset[-1].bias, 0)

    def extract_feats(self, k, v, Hout, Wout, offset, ks=3):
        B, C, Hin, Win = k.shape
        device = offset.device

        row_indices = torch.arange(Hout, device=device)
        col_indices = torch.arange(Wout, device=device)
        row_indices, col_indices = torch.meshgrid(row_indices, col_indices)
        index_tensor = torch.stack((row_indices, col_indices), dim=-1).view(1, 1, Hout, 1, Wout, 2)
        offset = rearrange(offset, "b (kh kw d) h w -> b kh h kw w d", kh=ks, kw=ks)
        offset = offset + index_tensor
        offset = offset.contiguous().view(B, ks * Hout, ks * Wout, 2)

        offset[..., 0] = (2 * offset[..., 0] / (Hout - 1) - 1)
        offset[..., 1] = (2 * offset[..., 1] / (Wout - 1) - 1)
        offset = offset.flip(-1)

        out_k = nn.functional.grid_sample(k, offset, mode="bilinear", padding_mode="zeros", align_corners=True)
        out_v = nn.functional.grid_sample(v, offset, mode="bilinear", padding_mode="zeros", align_corners=True)
        out_k = rearrange(out_k, "b c (ksh h) (ksw w) -> b (ksh ksw) c h w", ksh=ks, ksw=ks, h=Hout, w=Wout)
        out_v = rearrange(out_v, "b c (ksh h) (ksw w) -> b (ksh ksw) c h w", ksh=ks, ksw=ks, h=Hout, w=Wout)
        return out_k, out_v

    def forward(self, x):
        B, C, H, W = x.shape
        out_H, out_W = int(H * self.up_factor), int(W * self.up_factor)

        v = x
        x = self.layer_norm(x)
        q = self.proj_q(x)
        k = self.proj_k(x)

        if self.upsample_method == "nearest":
            q = torch.nn.functional.interpolate(q, (out_H, out_W), mode="nearest")
        else:
            q = torch.nn.functional.interpolate(q, (out_H, out_W), mode="bilinear", align_corners=True)

        q_off = q.view(B * self.n_groups, self.group_channel, out_H, out_W)
        pred_offset = self.conv_offset(q_off)
        offset = pred_offset.tanh().mul(self.offset_range_factor) + self.base_offset.to(x.dtype)

        k = k.view(B * self.n_groups, self.hidden_dim // self.n_groups, H, W)
        v = v.view(B * self.n_groups, C // self.n_groups, H, W)
        k, v = self.extract_feats(k, v, out_H, out_W, offset=offset, ks=self.k_u)

        q = rearrange(q, "b c h w -> b (h w) () c")
        k = rearrange(k, "(b g) n c h w -> b (h w) n (g c)", b=B, g=self.n_groups)
        v = rearrange(v, "(b g) n c h w -> b (h w) n (g c)", b=B, g=self.n_groups)

        if self.rpb:
            k = k + self.relative_position_bias_table.squeeze(0)
        q = q * self.scale
        attn = q @ k.transpose(-1, -2)
        attn = attn.softmax(dim=-1)
        out = attn @ v

        out = rearrange(out, "b (h w) t c -> b c (t h) w", h=out_H, w=out_W)
        return out