import os
import math
import collections.abc
from itertools import repeat
from functools import partial, reduce
from operator import mul

import random
import torch
import torch.nn as nn

# from models.backbones.med import BertModel
# from models.head import iBOTHead

from timm.models.layers import trunc_normal_, DropPath
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
from transformers import AutoTokenizer, BertConfig, BertTokenizer, logging

# helpers
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
logging.set_verbosity_error()

def _ntuple(n):
    def parse(x):
        if isinstance(x, collections.abc.Iterable):
            return x
        return tuple(repeat(x, n))
    return parse

to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple

class GlobalEmbedding(nn.Module):
    def __init__(self,
                 input_dim: int = 512,
                 hidden_dim: int = 2048,
                 output_dim: int = 128) -> None:
        super().__init__()

        self.head = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            # nn.LayerNorm(hidden_dim),
            nn.ReLU(inplace=False),
            # nn.GELU(),
            nn.Linear(hidden_dim, output_dim),
            nn.BatchNorm1d(output_dim, affine=False)  # output layer
            # nn.LayerNorm(output_dim, elementwise_affine=False)
        )

    def forward(self, x):
        return self.head(x)

class PatchEmbed(nn.Module):
    """ 2D Image to Patch Embedding
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
        self.num_patches = self.grid_size[0] * self.grid_size[1]
        self.flatten = flatten

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        B, C, H, W = x.shape
        x = self.proj(x)
        if self.flatten:
            x = x.flatten(2).transpose(1, 2)
        x = self.norm(x)
        return x

class Mlp(nn.Module):
    """ MLP as used in Vision Transformer, MLP-Mixer and related networks
    """

    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
        self.scale = qk_scale or head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.attn_gradients = None
        self.attention_map = None

    def save_attn_gradients(self, attn_gradients):
        self.attn_gradients = attn_gradients

    def get_attn_gradients(self):
        return self.attn_gradients

    def save_attention_map(self, attention_map):
        self.attention_map = attention_map

    def get_attention_map(self):
        return self.attention_map

    def forward(self, x, register_hook=False):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C //
                                  self.num_heads).permute(2, 0, 3, 1, 4)
        # make torchscript happy (cannot use tensor as tuple)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        
        if register_hook:
            self.save_attention_map(attn)
            if attn.requires_grad:
                attn.register_hook(self.save_attn_gradients)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class Block(nn.Module):

    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(
            drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
                       act_layer=act_layer, drop=drop)

        if use_grad_checkpointing:
            self.attn = checkpoint_wrapper(self.attn)
            self.mlp = checkpoint_wrapper(self.mlp)

    def forward(self, x, register_hook=False):
        x = x + self.drop_path(self.attn(self.norm1(x),
                               register_hook=register_hook))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

class VisionTransformer(nn.Module):
    """ Vision Transformer
    A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`  -
        https://arxiv.org/abs/2010.11929
    """
    def __init__(self, image2D_size=224, patch2D_size=16, in_chans2D=3, embed_dim=768, hidden_dim=2048, output_dim=128, patch_out_dim=8192, 
                 with_distill=False, masked_im_modeling=False, mask_ratio=0., depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., 
                 attn_drop_rate=0., drop_path_rate=0., norm_layer=None, use_grad_checkpointing=False, ckpt_layer=0):
        super().__init__()
        # num_features for consistency with other models
        self.num_features = self.embed_dim = embed_dim
        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
    
        self.patch_embed = PatchEmbed(img_size=image2D_size, patch_size=patch2D_size, in_chans=in_chans2D, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
        self.blocks = nn.ModuleList([
            Block(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 
                use_grad_checkpointing=(
                    use_grad_checkpointing and i >= depth-ckpt_layer)
            )
            for i in range(depth)])
    
        self.norm = norm_layer(embed_dim)

        ### initialize the vit backbone
        trunc_normal_(self.pos_embed, std=.02)
        trunc_normal_(self.cls_token, std=.02)
        self.apply(self._init_weights)

        # Mask parameters for vision transformer
        self.masked_im_modeling = masked_im_modeling
        self.mask_ratio = mask_ratio
        # MGCA's instance-level head
        if not with_distill:
            self.global_embedding = GlobalEmbedding(embed_dim, hidden_dim, output_dim)
        else:
            if self.masked_im_modeling:
                self.global_embedding = GlobalEmbedding(embed_dim, hidden_dim, output_dim)
        # CLIP's head
        # self.global_embedding = nn.Parameter(torch.empty(embed_dim, output_dim))
        # nn.init.normal_(self.global_embedding, std=embed_dim ** -0.5)

        if self.masked_im_modeling:
            self.masked_embed = nn.Parameter(torch.zeros(1, embed_dim))

        # if with_distill:
        #     # head: norm_last_layer TODO True --- stable, False --- performance better
        #     if self.masked_im_modeling:
        #         self.local_embedding = iBOTHead(*(embed_dim, patch_out_dim), patch_out_dim=patch_out_dim, norm_last_layer=False, shared_head=True)
        #     else:
        #         self.local_embedding = iBOTHead(*(embed_dim, patch_out_dim), patch_out_dim=patch_out_dim, shared_head=True)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    # now not need interpolate pos encoding
    def interpolate_pos_encoding(self, x, w, h):
        npatch = x.shape[1] - 1
        N = self.pos_embed.shape[1] - 1
        if npatch == N and w == h:
            return self.pos_embed
        class_pos_embed = self.pos_embed[:, 0]
        patch_pos_embed = self.pos_embed[:, 1:]
        dim = x.shape[-1]
        h0 = h // self.patch_embed.patch_size[0]
        w0 = w // self.patch_embed.patch_size[0]
        h0, w0 = h0 + 0.1, w0 + 0.1
        patch_pos_embed = nn.functional.interpolate(
            patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
            scale_factor=(h0 / math.sqrt(N), w0 / math.sqrt(N)),
            mode='bicubic',
        )
        assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1]
        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)

    def _mask_rand(self, input, mask_ratio):
        """
        Perform per-sample random masking by per-sample shuffling.
        Per-sample shuffling is done by argsort random noise.
        input: [N, L, D], squence
        """
        N, L, D = input.shape # batch, length, dim
        len_keep = int(L * (1 - mask_ratio))

        noise = torch.rand(N, L, device=input.device) # noise in [0, 1]

        # sort noise for each sample
        ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        # generate the binary mask: 0 is keep, 1 is remove
        mask = torch.ones([N, L], device=input.device)
        mask[:, :len_keep] = 0
        # unshuffle to get the binary mask
        mask = torch.gather(mask, dim=1, index=ids_restore)

        return mask    

    ### AttMask
    def _mask_attn(self, input, mask_ratio, attn):
        # N, L, D = input.shape # batch, length, dim
        # len_keep = int(L * (1 - mask_ratio))
        # # L1, B, H, _ = attn.size()
        # # score = attn.sum(dim=0).sum(dim=1) / (L1*H)
        # score = attn
        # _, indices = torch.topk(score, len_keep, dim=1) 
        # masks = torch.ones([N, L], device=input.device)

        # masks[torch.arange(N).unsqueeze(1), indices] = 0 
        # return masks
        # three masking mode ['attmask_high', 'attmask_low', 'attnmask_lowmix']
        masking_mode = 'attmask_low'
        masking_prob = 0.8

        N = int(attn.shape[1] * mask_ratio)
        attn_mask = torch.zeros(attn.shape, dtype=torch.bool, device=attn.device)
        if masking_mode == 'attmask_high':
            idx = torch.argsort(attn, descending=True)[:, :N]
        elif masking_mode in ['attmask_low', 'attnmask_lowmix']:
            idx = torch.argsort(attn, descending=False)[:, :N]
        else:
            raise('Use attmask_high or attmask_low')

        attn_mask.scatter_(1, idx, True)
        if masking_mode == 'attnmask_lowmix':
            attn_mask = ~attn_mask 
            ratio = 0.5
            tmp_masks = attn_mask.clone()
            n_tokens = attn_mask.sum(-1)[0]
            reveal_tokens = int(ratio * n_tokens)
            selected_true = torch.multinomial(tmp_masks.float(), reveal_tokens)
            attn_mask.scatter_(1, selected_true, False)

            tmp_masks = ~tmp_masks 
            selected_false = torch.multinomial(tmp_masks.float(), reveal_tokens)
            attn_mask.scatter_(1, selected_false, True)
            attn_mask = ~attn_mask
        generator = torch.rand(attn.shape[0], device=attn.device)
        attn_mask[generator > masking_prob] = False
        return attn_mask

    def forward(self, x, attn=None):
        # distill
        B, _, H, W = x.size()
        x = self.patch_embed(x)

        ### little modification for verison w/o patch mim
        # x = x + self.pos_embed[:, 1:, :]
        # if self.masked_im_modeling:
        #     if attn is not None:
        #         mask = self._mask_attn(x, self.mask_ratio, attn)
        #     else:
        #         mask = self._mask_rand(x, self.mask_ratio)
        #     _, L, D = x.size()
        #     ids = torch.argsort(mask.long(), dim=1)  # ascend
        #     mask_len = int(mask[0].sum())
        #     ids_keep = ids[:, : L - mask_len]
        #     x = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
        # cls_token = self.cls_token + self.pos_embed[:, :1, :]
        # # stole cls_tokens impl from Phil Wang, thanks
        # cls_token = cls_token.expand(B, -1, -1)  
        # x = torch.cat((cls_token, x), dim=1)
        # x = self.pos_drop(x)

        if self.masked_im_modeling:
            if attn is not None:
                mask = self._mask_attn(x, self.mask_ratio, attn).bool()
            else:
                mask = self._mask_rand(x, self.mask_ratio).bool()
            x[mask, :] = self.masked_embed.to(x.dtype)

        # stole cls_tokens impl from Phil Wang, thanks
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)

        if self.masked_im_modeling:
            for i, blk in enumerate(self.blocks):
                x = blk(x, register_hook = False)
            x = self.norm(x)
            return x, mask
        else:
            attn_list = []
            for i, blk in enumerate(self.blocks):
                x = blk(x, register_hook = True)
                cls_attn = blk.attn.get_attention_map()[:, :, 0, 1:].detach().clone()
                attn_list.append(cls_attn.mean(dim=1))
            x = self.norm(x)
            cls_attn = torch.stack(attn_list[:], dim=0).mean(dim=0)
            # cls_attn = attn_list[-1]
            return x #, cls_attn

def create_vit(image2D_size, hidden_dim, output_dim, patch_out_dim, \
                with_distill, masked_im_modeling, mask_ratio, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):

    ### only support VIT-B/16
    vision_width = 768
    visual_encoder = VisionTransformer(image2D_size=image2D_size, patch2D_size=16, embed_dim=vision_width, hidden_dim=hidden_dim, output_dim=output_dim, patch_out_dim=patch_out_dim,  
                                           with_distill=with_distill, masked_im_modeling = masked_im_modeling, mask_ratio = mask_ratio, depth=12, num_heads=12, 
                                           use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer, drop_path_rate=0 or drop_path_rate
                                           )

    return visual_encoder

class BertEncoder(nn.Module):
    def __init__(self,
                 tokenizer: BertTokenizer = None,
                 emb_dim: int = 768,
                 output_dim: int = 128,
                 hidden_dim: int = 2048,
                 freeze_bert: bool = True):
        super(BertEncoder, self).__init__()
        self.bert_type = "emilyalsentzer/Bio_ClinicalBERT"
        self.last_n_layers = 1
        self.aggregate_method = "sum"
        self.embedding_dim = emb_dim
        self.output_dim = output_dim
        self.freeze_bert = freeze_bert
        self.agg_tokens = True
        # self.max_sent_num = 10

        self.config = BertConfig.from_json_file(
            os.path.join(BASE_DIR, "../configs/bert_config.json"))
        self.model = BertModel.from_pretrained(
            self.bert_type,
            config=self.config,
            add_pooling_layer=False,
        )

        if tokenizer:
            self.tokenizer = tokenizer
        else:
            self.tokenizer = AutoTokenizer.from_pretrained(self.bert_type)

        self.idxtoword = {v: k for k, v in self.tokenizer.get_vocab().items()}

        if self.freeze_bert is True:
            print("Freezing BERT model")
            for param in self.model.parameters():
                param.requires_grad = False

        # MGCA's head
        self.global_embed = GlobalEmbedding(
            self.embedding_dim, hidden_dim, self.output_dim)
        
        # CLIP's head
        # self.global_embedding = nn.Parameter(torch.empty(self.embedding_dim, self.output_dim))
        # nn.init.normal_(self.global_embedding, std=self.embedding_dim ** -0.5)

    def aggregate_tokens(self, embeddings, caption_ids, last_layer_attn):
        '''
        :param embeddings: bz, 1, 112, 768
        :param caption_ids: bz, 112
        :param last_layer_attn: bz, 111
        '''
        _, num_layers, num_words, dim = embeddings.shape
        embeddings = embeddings.permute(0, 2, 1, 3)
        agg_embs_batch = []
        sentences = []
        last_attns = []

        # loop over batch
        for embs, caption_id, last_attn in zip(embeddings, caption_ids, last_layer_attn):
            agg_embs = []
            token_bank = []
            words = []
            word_bank = []
            attns = []
            attn_bank = []

            # loop over sentence
            for word_emb, word_id, attn in zip(embs, caption_id, last_attn):
                word = self.idxtoword[word_id.item()]
                if word == "[SEP]":
                    new_emb = torch.stack(token_bank)
                    new_emb = new_emb.sum(axis=0)
                    agg_embs.append(new_emb)
                    words.append("".join(word_bank))
                    attns.append(sum(attn_bank))
                    agg_embs.append(word_emb)
                    words.append(word)
                    attns.append(attn)
                    break
                # This is because some words are divided into two words.
                if not word.startswith("##"):
                    if len(word_bank) == 0:
                        token_bank.append(word_emb)
                        word_bank.append(word)
                        attn_bank.append(attn)
                    else:
                        new_emb = torch.stack(token_bank)
                        new_emb = new_emb.sum(axis=0)
                        agg_embs.append(new_emb)
                        words.append("".join(word_bank))
                        attns.append(sum(attn_bank))

                        token_bank = [word_emb]
                        word_bank = [word]
                        attn_bank = [attn]
                else:
                    token_bank.append(word_emb)
                    word_bank.append(word[2:])
                    attn_bank.append(attn)
            agg_embs = torch.stack(agg_embs)
            padding_size = num_words - len(agg_embs)
            paddings = torch.zeros(padding_size, num_layers, dim)
            paddings = paddings.type_as(agg_embs)
            words = words + ["[PAD]"] * padding_size
            last_attns.append(
                torch.cat([torch.tensor(attns), torch.zeros(padding_size)], dim=0))
            agg_embs_batch.append(torch.cat([agg_embs, paddings]))
            sentences.append(words)

        agg_embs_batch = torch.stack(agg_embs_batch)
        agg_embs_batch = agg_embs_batch.permute(0, 2, 1, 3)
        last_atten_pt = torch.stack(last_attns)
        last_atten_pt = last_atten_pt.type_as(agg_embs_batch)

        return agg_embs_batch, sentences, last_atten_pt

    def forward(self, ids, attn_mask, token_type, get_local=False):
        outputs = self.model(ids, attn_mask, token_type,
                             return_dict=True, mode="text")

        last_layer_attn = outputs.attentions[-1][:, :, 0, 1:].mean(dim=1)
        all_feat = outputs.last_hidden_state.unsqueeze(1)

        if self.agg_tokens:
            all_feat, sents, last_atten_pt = self.aggregate_tokens(
                all_feat, ids, last_layer_attn)
            last_atten_pt = last_atten_pt[:, 1:].contiguous()
        else:
            sents = [[self.idxtoword[w.item()] for w in sent]
                     for sent in ids]

        if self.last_n_layers == 1:
            all_feat = all_feat[:, 0]

        report_feat = all_feat[:, 0].contiguous()
        word_feat = all_feat[:, 1:].contiguous()

        return report_feat, word_feat, last_atten_pt, sents 


