import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import DropPath, to_2tuple, trunc_normal_

class Attention(nn.Module):
    def __init__(self, in_dim, out_dim, kernel_size=3, qkv_bias=False, qk_scale=None, qk_dim=None, attn_drop=0., proj_drop=0., stride=1, padding=None, groups=1,
                 locality_strength=1., positional_strength=1., use_local_init=True, bias=False, **kwargs):
        super().__init__()
        self.kernel_size = kernel_size
        self.num_heads = kernel_size**2
        self.stride = stride
        self.padding = self.kernel_size//2 if not padding else padding
        self.remove = self.kernel_size//2
        self.groups = groups
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.qk_dim = qk_dim if qk_dim else self.in_dim//self.num_heads
        self.scale = qk_scale or self.in_dim ** -0.5

        self.qk = nn.Linear(self.in_dim, self.qk_dim*self.num_heads*2, bias=qkv_bias)
        self.v = nn.Linear(self.in_dim//self.groups, self.in_dim//self.groups, bias=qkv_bias)
        # self.v = nn.Linear(self.in_dim, self.in_dim, bias=qkv_bias)

        self.attn_drop = nn.Dropout(attn_drop)
        self.pos_span = nn.Parameter(torch.ones(self.num_heads))
        self.pos_x = nn.Parameter(torch.randn(self.num_heads))
        self.pos_y = nn.Parameter(torch.randn(self.num_heads))

        self.filter = nn.Linear(self.in_dim * self.num_heads // self.groups, self.out_dim, bias=bias)
        self.filter_drop = nn.Dropout(proj_drop)

        self.locality_strength = locality_strength
        self.positional_strength = positional_strength
        self.alpha = nn.Parameter(torch.randn(self.num_heads))
        self.apply(self._init_weights)
        if use_local_init:
            self.local_init(locality_strength=locality_strength, positional_strength = positional_strength)

        self.zeropad = nn.ZeroPad2d(self.padding)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):

        x = self.zeropad(x)
        x = self.flatten(x)
        B, N_padded, C = x.shape

        attn = self.get_attention(x) # B, H, Nps, Np
        x = x.unsqueeze(1).repeat(1,self.num_heads,1,1) # B, H, Np, C

        if self.groups>1:
            x = x.reshape(B, self.num_heads, N_padded, self.groups, self.in_dim // self.groups)
            x = self.v(x).reshape(B, self.num_heads, N_padded, -1)
            x = (attn @ x).transpose(1, 2)
            groups = x.split(self.in_dim//self.groups, dim=-1)
            filters = self.filter.weight.split(self.out_dim//self.groups, dim=0)
            x = torch.cat([g.reshape(B, -1, self.num_heads*self.in_dim//self.groups) @ f.t() for f,g in zip(filters, groups)], dim=-1)

        else:
            x = (attn @ self.v(x)).transpose(1, 2) # B, Np, H, C
            x = x.reshape(B, -1, self.num_heads*self.in_dim) # B, Np, H*C
            x = self.filter(x) # B, Np, C

        x = self.filter_drop(x)
        x = self.spatialize(x)
        return x
    
    def get_pos_attention(self, x):
        B, N, C = x.shape
        if not hasattr(self, 'rel_indices') or self.rel_indices.size(1)!=N:
            self.get_rel_indices(N)
        pos_score = self.rel_indices.unsqueeze(-1)
        pos_score = F.softplus(self.pos_span,beta=5) * ( - pos_score[:,:,:,2] + pos_score[:,:,:,0] @ self.pos_x.unsqueeze(0) + pos_score[:,:,:,1] @ self.pos_y.unsqueeze(0) )
        pos_score = pos_score.expand(B, -1, -1,-1).permute(0,3,1,2)
        pos_score = pos_score.softmax(dim=-1)
        return pos_score
    
    def get_cont_attention(self, x):
        B, N, C = x.shape
        qk = self.qk(x).reshape(B, N, 2, self.num_heads*self.qk_dim).permute(2, 0, 1, 3) 
        q, k = qk[0], qk[1]
        q = self.flatten(self.spatialize(q)[:,:,self.remove:-self.remove:self.stride,self.remove:-self.remove:self.stride])
        q = q.reshape(B, -1, self.num_heads, self.qk_dim).permute(0, 2, 1, 3)
        k = k.reshape(B, N, self.num_heads, self.qk_dim).permute(0, 2, 1, 3)        

        patch_score = (q @ k.transpose(-2, -1)) * self.scale
        patch_score = patch_score.softmax(dim=-1)
        return patch_score

    def get_attention(self, x):

        pos_score, patch_score = self.get_pos_attention(x), self.get_cont_attention(x)
        
        alpha = self.alpha.view(1,-1,1,1)
        attn = (1.-torch.sigmoid(2*alpha)) * patch_score + torch.sigmoid(2*alpha) * pos_score
        attn /= attn.sum(dim=-1).unsqueeze(-1)
        attn = self.attn_drop(attn)

        return attn


    def get_attention_map(self, x, return_map = False, type=None):

        x = self.zeropad(x)
        x = self.flatten(x)
        
        if type=='pos':
            attn_map = self.get_pos_attention(x)
        if type=='cont':
            attn_map = self.get_cont_attention(x)
        else:
            attn_map = self.get_attention(x)
            
        attn_map = attn_map.mean(0) # average over batch
        
        distances = self.rel_indices.squeeze()[:,:,-1]**.5
        dist = torch.einsum('nm,hnm->h', (distances, attn_map)).mean() # average over heads
        dist = dist.item() / distances.size(0)
        if return_map:
            return dist, attn_map
        else:
            return dist

    def local_init(self, locality_strength, positional_strength):

        nn.init.constant_(self.alpha, positional_strength)

        self.v.weight.data.copy_(torch.eye(self.v.weight.data.size(0)))

        kernel_size = int(self.num_heads**.5)
        center = (kernel_size-1)/2 if kernel_size%2==0 else kernel_size//2
        for h1 in range(kernel_size):
            for h2 in range(kernel_size):
                position = h2+kernel_size*h1
                self.pos_x.data[position] = 2*(h2-center)
                self.pos_y.data[position] = 2*(h1-center)
                self.pos_span.data[position] = locality_strength

    def load_filters(self, conv):

        #assert not conv.bias
        w = conv.weight.data
        ch_out, ch_in, k, k = w.shape
        w = w.permute(0,2,3,1)
        w = w.reshape(ch_out,ch_in*k*k)
        self.filter.weight.data.copy_(w)
        if conv.bias is not None:
            self.filter.bias.data.copy_(conv.bias.data)

    def load_pooling(self):

        conv = nn.Conv2d(self.in_dim, self.out_dim, self.kernel_size, stride=1, padding=self.padding, bias=False)
        # Turn the conv layer into a pooling layer
        conv.weight.data.fill_(0)
        for i in range(dim):
            conv.weight.data[i, i].fill_(1/(kernel_size**2))
        # Embed as a SA layer
        self.load_filters(conv)

    def spatialize(self, x):
        B,N,C = x.shape
        img_size = int(N**.5)
        x = x.permute(0,2,1)
        x = x.reshape(B,C,img_size, img_size)
        return x

    def flatten(self, x):
        B,C,img_size,img_size = x.shape
        N = img_size**2
        x = x.reshape(B,C,N)
        x = x.permute(0,2,1)
        return x

    def get_rel_indices(self, num_patches):
        img_size = int(num_patches**.5)
        rel_indices   = torch.zeros(1, num_patches, num_patches, 3)
        ind = torch.arange(img_size).view(1,-1) - torch.arange(img_size).view(-1, 1)
        indx = ind.repeat(img_size, img_size)
        indy = ind.repeat_interleave(img_size,dim=0).repeat_interleave(img_size,dim=1)
        indd = indx**2 + indy**2
        rel_indices[:,:,:,2] = indd.unsqueeze(0)
        rel_indices[:,:,:,1] = indy.unsqueeze(0)
        rel_indices[:,:,:,0] = indx.unsqueeze(0)
        rel_indices = rel_indices.reshape(1, img_size, img_size, num_patches, 3)[:,self.remove:-self.remove:self.stride,self.remove:-self.remove:self.stride,:,:]
        rel_indices = rel_indices.reshape(1, -1, num_patches, 3)
        device = self.filter.weight.device
        self.rel_indices = rel_indices.to(device)

