import os
import torch
from torch import nn
import torch.nn.functional as F
from functools import partial
from itertools import repeat
from einops import rearrange
import collections.abc

from timm.models.vision_transformer import PatchEmbed
from timm.models.layers import trunc_normal_, DropPath
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
from transformers import logging
# helpers
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
logging.set_verbosity_error()

def _ntuple(n):
    def parse(x):
        if isinstance(x, collections.abc.Iterable):
            return x
        return tuple(repeat(x, n))
    return parse


to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple

class GlobalEmbedding(nn.Module):
    def __init__(self,
                 input_dim: int = 512,
                 hidden_dim: int = 2048,
                 output_dim: int = 128) -> None:
        super().__init__()

        self.head = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, output_dim),
            nn.BatchNorm1d(output_dim, affine=False)  # output layer
        )

    def forward(self, x):
        return self.head(x)

class PatchEmbed3D(nn.Module):
    """ 3D Image to Patch Embedding
    """
    def __init__(self, img_size=128, patch_size=16, slices=32, slice_patch=4, in_chans=1, embed_dim=768, norm_layer=None, flatten=True):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        self.img_size = img_size + (slices,)
        self.patch_size = patch_size + (slice_patch,)
        self.grid_size = (self.img_size[0] // self.patch_size[0], self.img_size[1] // self.patch_size[1], self.img_size[2] // self.patch_size[2])
        self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
        self.flatten = flatten

        self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=self.patch_size, stride=self.patch_size)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        B, C, H, W, D = x.shape
        assert H == self.img_size[0] and W == self.img_size[1] and D == self.img_size[2], \
            f"Input image size ({H}*{W}*{D}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}*{self.img_size[2]})."
        x = self.proj(x)
        if self.flatten:
            x = x.flatten(2).transpose(1, 2)  # BCHWD -> BNC
        x = self.norm(x)
        return x

class PatchEmbed(nn.Module):
    """ 2D Image to Patch Embedding
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
        self.num_patches = self.grid_size[0] * self.grid_size[1]
        self.flatten = flatten

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        B, C, H, W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x)
        if self.flatten:
            x = x.flatten(2).transpose(1, 2)  # BCHW -> BNC
        x = self.norm(x)
        return x

class Mlp(nn.Module):
    """ MLP as used in Vision Transformer, MLP-Mixer and related networks
    """

    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
        self.scale = qk_scale or head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.attn_gradients = None
        self.attention_map = None

    def save_attn_gradients(self, attn_gradients):
        self.attn_gradients = attn_gradients

    def get_attn_gradients(self):
        return self.attn_gradients

    def save_attention_map(self, attention_map):
        self.attention_map = attention_map

    def get_attention_map(self):
        return self.attention_map

    def forward(self, x, register_hook=False):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C //
                                  self.num_heads).permute(2, 0, 3, 1, 4)
        # make torchscript happy (cannot use tensor as tuple)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        if register_hook:
            self.save_attention_map(attn)
            if attn.requires_grad:
                attn.register_hook(self.save_attn_gradients)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class Block(nn.Module):

    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(
            drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
                       act_layer=act_layer, drop=drop)

        if use_grad_checkpointing:
            self.attn = checkpoint_wrapper(self.attn)
            self.mlp = checkpoint_wrapper(self.mlp)

    def forward(self, x, register_hook=False):
        x = x + self.drop_path(self.attn(self.norm1(x),
                               register_hook=register_hook))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

class VisionTransformer(nn.Module):
    """ Vision Transformer
    A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`  -
        https://arxiv.org/abs/2010.11929
    """

    def __init__(self, image2D_size=224, patch2D_size=16, image3D_size=128, patch3D_size=16, slices=32, slice_patch_size=4,
                 in_chans2D=3, in_chans3D=1, embed_dim=768, hidden_dim=2048, output_dim=128, depth=12, num_heads=12, mlp_ratio=4., 
                 qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None, use_grad_checkpointing=False, 
                 ckpt_layer=0):
        """
        Args:
            img_size (int, tuple): input image size
            patch_size (int, tuple): patch size
            in_chans (int): number of input channels
            num_classes (int): number of classes for classification head
            embed_dim (int): embedding dimension
            depth (int): depth of transformer
            num_heads (int): number of attention heads
            mlp_ratio (int): ratio of mlp hidden dim to embedding dim
            qkv_bias (bool): enable bias for qkv if True
            qk_scale (float): override default qk scale of head_dim ** -0.5 if set
            representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
            drop_rate (float): dropout rate
            attn_drop_rate (float): attention dropout rate
            drop_path_rate (float): stochastic depth rate
            norm_layer: (nn.Module): normalization layer
        """
        super().__init__()
        # num_features for consistency with other models
        self.num_features = self.embed_dim = embed_dim
        self.slices = slices
        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)

        # different patchembed for 2D and 3D
        self.patch_embed = PatchEmbed(
            img_size=image2D_size, patch_size=patch2D_size, in_chans=in_chans2D, embed_dim=embed_dim)
        self.patch_embed3D = PatchEmbed3D(
            img_size=image3D_size, patch_size=patch3D_size, slices=slices, slice_patch=slice_patch_size,
            in_chans=in_chans3D, embed_dim=embed_dim
        )

        num_patches = self.patch_embed.num_patches
        num_patches3D = self.patch_embed3D.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(
            torch.zeros(1, num_patches + 1, embed_dim))
        self.cls_token3D = nn.Parameter(torch.ones(1, 1, embed_dim))
        self.pos_embed3D = nn.Parameter(
            torch.ones(1, num_patches3D + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        # stochastic depth decay rule
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
        self.blocks = nn.ModuleList([
            Block(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
                use_grad_checkpointing=(
                    use_grad_checkpointing and i >= depth-ckpt_layer)
            )
            for i in range(depth)])
        self.norm = norm_layer(embed_dim)
        self.global_embedding = GlobalEmbedding(embed_dim, hidden_dim, output_dim)

        trunc_normal_(self.pos_embed, std=.02)
        trunc_normal_(self.cls_token, std=.02)
        trunc_normal_(self.pos_embed3D, std=.02)
        trunc_normal_(self.cls_token3D, std=.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'pos_embed', 'cls_token'}

    def forward2D(self, x, register_blk=-1):
        B = x.shape[0]
        x = self.patch_embed(x)

        # stole cls_tokens impl from Phil Wang, thanks
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        x = x + self.pos_embed[:, :x.size(1), :]
        x = self.pos_drop(x)

        for i, blk in enumerate(self.blocks):
            x = blk(x, register_blk == i)
        x = self.norm(x)

        return x

    def forward3D(self, x, register_blk=-1):
        B = x.shape[0]
        x = self.patch_embed3D(x)

        # stole cls_tokens impl from Phil Wang, thanks
        cls_tokens = self.cls_token3D.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        x = x + self.pos_embed3D[:, :x.size(1), :]
        x = self.pos_drop(x)

        for i, blk in enumerate(self.blocks):
            x = blk(x, register_blk == i)
        x = self.norm(x)

        return x
    
    def forward(self, x, modal_type="3D"):
        if modal_type == '2D':
            x = self.forward2D(x)
        elif modal_type == '3D':
            x = rearrange(x, "(B D) C H W -> B C H W D", D=self.slices)
            x = self.forward3D(x)
        return x

def create_vit(vit, image2D_size, image3D_size, slices, hidden_dim, output_dim, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):

    assert vit in ['base', 'large'], "vit parameter must be base or large"
    if vit == 'base':
        vision_width = 768
        visual_encoder = VisionTransformer(image2D_size=image2D_size, patch2D_size=16, image3D_size=image3D_size, patch3D_size=16, slices=slices, slice_patch_size=4, 
                                           embed_dim=vision_width, hidden_dim=hidden_dim, output_dim=output_dim, depth=6, num_heads=3, 
                                           use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer, drop_path_rate=0 or drop_path_rate
                                           )
    elif vit == 'large':
        vision_width = 1024
        visual_encoder = VisionTransformer(image2D_size=image2D_size, patch2D_size=16, image3D_size=image3D_size, patch3D_size=16, slices=slices, slice_patch_size=4, 
                                           embed_dim=vision_width, hidden_dim=hidden_dim, output_dim=output_dim, depth=24, num_heads=16, 
                                           use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer, drop_path_rate=0.1 or drop_path_rate
                                           )
    return visual_encoder


def load_state_with_same_shape(model, weights):
    model_state = model.state_dict()
    if list(weights.keys())[0].startswith('module.'):
        weights = {k.partition('module.')[2]:weights[k] for k in weights.keys()}

    if list(weights.keys())[0].startswith('image_encoder_q.'):
        weights = {k.partition('image_encoder_q.')[2]:weights[k] for k in weights.keys()}

    filtered_weights = {
        k: v for k, v in weights.items() if k in model_state  and v.size() == model_state[k].size()
    }
    print("Loading weights:" + ', '.join(filtered_weights.keys()))
    return filtered_weights


if __name__ == '__main__':
    import pdb; pdb.set_trace()
    vision_encoder = create_vit('base', 224, 128, 32)
    import pdb; pdb.set_trace()
    x = torch.randn((4,3,224,224))
    out = vision_encoder(x, '2D')
    x = torch.randn((4,1,128,128,32))
    out = vision_encoder(x, '3D')
    import pdb; pdb.set_trace()
    print('Hello World')
    