import torch
import copy 
from ...logic_engine import LogicEngine

class VARProc(LogicEngine):
    P = .4
    THRES = .5
    TARGET_HEADS = 0
    TARGET_LAYERS = 0
    STEP = ""
    selected_tokens = []
    vis_only = False
    except_last_layer = False

    @classmethod
    def activate(cls, p):
        cls.set_flag(True)
        cls.P = p
        print(f"{cls.__name__} flag set to {cls._flag()}, p: {cls.P}")

    @classmethod
    def config_last_layer(cls, flag):
        cls.except_last_layer = flag

    @classmethod
    def config_head(cls, head_num):
        cls.TARGET_HEADS = list(map(int, head_num.split()))

    @classmethod
    def config_layer(cls, layer_num):
        cls.TARGET_LAYERS = list(map(int, layer_num.split()))

    @classmethod
    def set_selected_token(cls, selected_tokens):
        cls.selected_tokens = selected_tokens

    @classmethod
    def check_target_layer(cls):
        return cls.TARGET_LAYERS

    @classmethod
    def what_step(cls):
        return cls.STEP

    @classmethod
    def set_step(cls, step):
        cls.STEP = step

    @classmethod
    def redistribution(cls, attention_map):
        """
        attention map: # torch.Size([1, 32, 632(1), 576])
        """
        if cls.current_decoder_layer < 2:
            return attention_map
        
        # ablation study
        if cls.except_last_layer and cls.current_decoder_layer == cls.model_config.num_hidden_layers - 1:
            return attention_map    

        im, pa = cls.begin_pos["image"], cls.vis_len
        p = cls.P

        if cls.indices:
            if isinstance(cls.sink_select_layers, list) and cls.current_decoder_layer in cls.sink_select_layers:
                indices = cls.indices[cls.current_decoder_layer]
            elif isinstance(cls.sink_select_layers, int) and cls.current_decoder_layer == cls.sink_select_layers:
                indices = cls.indices[cls.sink_select_layer]
        else:
            raise NotImplementedError("indices not found")

        if cls._flag("HeadFork"):
            coord = cls.forked_head[cls.current_decoder_layer]

        if len(coord) > 0: # if catched a valid head for query tokens
            model_head_num = cls.model_config.num_attention_heads
            for h in range(model_head_num):
                query_coord = coord[coord[:, 1]==h][:,2] # size: [bsz, H, Q] -> [H, Q]
                query_coord = query_coord[im+pa<=query_coord] if cls.output_token_count == 0 else query_coord # TODO
                bsz_coord = coord[coord[:, 1] == h][:,0][:len(query_coord)]
                head_coord = coord[coord[:, 1]==h][:,1][:len(query_coord)]

                if not query_coord.shape[0] or not head_coord.shape[0]:
                    continue

                # Attention map selection & split sink token indices
                selected_attn_map = attention_map[bsz_coord, head_coord, query_coord, :].clone()
                indices = indices.to(selected_attn_map.device)
                vis_indices = indices[(im<=indices) & (indices<im+pa)]
                text_indices = indices[~torch.isin(indices, vis_indices)]

                # Copy only the attention map for manipulation.
                copied_attention_map = copy.deepcopy(selected_attn_map.detach())  # [Q, K]

                # Decrease the portion of the selected_attn_map corresponding to the sink token by p.
                selected_attn_map[:, text_indices] *= p
                selected_attn_map[:, vis_indices] *= p 

                # Calculate the attention weight of some sink tokens that can be distributed. (1-p)
                weight_budget_vis = copied_attention_map[:, vis_indices].sum(dim=1) * (1 - p)
                weight_budget_text = copied_attention_map[:, text_indices].sum(dim=1) * (1 - p)

                # Set all attention weights corresponding to the sink token to 0. (to get the ratio of non-sink token values)
                # copied_attention_map[:, text_indices] *= 0
                copied_attention_map[:, vis_indices] *= 0  

                # Find the weight ratio of the un-sink tokens to the vision tokens.
                ratios_vis = copied_attention_map[:, im:im+pa] / copied_attention_map[:, im:im+pa].sum(dim=1, keepdim=True).to(selected_attn_map.dtype)                

                # Combines the vision and text budget and allocates it to the vision token.
                selected_attn_map[:, im:im+pa] += (weight_budget_vis + weight_budget_text).view(-1,1) * ratios_vis
                attention_map[bsz_coord, head_coord, query_coord, :] = selected_attn_map
        return attention_map
