# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# adapted from:
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
# DeiT: https://github.com/facebookresearch/deit
# --------------------------------------------------------
from collections import OrderedDict
from math import prod
import os
from functools import partial

import numpy as np
import timm.models.vision_transformer
import torch
import torch.nn as nn
import torch.nn.functional as F
from vc_models.models.vit import model_utils
from timm.models.vision_transformer import resize_pos_embed, Mlp, Attention, Block
from timm.models.crossvit import CrossAttention
import logging
from habitat_vc.models.freeze_batchnorm import convert_frozen_batchnorm
from .vit import VisionTransformer
from . import _VISUALIZE

logger = logging.getLogger(__name__)


class AttentionWithMask(Attention):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__(dim, num_heads, qkv_bias, attn_drop, proj_drop)
        self.mask = None

    def forward(self, x, key_mask=None, attn_mask=None):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)   # make torchscript happy (cannot use tensor as tuple) B x num_heads x N x C//num_heads

        attn = (q @ k.transpose(-2, -1)) * self.scale # B x num_heads x N x N
        if key_mask is not None:
            assert key_mask.shape == (B, N), "key_mask should have shape (B, N)"
            attn = attn.masked_fill(key_mask.unsqueeze(1).unsqueeze(2), float("-inf"))
        if attn_mask is not None:
            assert attn_mask.shape == (B, N, N), "attn_mask should have shape (B, N, N)"
            attn = attn + attn_mask.unsqueeze(1)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x
        
class BlockWithAttention(Block):

    def __init__(
            self,
            dim,
            num_heads,
            mlp_ratio=4.,
            qkv_bias=False,
            drop=0.,
            attn_drop=0.,
            init_values=None,
            drop_path=0.,
            act_layer=nn.GELU,
            norm_layer=nn.LayerNorm
    ):
        super().__init__(
            dim=dim,
            num_heads=num_heads,
            mlp_ratio=mlp_ratio,
            qkv_bias=qkv_bias,
            drop=drop,
            attn_drop=attn_drop,
            init_values=init_values,
            drop_path=drop_path,
            act_layer=act_layer,
            norm_layer=norm_layer
        )
        self.attn = AttentionWithMask(
            dim=dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            attn_drop=attn_drop,
            proj_drop=drop
        )

    def forward(self, x, key_mask=None, attn_mask=None):
        x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), 
                                            key_mask=key_mask, attn_mask=attn_mask)))
        x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
        return x
    

class TaskVisionTransformer(VisionTransformer):
    def __init__(
        self, reduction_layers=(3, 6), hidden_state_dim=2048+32, num_prototypes=10, selection_tau=1.0, cluster_tau=1.0,
        freeze_backbone=False, freeze_batchnorm=False, loss_cfg=None, **kwargs):
        assert 'block_fn' not in kwargs, "block_fn should not be set in kwargs, use reduction_layers instead"
        kwargs['block_fn'] = BlockWithAttention
        super().__init__(**kwargs)
        self.reduction_layers = reduction_layers
        self.hidden_state_dim = hidden_state_dim
        self.embedding_dim = self.num_features
        self.selection_tau = selection_tau
        self.cluster_tau = cluster_tau
        self.freeze_backbone = freeze_backbone
        self.freeze_batchnorm = freeze_batchnorm
        self.loss_cfg = loss_cfg

        self.prototypes = nn.Parameter(torch.randn(num_prototypes, self.embedding_dim))
        self.layer_prototype = nn.Parameter(torch.zeros(len(reduction_layers), num_prototypes, self.embedding_dim))
        self.prototype_gate = Mlp(hidden_state_dim, self.embedding_dim, num_prototypes, act_layer=nn.GELU, drop=0.0)
        self.forward_dict = dict()
    
    def train(self, mode = True):
        super().train(mode)
        
        if mode:
            if self.freeze_backbone:
                for p in self.parameters():
                    p.requires_grad = False
                if self.freeze_batchnorm:
                    self.norm_pre = convert_frozen_batchnorm(self.norm_pre)
                    self.blocks = convert_frozen_batchnorm(self.blocks)
                    self.norm = convert_frozen_batchnorm(self.norm)
                    self.fc_norm = convert_frozen_batchnorm(self.fc_norm)
                for p in self.prototype_gate.parameters():
                    p.requires_grad = True
                self.layer_prototype.requires_grad = True
                self.prototypes.requires_grad = True
    
    def forward_vit(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)

        # add pos embed w/o cls token
        x = x + self.pos_embed[:, 1:, :]

        # masking: length -> length * mask_ratio
        if self.mask_ratio is not None:
            x, _, _ = self.random_masking(x, mask_ratio=self.mask_ratio)

        # append cls token
        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        x = torch.cat((cls_token.expand(B, -1, -1), x), dim=1)

        x = self.blocks(x)
        return self.handle_outcome(x)
    
    def forward_features_train(self, x, hidden_states):
        forward_dict = dict()
        B = x.shape[0]
        x = self.patch_embed(x)
        # add pos embed w/o cls token
        x = x + self.pos_embed[:, 1:, :]

        # masking: length -> length * mask_ratio
        if self.mask_ratio is not None:
            x, _, _ = self.random_masking(x, mask_ratio=self.mask_ratio)

        # append cls token
        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        x = torch.cat((cls_token.expand(B, -1, -1), x), dim=1)
        prototype_scores = self.prototype_gate(hidden_states) # B x NP
        prototypes = F.normalize(self.layer_prototype + self.prototypes.unsqueeze(0), dim=-1).unsqueeze(1).expand(
            len(self.reduction_layers), B, self.prototypes.shape[0], self.embedding_dim)  # L x B x NP x D
        forward_dict['prototype_scores'] = prototype_scores
        
        # forward blocks
        xs = []
        reduction_id = 0
        keep_probs = [torch.ones(B, x.shape[1], device=x.device, dtype=x.dtype)]
        for i, block in enumerate(self.blocks):
            if i in self.reduction_layers:
                assert i == self.reduction_layers[reduction_id], f"reduction layer {i} not in {self.reduction_layers}"
                xs.append(x)
                token_norm = F.normalize(x, dim=-1)      # [B, L, D]
                prototype_match_sim = torch.einsum("bld,bkd->blk", token_norm, prototypes[reduction_id]) / self.cluster_tau  # [B, L, NP]
                prototype_match_prob = torch.softmax(prototype_match_sim, dim=-1) # B x L x NP
                keep_logits_cur = torch.sum(prototype_match_prob * prototype_scores.unsqueeze(1), dim=-1) # B x L
                
                # gumble softmax
                keep_logits_cur = torch.stack([torch.zeros_like(keep_logits_cur), keep_logits_cur], dim=-1) # B x L x 2
                # keep_probs_cur = torch.nn.functional.gumbel_softmax(keep_logits_cur, tau=self.tau, hard=True, dim=-1)[..., 1] # B x L
                keep_probs_cur = torch.nn.functional.gumbel_softmax(keep_logits_cur, tau=self.selection_tau, dim=-1)[..., 1] # B x L
                keep_mask_cur = keep_probs_cur > 0.5
                keep_mask_cur[:, 0] = True
                for b in range(B):
                    if keep_mask_cur[b].sum() < 10:
                        keep_mask_cur[b][torch.randperm(keep_mask_cur[b].numel(), device=keep_mask_cur.device)[:10]] = True
                keep_probs_cur = keep_mask_cur.float() - keep_probs_cur.detach() + keep_probs_cur
                keep_probs_cur = keep_probs_cur * keep_probs[-1]
                
                forward_dict[f'keep_probs_{i}'] = keep_probs_cur
                forward_dict[f'prototype_match_sim_{i}'] = prototype_match_sim
                forward_dict[f'prototype_match_prob_{i}'] = prototype_match_prob
                forward_dict[f'features_{i}'] = x
                forward_dict[f'prototypes_{i}'] = prototypes[reduction_id]
                keep_probs.append(keep_probs_cur) # B x L
                reduction_id += 1
            x = block(x, key_mask=~keep_probs[-1].bool())
        xs.append(x)
        
        keep_probs = torch.stack(keep_probs, dim=0) # L x B x L
        # stoped_layer_mask: one-hot map indicating which layer each token stops at
        stoped_layer_mask = torch.cat([keep_probs[:-1] - keep_probs[1:], keep_probs[-1:]], dim=0) # L x B x L
        xs = torch.stack(xs, dim=0)
        x = (xs * stoped_layer_mask.unsqueeze(-1)).sum(dim=0) # B x L x D
        
        for k, v in forward_dict.items():
            if k not in self.forward_dict:
                self.forward_dict[k] = []
            self.forward_dict[k].append(v)
        self.forward_dict['merge_function'] = self.merge_forward_dict_train
        
        out = self.handle_outcome(x)

        # check nan / inf
        if torch.isnan(out).any() or torch.isinf(out).any():
            torch.save({
                'xs': xs,
                'x': x,
                'stoped_layer_mask': stoped_layer_mask,
                'forward_dict': forward_dict,
                'prototypes': prototypes,
                'prototype_scores': prototype_scores,
                'hidden_states': hidden_states,
            }, f'/home/temp/robitic_fundation/eai-vc/cortexbench/habitat_vc/debug_info_{torch.randint(1000, (1,)).item()}.pth')
            exit(1)
        return out
    
    def merge_forward_dict_train(self, forward_dict):
        merged_dict = self.calculate_loss_from_forward_dict(forward_dict)
        
        with torch.no_grad():
            for reduction_id, l in enumerate(self.reduction_layers):
                match_sim = torch.cat(forward_dict[f'prototype_match_sim_{l}'], dim=0)  # B x L x NP
                match_ratios = list()
                for pid in range(self.prototypes.shape[0]):
                    match_ratios.append((torch.argmax(match_sim, dim=-1) == pid).float().mean().item())
                merged_dict[f'prune/match_ratio/layer{l}_std'] = np.std(match_ratios)
                merged_dict[f'prune/match_ratio/layer{l}_min'] = np.min(match_ratios)
                merged_dict[f'prune/match_ratio/layer{l}_max'] = np.max(match_ratios)
                match_probs = torch.cat(forward_dict[f'prototype_match_prob_{l}'], dim=0)  # B x L x NP
                # calculate match_prob stats for each prototype
                match_min = match_probs.min(dim=-1).values.mean().item()
                match_max = match_probs.max(dim=-1).values.mean().item()
                merged_dict[f'prune/match_prob/layer{l}_min'] = match_min
                merged_dict[f'prune/match_prob/layer{l}_max'] = match_max

        # clear the forward_dict for next batch
        for k in list(forward_dict.keys()):
            del forward_dict[k]
        return merged_dict
    
    def calculate_loss_from_forward_dict(self, forward_dict):
        """
        Calculate the loss from the forward_dict.
        """
        result = OrderedDict()
        if 'keep_loss' in self.loss_cfg:
            for reduction_id, l in enumerate(self.reduction_layers):
                keep_probs = torch.cat(forward_dict[f'keep_probs_{l}'], dim=0)  # L x B x L
                keep_ratio = keep_probs.mean()
                result[f'keep_ratio_{l}'] = keep_ratio.item()
                result[f'loss/keep_{l}'] = ((keep_ratio - self.loss_cfg.keep_loss.ratios[reduction_id])**2) * self.loss_cfg.keep_loss.weights[reduction_id]
        if 'entropy_loss' in self.loss_cfg:
            for reduction_id, l in enumerate(self.reduction_layers):
                prototype_match_sim = torch.cat(forward_dict[f'prototype_match_sim_{l}'], dim=0)  # B x L x NP
                prototype_match_prob = torch.cat(forward_dict[f'prototype_match_prob_{l}'], dim=0)  # B x L x NP
                prototype_match_prob_log = torch.log_softmax(prototype_match_sim, dim=-1)  # B x L x NP
                entropy = -(prototype_match_prob * prototype_match_prob_log).sum(dim=-1).mean()  # B
                result[f'loss/entropy_{l}'] = entropy * self.loss_cfg.entropy_loss.weights[reduction_id]
                # import ipdb; ipdb.set_trace()  # DEBUG: check the entropy loss
        if 'usage_dist_loss' in self.loss_cfg:
            for reduction_id, l in enumerate(self.reduction_layers):
                prototype_match_prob = torch.cat(forward_dict[f'prototype_match_prob_{l}'], dim=0)  # B x L x NP
                prototype_match_prob = prototype_match_prob.flatten(0, 1).mean(dim=1)  # NP
                uniform_dist = torch.ones_like(prototype_match_prob) / prototype_match_prob.shape[-1]
                kl_div = torch.nn.functional.kl_div(
                    input=torch.log(prototype_match_prob + 1e-8),
                    target=uniform_dist,
                    reduction='batchmean',
                )
                result[f'loss/usage_dist_{l}'] = kl_div * self.loss_cfg.usage_dist_loss.weights[reduction_id]
        if 'cluster_loss' in self.loss_cfg:
            for reduction_id, l in enumerate(self.reduction_layers):
                prototype_match_sim = torch.cat(forward_dict[f'prototype_match_sim_{l}'], dim=0)  # B x L x NP
                features = torch.cat(forward_dict[f'features_{l}'], dim=0)  # B x L x D
                prototypes = torch.cat(forward_dict[f'prototypes_{l}'], dim=0)  # B x NP x D
                pseudo_label = torch.argmax(prototype_match_sim, dim=-1)  # B x L
                cluster_loss = torch.nn.functional.cross_entropy(
                    input=prototype_match_sim.reshape(-1, prototype_match_sim.shape[-1]),  # B*L x NP
                    target=pseudo_label.reshape(-1),  # B*L
                    reduction='mean',
                )
                result[f'loss/cluster_{l}'] = cluster_loss * self.loss_cfg.cluster_loss.weights[reduction_id]

                # prototype_match_prob = torch.cat(forward_dict[f'prototype_match_prob_{l}'], dim=0)  # B x L x NP
                # features = torch.cat(forward_dict[f'features_{l}'], dim=0) # B x L x D
                # prototypes = torch.cat(forward_dict[f'prototypes_{l}'], dim=0)  # B x NP x D
                # centers = features.unsqueeze(2) * prototype_match_prob.unsqueeze(-1) # B x L x NP x D
                # centers = centers.sum(dim=1) / (prototype_match_prob.sum(dim=1).unsqueeze(-1) + 1e-8)  # B x NP x D
                # cluster_loss = torch.nn.functional.mse_loss(
                #     input=prototypes,
                #     target=centers,
                #     reduction='mean',
                # )
                # result[f'loss/cluster_{l}'] = cluster_loss * self.loss_cfg.cluster_loss.weights[reduction_id]
        if 'repel_loss' in self.loss_cfg:
            for reduction_id, l in enumerate(self.reduction_layers):
                prototypes = torch.cat(forward_dict[f'prototypes_{l}'], dim=0)  # B x NP x D
                prototypes = F.normalize(prototypes, dim=-1)  # B x NP x D
                sim = torch.einsum("bid,bjd->bij", prototypes, prototypes)  # B x NP x NP
                sim = sim - torch.eye(sim.shape[-1], device=sim.device).unsqueeze(0) * 1e6
                repel_loss = torch.nn.functional.relu(sim-0.1).mean()  # B
                result[f'loss/repel_{l}'] = repel_loss * self.loss_cfg.repel_loss.weights[reduction_id]

        
        return result

    def _forward_features_eval_single_wo_handle(self, x, prototypes, prototype_scores):
        assert x.shape[0] == 1, "x should have batch size 1 in eval mode"
        current_idxes = torch.arange(x.shape[1], device=x.device)
        out_x = torch.zeros_like(x).reshape(-1, x.shape[-1])
        forward_dict = dict(merge_function=self.merge_forward_dict_eval)
        
        forward_dict['prototype_scores'] = prototype_scores
        reduction_id = 0
        for i, block in enumerate(self.blocks):
            if i in self.reduction_layers:
                assert i == self.reduction_layers[reduction_id], f"reduction layer {i} not in {self.reduction_layers}"
                token_norm = F.normalize(x, dim=-1)      # [B, L, D]
                prototype_match_sim = torch.einsum("bld,kd->blk", token_norm, prototypes[reduction_id]) / self.cluster_tau  # [B, L, NP]
                prototype_match_prob = torch.softmax(prototype_match_sim, dim=-1) # B x L x NP
                keep_logits_cur = torch.sum(prototype_match_prob * prototype_scores.unsqueeze(1), dim=-1) # B x L
                keep_probs_cur = torch.sigmoid(keep_logits_cur) # B x L
                
                
                mask = keep_probs_cur > 0.5
                
                small_target_idx = current_idxes[~mask.flatten()]
                out_x[small_target_idx] = x[~mask].reshape(-1, x.shape[-1])
                current_idxes = current_idxes[mask.flatten()]
                x = x[mask].reshape(1, -1, x.shape[-1])
                
                forward_dict[f'keep_probs_{i}'] = keep_probs_cur
                forward_dict[f'prototype_match_sim_{i}'] = prototype_match_sim
                forward_dict[f'prototype_match_prob_{i}'] = prototype_match_prob
                forward_dict['visualize_mask'] = keep_probs_cur[:, 1:]
                reduction_id += 1
            x = block(x)
        out_x[current_idxes] = x.reshape(-1, x.shape[-1])
        out_x = out_x.reshape(1, -1, out_x.shape[-1])
        return out_x, forward_dict
    

    def forward_features_eval(self, x, hidden_states):
        B = x.shape[0]
        x = self.patch_embed(x)
        # add pos embed w/o cls token
        x = x + self.pos_embed[:, 1:, :]

        # masking: length -> length * mask_ratio
        if self.mask_ratio is not None:
            x, _, _ = self.random_masking(x, mask_ratio=self.mask_ratio)

        # append cls token
        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        x = torch.cat((cls_token.expand(B, -1, -1), x), dim=1)
        prototype_scores = self.prototype_gate(hidden_states) # B x NP
        prototypes = F.normalize(self.layer_prototype + self.prototypes.unsqueeze(0), dim=-1) # L x NP x D
        
        if B == 1:
            # single batch, no need to loop
            x, forward_dict = self._forward_features_eval_single_wo_handle(x, prototypes, prototype_scores)
            self.forward_dict = [forward_dict]
            return self.handle_outcome(x)
        xs = list()
        forward_dicts = list()
        for b in range(B):
            cur_x, forward_dict = self._forward_features_eval_single_wo_handle(x[b:b+1], prototypes, prototype_scores[b:b+1])
            xs.append(cur_x)
            forward_dicts.append(forward_dict)
        x = torch.cat(xs, dim=0)  # B x L x D
        self.forward_dict = forward_dicts
        return self.handle_outcome(x)
    
    def forward_features_demo(self, x, hidden_states=None):
        assert hidden_states is None, "hidden_states should be None in demo mode"
        assert x.shape[0] == 1, "x should have batch size 1 in demo mode"

        B = x.shape[0]
        x = self.patch_embed(x)
        # add pos embed w/o cls token
        x = x + self.pos_embed[:, 1:, :]

        # masking: length -> length * mask_ratio
        if self.mask_ratio is not None:
            x, _, _ = self.random_masking(x, mask_ratio=self.mask_ratio)

        # append cls token
        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        x = torch.cat((cls_token.expand(B, -1, -1), x), dim=1)

        x_out, forward_dict = self._forward_features_eval_single_wo_handle(
            x,
            prototypes=F.normalize(self.layer_prototype + self.prototypes.unsqueeze(0), dim=-1),
            prototype_scores=torch.zeros(1, self.prototypes.shape[0], device=x.device, dtype=x.dtype)
        )
        return x_out, forward_dict

    def merge_forward_dict_eval(self, forward_dicts):
        agent_metrics = dict()
        detailed_metrics = dict()
        for reduction_id, l in enumerate(self.reduction_layers):
            keep_probs = np.concatenate([forward_dict[f'keep_probs_{l}'] for forward_dict in forward_dicts], axis=0)
            keep_ratio = (keep_probs > 0.5).mean()
            agent_metrics[f'keep_ratio_{l}'] = float(keep_ratio)
            detailed_metrics[f'keep_ratio_{l}'] = agent_metrics[f'keep_ratio_{l}']
            
            match_sim = np.concatenate([forward_dict[f'prototype_match_sim_{l}'] for forward_dict in forward_dicts], axis=0)  # B x L x NP
            for pid in range(self.prototypes.shape[0]):
                agent_metrics[f'layer{l}_proto{pid}_match_ratio'] = float((np.argmax(match_sim, axis=-1) == pid).mean())
                detailed_metrics[f'layer{l}_proto{pid}_match_ratio'] = agent_metrics[f'layer{l}_proto{pid}_match_ratio']
        prototype_scores = np.concatenate([forward_dict['prototype_scores'] for forward_dict in forward_dicts], axis=0)
        for pid in range(self.prototypes.shape[0]):
            agent_metrics[f'proto{pid}_keep_ratio'] = float((prototype_scores[:, pid] > 0).mean())
            detailed_metrics[f'proto{pid}_keep_ratio'] = agent_metrics[f'proto{pid}_keep_ratio']

            detailed_metrics[f'proto{pid}_scores'] = prototype_scores[:, pid].tolist()
        
        return dict(agent_metrics=agent_metrics, detailed_metrics=detailed_metrics)
    
    def forward_features(self, x, hidden_states, masks):
        if self.training:
            features = self.forward_features_train(x, hidden_states)
            return features
        else:
            features = self.forward_features_eval(x, hidden_states)
            return features

    def forward(self, x, hidden_states, masks):
        return self.forward_features(x, hidden_states, masks)

def task_deit_tiny_patch16_224(**kwargs):
    kwargs.pop('requires_state_keys')
    model = TaskVisionTransformer(
        patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs
    )
    return model
