import torch
import torch.nn as nn
import torch.nn.functional as F
from .droppath import DropPath
from .swin import Mlp
import numpy as np
from functools import partial
from typing import Any, Union, Callable, Optional
from itertools import repeat
import collections.abc


def constant_init(tensor, constant=0.0):
    nn.init.constant_(tensor, constant)
    return tensor


def _ntuple(n):
    def parse(x):
        if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
            return x
        return tuple(repeat(x, n))
    return parse


class Mlp(nn.Module):
    def __init__(self, in_features=None, hidden_features=None, out_features=None, 
                 activation=F.gelu, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = activation
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x, train: bool = True):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x) if train else x
        x = self.fc2(x)
        x = self.drop(x) if train else x
        return x


class Attention(nn.Module):
    """
    Default multihead attention
    """
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        nn.init.xavier_uniform_(self.qkv.weight)
        nn.init.xavier_uniform_(self.proj.weight)

    def forward(self, x, train: bool = True):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn) if train else attn

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x) if train else x
        return x


def window_partition1d(x, window_size):
    B, W, C = x.shape
    x = x.view(B, W // window_size, window_size, C)
    windows = x.view(-1, window_size, C)
    return windows


def window_reverse1d(windows, window_size, W: int):
    B = int(windows.shape[0] / (W / window_size))
    x = windows.view(B, W // window_size, window_size, -1)
    x = x.view(B, W, -1)
    return x


def get_relative_position_index1d(win_w):
    # get pair-wise relative position index for each token inside the window
    coords = torch.stack(torch.meshgrid(torch.arange(win_w)))
    
    relative_coords = coords[:, :, None] - coords[:, None, :]  # 1, Ww, Ww
    relative_coords = relative_coords.permute(1, 2, 0)  # Ww, Ww, 1
    
    relative_coords[:, :, 0] += (win_w - 1)  # shift to start from 0
    
    return relative_coords.sum(-1)  # Ww*Ww


class WindowedAttentionHead(nn.Module):
    def __init__(self, head_dim, window_size, shift_windows=False, attn_drop=0.):
        super().__init__()
        self.head_dim = head_dim
        self.window_size = window_size
        self.shift_windows = shift_windows
        self.attn_drop = attn_drop
        
        self.scale = self.head_dim ** -0.5
        self.window_area = self.window_size * 1
        
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size - 1, 1))
        )
        nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)
        
        # Get relative position index
        self.register_buffer("relative_position_index", get_relative_position_index1d(window_size))
        
        self.drop_layer = nn.Dropout(attn_drop) if attn_drop > 0 else None
        
        if shift_windows:
            self.shift_size = window_size // 2
        else:
            self.shift_size = 0
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

    def forward(self, q, k, v, train: bool = True):
        B, W, C = q.shape
        
        mask = None
        if self.shift_size > 0:
            img_mask = torch.zeros((1, W, 1), device=q.device)
            cnt = 0
            for w in (
                slice(0, -self.window_size),
                slice(-self.window_size, -self.shift_size),
                slice(-self.shift_size, None)):
                img_mask[:, w, :] = cnt
                cnt += 1
            mask_windows = window_partition1d(img_mask, self.window_size)
            mask_windows = mask_windows.view(-1, self.window_size)
            mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            mask = mask.masked_fill(mask != 0, -100.0).masked_fill(mask == 0, 0.0)

            q = torch.roll(q, shifts=-self.shift_size, dims=1)
            k = torch.roll(k, shifts=-self.shift_size, dims=1)
            v = torch.roll(v, shifts=-self.shift_size, dims=1)

        q = window_partition1d(q, self.window_size)
        k = window_partition1d(k, self.window_size)
        v = window_partition1d(v, self.window_size)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        
        if train:
            attn = attn + self._get_rel_pos_bias()
        else:
            attn = attn + self._get_rel_pos_bias()

        if mask is not None:
            B_, N, _ = attn.shape
            num_win = mask.shape[0]
            attn = attn.view(B_ // num_win, num_win, N, N) + mask.unsqueeze(0)
            attn = attn.view(-1, N, N)
            attn = attn.softmax(dim=-1)
        else:
            attn = attn.softmax(dim=-1)

        if self.drop_layer is not None and train:
            attn = self.drop_layer(attn)

        x = (attn @ v)

        # merge windows
        shifted_x = window_reverse1d(x, self.window_size, W=W)

        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=self.shift_size, dims=1)
        else:
            x = shifted_x
        
        return x, attn
    
    def _get_rel_pos_bias(self):
        relative_position_bias = self.relative_position_bias_table[
            self.relative_position_index.view(-1)].view(
                self.window_area, self.window_area, -1)  # Ww,Ww,1
        relative_position_bias = relative_position_bias.permute(2, 0, 1)  # 1, Ww, Ww
        return relative_position_bias


class AttentionHead(nn.Module):
    def __init__(self, head_dim, attn_drop=0.):
        super().__init__()
        self.head_dim = head_dim
        self.scale = head_dim ** -0.5
        self.drop_layer = nn.Dropout(attn_drop) if attn_drop > 0 else None
    
    def forward(self, q, k, v, train: bool = True):
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)

        if self.drop_layer is not None and train:
            attn = self.drop_layer(attn)
        
        x = (attn @ v)
        return x, attn


class WindowedMultiHeadAttention(nn.Module):
    def __init__(self, dim, window_sizes, shift_windows=False, num_heads=8, qkv_bias=False, 
                 attn_drop=0.0, proj_drop=0.0):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        nn.init.xavier_uniform_(self.qkv.weight)
        
        if isinstance(window_sizes, int):
            window_sizes = _ntuple(num_heads)(window_sizes)
        else:
            assert len(window_sizes) == num_heads
        
        self.attn_heads = nn.ModuleList()
        for i in range(num_heads):
            ws_i = window_sizes[i]
            if ws_i == 0:
                self.attn_heads.append(AttentionHead(self.head_dim, attn_drop))
            else:
                self.attn_heads.append(WindowedAttentionHead(
                    self.head_dim,
                    window_size=ws_i,
                    shift_windows=shift_windows,
                    attn_drop=attn_drop
                ))
        
        self.proj = nn.Linear(dim, dim)
        nn.init.xavier_uniform_(self.proj.weight)
        self.drop_layer = nn.Dropout(proj_drop) if proj_drop > 0 else None
    
    def forward(self, x, train: bool = True):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 3, 0, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        o = []
        for i in range(self.num_heads):
            head_i, attn_i = self.attn_heads[i](q[i], k[i], v[i], train=train)
            o.append(head_i.unsqueeze(0))
        
        o = torch.cat(o, dim=0)
        o = o.permute(1, 2, 0, 3).reshape(B, N, -1)
        o = self.proj(o)
        
        if self.drop_layer is not None and train:
            o = self.drop_layer(o)
        
        return o


class LayerScale(nn.Module):
    def __init__(self, dim, init_values=1e-5):
        super().__init__()
        self.gamma = nn.Parameter(init_values * torch.ones(dim))
        
    def forward(self, x):
        return x * self.gamma


class BNWrapper(nn.Module):
    def __init__(self, num_features, use_running_average=True, use_bias=True, use_scale=True):
        super().__init__()
        self.bn = nn.BatchNorm1d(num_features, affine=use_scale or use_bias)
        
    def forward(self, x, train=True):
        return self.bn(x, train)


class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
                 init_values=None, drop_path=0., act_layer=F.gelu, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias,
                              attn_drop=attn_drop, proj_drop=drop)
        
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, out_features=dim,
                       activation=act_layer, drop=drop)
        
        self.init_values = init_values
        if init_values is not None:
            self.layer_scale1 = LayerScale(dim, init_values)
            self.layer_scale2 = LayerScale(dim, init_values)
    
    def forward(self, x, train: bool = True):
        outputs1 = self.attn(self.norm1(x), train=train)
        
        if self.init_values is not None:
            outputs1 = self.layer_scale1(outputs1)
        
        x = x + self.drop_path(outputs1) if train else x + outputs1
        
        outputs2 = self.mlp(self.norm2(x), train=train)
        
        if self.init_values is not None:
            outputs2 = self.layer_scale2(outputs2)
        
        x = x + self.drop_path(outputs2) if train else x + outputs2
        return x

class MWMHABlock(nn.Module):
    def __init__(self, dim, num_heads, window_sizes, shift_windows=False, mlp_ratio=4.,
                 qkv_bias=False, drop=0., attn_drop=0., init_values=None, drop_path=0.,
                 act_layer=F.gelu, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.wmha = WindowedMultiHeadAttention(
            dim,
            window_sizes=window_sizes,
            shift_windows=shift_windows,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            attn_drop=attn_drop,
            proj_drop=drop
        )
        
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, out_features=dim,
                      activation=act_layer, drop=drop)
                      
        self.init_values = init_values
        if init_values is not None:
            self.layer_scale1 = LayerScale(dim, init_values)
            self.layer_scale2 = LayerScale(dim, init_values)
    
    def forward(self, x, train: bool = True):
        outputs1 = self.wmha(self.norm1(x), train=train)
        
        if self.init_values is not None:
            outputs1 = self.layer_scale1(outputs1)
        
        x = x + self.drop_path(outputs1) if train else x + outputs1
        
        outputs2 = self.mlp(self.norm2(x), train=train)
        
        if self.init_values is not None:
            outputs2 = self.layer_scale2(outputs2)
        
        x = x + self.drop_path(outputs2) if train else x + outputs2
        return x