# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# adapted from:
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
# DeiT: https://github.com/facebookresearch/deit
# --------------------------------------------------------
from collections import OrderedDict
import contextlib
from math import prod
import os
from functools import partial

import numpy as np
import timm.models.vision_transformer
import torch
import torch.nn as nn
import torch.nn.functional as F
from vc_models.models.vit import model_utils
from timm.models.vision_transformer import resize_pos_embed, Mlp, Attention, Block
from timm.models.crossvit import CrossAttention
import logging
from habitat_vc.models.freeze_batchnorm import convert_frozen_batchnorm
from .vit import VisionTransformer
from . import _VISUALIZE

logger = logging.getLogger(__name__)


class AttentionWithMask(Attention):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__(dim, num_heads, qkv_bias, attn_drop, proj_drop)

    def softmax_with_policy(self, attn, policy, eps=1e-6):
        B, N, _ = policy.size()
        B, H, N, N = attn.size()
        attn_policy = policy.reshape(B, 1, 1, N)  # * policy.reshape(B, 1, N, 1)
        eye = torch.eye(N, dtype=attn_policy.dtype, device=attn_policy.device).view(1, 1, N, N)
        attn_policy = attn_policy + (1.0 - attn_policy) * eye
        max_att = torch.max(attn, dim=-1, keepdim=True)[0]
        attn = attn - max_att
        # attn = attn.exp_() * attn_policy
        # return attn / (attn.sum(dim=-1, keepdim=True) + eps)

        # for stable training
        attn = attn.to(torch.float32).exp_() * attn_policy.to(torch.float32)
        attn = (attn + eps/N) / (attn.sum(dim=-1, keepdim=True) + eps)
        return attn.type_as(max_att)

    def forward(self, x, policy):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale

        if policy is None:
            attn = attn.softmax(dim=-1)
        else:
            attn = self.softmax_with_policy(attn, policy)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class BlockWithAttention(Block):

    def __init__(
            self,
            dim,
            num_heads,
            mlp_ratio=4.,
            qkv_bias=False,
            drop=0.,
            attn_drop=0.,
            init_values=None,
            drop_path=0.,
            act_layer=nn.GELU,
            norm_layer=nn.LayerNorm
    ):
        super().__init__(
            dim=dim,
            num_heads=num_heads,
            mlp_ratio=mlp_ratio,
            qkv_bias=qkv_bias,
            drop=drop,
            attn_drop=attn_drop,
            init_values=init_values,
            drop_path=drop_path,
            act_layer=act_layer,
            norm_layer=norm_layer
        )
        del self.attn
        self.attn = AttentionWithMask(
            dim=dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            attn_drop=attn_drop,
            proj_drop=drop
        )

    def forward(self, x, policy=None):
        x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), policy=policy)))
        x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
        return x


class PredictorLG(nn.Module):
    """ Image to Patch Embedding
    """
    def __init__(self, embed_dim=384):
        super().__init__()
        self.in_conv = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, embed_dim),
            nn.GELU()
        )

        self.out_conv = nn.Sequential(
            nn.Linear(embed_dim, embed_dim // 2),
            nn.GELU(),
            nn.Linear(embed_dim // 2, embed_dim // 4),
            nn.GELU(),
            nn.Linear(embed_dim // 4, 2),
            nn.LogSoftmax(dim=-1)
        )

    def forward(self, x, policy, task=None):
        B, N, _ = x.size()
        if task is not None:
            x = torch.cat([x, task.unsqueeze(1).expand(-1, N, -1)], dim=-1)
        x = self.in_conv(x)
        B, N, C = x.size()
        local_x = x[:,:, :C//2]
        global_x = (x[:,:, C//2:] * policy).sum(dim=1, keepdim=True) / torch.sum(policy, dim=1, keepdim=True)
        x = torch.cat([local_x, global_x.expand(B, N, C//2)], dim=-1)
        return self.out_conv(x)


def batch_index_select(x, idx):
    idx = torch.sort(idx, dim=1)[0]
    if len(x.size()) == 3:
        B, N, C = x.size()
        N_new = idx.size(1)
        offset = torch.arange(B, dtype=torch.long, device=x.device).view(B, 1) * N
        idx = idx + offset
        out = x.reshape(B*N, C)[idx.reshape(-1)].reshape(B, N_new, C)
        return out
    elif len(x.size()) == 2:
        B, N = x.size()
        N_new = idx.size(1)
        offset = torch.arange(B, dtype=torch.long, device=x.device).view(B, 1) * N
        idx = idx + offset
        out = x.reshape(B*N)[idx.reshape(-1)].reshape(B, N_new)
        return out
    else:
        raise NotImplementedError
    

class DynamicVisionTransformer(VisionTransformer):
    def __init__(
        self, reduction_layers=(3, 6, 9), keep_ratios=(0.7, 0.49, 0.343), hidden_state_dim=0,
        freeze_backbone=False, freeze_batchnorm=False, **kwargs):
        assert 'block_fn' not in kwargs, "block_fn should not be set in kwargs, use reduction_layers instead"
        kwargs['block_fn'] = BlockWithAttention
        super().__init__(**kwargs)
        self.reduction_layers = reduction_layers
        self.keep_ratios = keep_ratios
        self.embedding_dim = self.num_features
        self.freeze_backbone = freeze_backbone
        self.freeze_batchnorm = freeze_batchnorm
        self.use_task = hidden_state_dim > 0
        self.forward_dict = dict()
        
        score_feature_dim = self.embedding_dim
        if self.use_task:
            self.task_embed = nn.Linear(hidden_state_dim, self.embedding_dim)
            score_feature_dim += self.embedding_dim
        score_predictors = [
            PredictorLG(score_feature_dim) for _ in reduction_layers
        ]
        self.score_predictors = nn.ModuleList(score_predictors)
    
    def train(self, mode = True):
        super().train(mode)
        
        if mode:
            if self.freeze_backbone:
                for p in self.parameters():
                    p.requires_grad = False
                if self.freeze_batchnorm:
                    self.norm_pre = convert_frozen_batchnorm(self.norm_pre)
                    self.blocks = convert_frozen_batchnorm(self.blocks)
                    self.norm = convert_frozen_batchnorm(self.norm)
                    self.fc_norm = convert_frozen_batchnorm(self.fc_norm)
                for p in self.score_predictors.parameters():
                    p.requires_grad = True
                if self.use_task:
                    for p in self.task_embed.parameters():
                        p.requires_grad = True

    
    def forward_vit(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)

        # add pos embed w/o cls token
        x = x + self.pos_embed[:, 1:, :]

        # masking: length -> length * mask_ratio
        if self.mask_ratio is not None:
            x, _, _ = self.random_masking(x, mask_ratio=self.mask_ratio)

        # append cls token
        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        x = torch.cat((cls_token.expand(B, -1, -1), x), dim=1)

        x = self.blocks(x)
        return self.handle_outcome(x)
    
    def forward_features_train(self, x, hidden_states=None):
        forward_dict = dict()
        B = x.shape[0]
        if self.use_task:
            task = self.task_embed(hidden_states)
        else:
            task = None
        with torch.no_grad() if self.freeze_backbone else contextlib.nullcontext():
            x = self.patch_embed(x)
            # add pos embed w/o cls token
            x = x + self.pos_embed[:, 1:, :]

            # masking: length -> length * mask_ratio
            if self.mask_ratio is not None:
                x, _, _ = self.random_masking(x, mask_ratio=self.mask_ratio)

            # append cls token
            cls_token = self.cls_token + self.pos_embed[:, :1, :]
            x = torch.cat((cls_token.expand(B, -1, -1), x), dim=1)
        
        # forward blocks
        xs = []
        reduction_id = 0
        patch_num = x.shape[1] - 1
        prev_decision = torch.ones(B, patch_num, 1, dtype=x.dtype, device=x.device)
        policy = torch.ones(B, patch_num + 1, 1, dtype=x.dtype, device=x.device)

        for i, block in enumerate(self.blocks):
            if i in self.reduction_layers:
                assert i == self.reduction_layers[reduction_id], f"reduction layer {i} not in {self.reduction_layers}"
                spatial_x = x[:, 1:]
                pred_score = self.score_predictors[reduction_id](spatial_x, prev_decision, task).reshape(B, -1, 2)
                
                if self.training:
                    hard_keep_decision = F.gumbel_softmax(pred_score, hard=True)[:, :, 0:1] * prev_decision
                    forward_dict[f'keep_decision_{i}'] = hard_keep_decision.reshape(B, patch_num)
                    cls_policy = torch.ones(B, 1, 1, dtype=hard_keep_decision.dtype, device=hard_keep_decision.device)
                    policy = torch.cat([cls_policy, hard_keep_decision], dim=1)
                    with torch.no_grad() if self.freeze_backbone else contextlib.nullcontext():
                        x = block(x, policy=policy)
                    prev_decision = hard_keep_decision
                else:
                    raise NotImplementedError
                    score = pred_score[:,:,0]
                    num_keep_node = int(patch_num * self.keep_ratios[reduction_id])
                    keep_policy = torch.argsort(score, dim=1, descending=True)[:, :num_keep_node]
                    cls_policy = torch.zeros(B, 1, dtype=keep_policy.dtype, device=keep_policy.device)
                    now_policy = torch.cat([cls_policy, keep_policy + 1], dim=1)
                    x = batch_index_select(x, now_policy)
                    prev_decision = batch_index_select(prev_decision, keep_policy)
                    x = block(x)
                reduction_id += 1
            else:
                with torch.no_grad() if self.freeze_backbone else contextlib.nullcontext():
                    x = block(x, policy=policy)
        
        for k, v in forward_dict.items():
            if k not in self.forward_dict:
                self.forward_dict[k] = []
            self.forward_dict[k].append(v)
        self.forward_dict['merge_function'] = self.merge_forward_dict_train
        
        out = self.handle_outcome(x)

        # check nan / inf
        if torch.isnan(out).any() or torch.isinf(out).any():
            torch.save({
                'xs': xs,
                'x': x,
                'forward_dict': forward_dict,
            }, f'/home/temp/robitic_fundation/eai-vc/cortexbench/habitat_vc/debug_info_{torch.randint(1000, (1,)).item()}.pth')
            exit(1)
        return out
    
    def merge_forward_dict_train(self, forward_dict):
        merged_dict = dict()
        for reduction_id, l in enumerate(self.reduction_layers):
            keep_mask = torch.cat(forward_dict[f'keep_decision_{l}'], dim=0)  # B x L x NP
            merged_dict[f'loss/prune_{l}'] = ((keep_mask.mean(dim=1)-self.keep_ratios[reduction_id]) ** 2).mean()
            merged_dict[f'keep_ratio_{l}'] = keep_mask.mean()

        # clear the forward_dict for next batch
        for k in list(forward_dict.keys()):
            del forward_dict[k]
        return merged_dict
    
    def _forward_features_eval_single_wo_handle(self, x, task=None):
        assert x.shape[0] == 1, "x should have batch size 1 in eval mode"
        current_idxes = torch.arange(x.shape[1], device=x.device)
        out_x = torch.zeros_like(x).reshape(-1, x.shape[-1])
        forward_dict = dict(merge_function=self.merge_forward_dict_eval)
        
        reduction_id = 0
        B = 1
        patch_num = x.shape[1]-1
        prev_decision = torch.ones(B, patch_num, 1, dtype=x.dtype, device=x.device)
        mask_visualize = torch.ones((B, x.shape[1]), device=x.device, dtype=torch.float32).view(-1)
        for i, block in enumerate(self.blocks):
            if i in self.reduction_layers:
                assert i == self.reduction_layers[reduction_id], f"reduction layer {i} not in {self.reduction_layers}"
                
                spatial_x = x[:, 1:]
                pred_score = self.score_predictors[reduction_id](spatial_x, prev_decision, task).reshape(B, -1, 2)
                score = pred_score[:,:,0]
                num_keep_node = int(patch_num * self.keep_ratios[reduction_id])
                keep_policy = torch.argsort(score, dim=1, descending=True)[:, :num_keep_node]
                cls_policy = torch.zeros(B, 1, dtype=keep_policy.dtype, device=keep_policy.device)
                now_policy = torch.cat([cls_policy, keep_policy + 1], dim=1)
        
                mask = torch.zeros((B, x.shape[1]), device=x.device, dtype=torch.bool).view(-1)
                mask[now_policy] = 1
                mask = mask.view(B, x.shape[1])

                assert mask_visualize.sum() == x.shape[1], f"{mask_visualize.sum()}, {x.shape}"
                cur_idx = torch.nonzero(mask_visualize>0)[:, 0]
                new_idx = cur_idx[now_policy]
                mask_visualize.fill_(0)
                mask_visualize[new_idx] = 1

                if _VISUALIZE and reduction_id == 2:
                    forward_dict['visualize_mask'] = mask_visualize.view(B, -1)[:, 1:].clone()
                    
                
                small_target_idx = current_idxes[~mask.flatten()]
                out_x[small_target_idx] = x[~mask].reshape(-1, x.shape[-1])
                current_idxes = current_idxes[mask.flatten()]
                
                x = batch_index_select(x, now_policy)
                prev_decision = batch_index_select(prev_decision, keep_policy)
                reduction_id += 1
            x = block(x)
        out_x[current_idxes] = x.reshape(-1, x.shape[-1])
        out_x = out_x.reshape(1, -1, out_x.shape[-1])
        return out_x, forward_dict
    

    def forward_features_eval(self, x, hidden_states=None):
        B = x.shape[0]
        x = self.patch_embed(x)
        # add pos embed w/o cls token
        x = x + self.pos_embed[:, 1:, :]

        # masking: length -> length * mask_ratio
        if self.mask_ratio is not None:
            x, _, _ = self.random_masking(x, mask_ratio=self.mask_ratio)

        # append cls token
        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        x = torch.cat((cls_token.expand(B, -1, -1), x), dim=1)

        if self.use_task:
            task = self.task_embed(hidden_states)
        else:
            task = None
        
        if B == 1:
            # single batch, no need to loop
            x, forward_dict = self._forward_features_eval_single_wo_handle(x, task)
            self.forward_dict = [forward_dict]
            return self.handle_outcome(x)
        xs = list()
        forward_dicts = list()
        for b in range(B):
            ctask = task[b:b+1] if task is not None else None
            cur_x, forward_dict = self._forward_features_eval_single_wo_handle(x[b:b+1], ctask)
            xs.append(cur_x)
            forward_dicts.append(forward_dict)
        x = torch.cat(xs, dim=0)  # B x L x D
        self.forward_dict = forward_dicts
        return self.handle_outcome(x)
    
    def merge_forward_dict_eval(self, forward_dicts):
        agent_metrics = dict()
        detailed_metrics = dict()
        
        return dict(agent_metrics=agent_metrics, detailed_metrics=detailed_metrics)
    
    def forward_features(self, x, hidden_states=None, masks=None):
        if self.training:
            features = self.forward_features_train(x, hidden_states)
            return features
        else:
            features = self.forward_features_eval(x, hidden_states)
            return features

    def forward(self, x, hidden_states=None, masks=None):
        return self.forward_features(x, hidden_states, masks)


def dynamic_deit_tiny_patch16_224(**kwargs):
    kwargs.pop('requires_state_keys', None)
    model = DynamicVisionTransformer(
        patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs
    )
    return model
