from math import prod
import math
import os
from functools import partial

import timm.models.vision_transformer
import torch
import torch.nn as nn
from vc_models.models.vit import model_utils
from timm.models.vision_transformer import resize_pos_embed, DropPath, Mlp, LayerScale
from timm.models.crossvit import CrossAttention
import logging
from habitat_vc.models.freeze_batchnorm import convert_frozen_batchnorm
from torch.nn import MultiheadAttention
from torch.distributions import Categorical
import triton
import triton.language as tl
from .vit import VisionTransformer
from . import _VISUALIZE

logger = logging.getLogger(__name__)


class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., keep_rate=1.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.keep_rate = keep_rate

    def forward(self, x, keep_rate=None, tokens=None):
        if keep_rate is None:
            keep_rate = self.keep_rate
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)

        left_tokens = N - 1
        if self.keep_rate < 1 and keep_rate < 1 or tokens is not None:  # double check the keep rate
            left_tokens = math.ceil(keep_rate * (N - 1))
            if tokens is not None:
                left_tokens = tokens
            if left_tokens == N - 1:
                return x, None, None, None, left_tokens
            assert left_tokens >= 1
            cls_attn = attn[:, :, 0, 1:]  # [B, H, N-1]
            cls_attn = cls_attn.mean(dim=1)  # [B, N-1]
            _, idx = torch.topk(cls_attn, left_tokens, dim=1, largest=True, sorted=True)  # [B, left_tokens]
            # cls_idx = torch.zeros(B, 1, dtype=idx.dtype, device=idx.device)
            # index = torch.cat([cls_idx, idx + 1], dim=1)
            index = idx.unsqueeze(-1).expand(-1, -1, C)  # [B, left_tokens, C]

            return x, index, idx, cls_attn, left_tokens

        return  x, None, None, None, left_tokens

def complement_idx(idx, dim):
    """
    Compute the complement: set(range(dim)) - set(idx).
    idx is a multi-dimensional tensor, find the complement for its trailing dimension,
    all other dimension is considered batched.
    Args:
        idx: input index, shape: [N, *, K]
        dim: the max index for complement
    """
    a = torch.arange(dim, device=idx.device)
    ndim = idx.ndim
    dims = idx.shape
    n_idx = dims[-1]
    dims = dims[:-1] + (-1, )
    for i in range(1, ndim):
        a = a.unsqueeze(0)
    a = a.expand(*dims)
    masked = torch.scatter(a, -1, idx, 0)
    compl, _ = torch.sort(masked, dim=-1, descending=False)
    compl = compl.permute(-1, *tuple(range(ndim - 1)))
    compl = compl[n_idx:].permute(*(tuple(range(1, ndim)) + (0,)))
    return compl

class Block(nn.Module):

    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, keep_rate=0., init_values=None,
                 fuse_token=True):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias,
                              attn_drop=attn_drop, proj_drop=drop, keep_rate=keep_rate)
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
        self.keep_rate = keep_rate
        self.mlp_hidden_dim = mlp_hidden_dim
        self.fuse_token = fuse_token

    def forward(self, x, keep_rate=None, tokens=None, get_idx=False):
        if keep_rate is None:
            keep_rate = self.keep_rate  # this is for inference, use the default keep rate
        B, N, C = x.shape

        tmp, index, idx, cls_attn, left_tokens = self.attn(self.norm1(x), keep_rate, tokens)
        x = x + self.drop_path(tmp)

        if index is not None:
            # B, N, C = x.shape
            non_cls = x[:, 1:]
            x_others = torch.gather(non_cls, dim=1, index=index)  # [B, left_tokens, C]

            if self.fuse_token:
                compl = complement_idx(idx, N - 1)  # [B, N-1-left_tokens]
                non_topk = torch.gather(non_cls, dim=1, index=compl.unsqueeze(-1).expand(-1, -1, C))  # [B, N-1-left_tokens, C]

                non_topk_attn = torch.gather(cls_attn, dim=1, index=compl)  # [B, N-1-left_tokens]
                extra_token = torch.sum(non_topk * non_topk_attn.unsqueeze(-1), dim=1, keepdim=True)  # [B, 1, C]
                x = torch.cat([x[:, 0:1], x_others, extra_token], dim=1)
            else:
                x = torch.cat([x[:, 0:1], x_others], dim=1)

        x = x + self.drop_path(self.mlp(self.norm2(x)))
        n_tokens = x.shape[1] - 1
        if get_idx and index is not None:
            return x, n_tokens, idx
        return x, n_tokens, None


class EViT(VisionTransformer):
    """Vision Transformer with support for global average pooling"""

    def __init__(
        self, keep_ratio=0.7, drop_loc=(3,6,9), **kwargs
    ):
        kwargs['block_fn'] = Block
        super(EViT, self).__init__(**kwargs)
        keep_ratios = [1.0 for _ in range(kwargs.get('depth', 12))]
        for l in drop_loc:
            keep_ratios[l] = keep_ratio
        self.keep_ratios = keep_ratios

    def forward_features(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        # add pos embed w/o cls token
        x = x + self.pos_embed[:, 1:, :]

        # masking: length -> length * mask_ratio
        if self.mask_ratio is not None:
            x, _, _ = self.random_masking(x, mask_ratio=self.mask_ratio)

        # append cls token
        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        x = torch.cat((cls_token.expand(B, -1, -1), x), dim=1)

        
        left_tokens = []
        idxs = []
        xs = []
        for i, blk in enumerate(self.blocks):
            nx, left_token, idx = blk(x, self.keep_ratios[i], None, True)
            left_tokens.append(left_token)
            if idx is not None:
                idxs.append(idx)
                xs.append(x)
            x = nx
        
        prune_idx = torch.zeros_like(x[:, :, 0])
        prune_idx.fill_(len(xs))
        # recover feature map
        for i in range(len(xs)-1, -1, -1):
            x_before = xs[i]
            x_after = x
            keep_idx = idxs[i] + 1 # cls token is 0
            x_before.scatter_(1, keep_idx.unsqueeze(-1).expand(-1, -1, x_before.shape[-1]), x_after[:, 1:, :])
            x = x_before

            prune_idx_before = torch.zeros_like(x_before[:, :, 0])
            prune_idx_before.fill_(i)
            prune_idx_before.scatter_(1, keep_idx, prune_idx[:, 1:])
            prune_idx = prune_idx_before
        
        
        if _VISUALIZE:
            forward_dict = dict(merge_function=self.merge_forward_dict_eval)
            forward_dict['visualize_mask'] = (prune_idx_before.view(B, -1)[:, 1:] >= 3).float()
            self.forward_dict = [forward_dict]

        return self.handle_outcome(x)

    def merge_forward_dict_eval(self, forward_dicts):
        agent_metrics = dict()
        detailed_metrics = dict()
        
        return dict(agent_metrics=agent_metrics, detailed_metrics=detailed_metrics)

    def forward(self, x):
        return self.forward_features(x)


def evit_small_patch16_224(**kwargs):
    """ DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
    ImageNet-1k weights from https://github.com/facebookresearch/deit.
    """
    # model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
    model = EViT(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
    return model

def evit_tiny_patch16_224(**kwargs):
    model = EViT(
        patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs
    )
    return model