from .logic_engine import LogicEngine
from .constants import FIX
import torch

class DimProspector(LogicEngine):
    tau = 20
    FIX = None

    @classmethod
    def activate(cls, tau=20):
        cls.set_flag(True)
        cls.fix_dim(cls.llm_name)
        cls.tau = tau
        print(f"{cls.__name__} flag set to {cls._flag()}")

    @classmethod
    def fix_dim(cls, llm_name="llama-7b"):
        cls.FIX = FIX[llm_name]
        print(f"{cls.__name__} fixed dims: {cls.FIX}")

    @classmethod
    def rmsnorm(cls, hidden_states, eps=1e-6):
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        return hidden_states * torch.rsqrt(variance + eps)

    @classmethod
    def run_logic(cls, hs, layer):
        rms_norm_hs = torch.abs(cls.rmsnorm(hs)) # [bsz, tok, dim]
        rms_values = torch.stack([rms_norm_hs[:, :, idx] for idx in cls.FIX], dim=-1) # [bsz, tok, 2]
        max_rms_values = torch.max(rms_values, dim=-1)[0] # [bsz, tok]
        indices = torch.nonzero(max_rms_values > cls.tau)[:, 1] # [batch_axis, token_axis] -> [token_axis]
        cls.__base__.indices[layer]=indices

class HeadFork(LogicEngine):
    rho = 0.2
    summ = 0.2 # FIX

    @classmethod    
    def activate(cls, rho, summ):
        cls.set_flag(True)
        cls.rho = rho
        cls.summ = summ
        print(f"{cls.__name__} flag set to {cls._flag()}, rho: {cls.rho}, summ: {cls.summ}")

    @classmethod
    def run_logic(cls, attn):
        layer = cls.current_decoder_layer
        if isinstance(cls.sink_select_layers, list):
            sink_inds = cls.indices[layer] # [1,...]
        else:
            sink_inds = cls.indices[cls.sink_select_layers]

        im, pa = cls.begin_pos["image"], cls.vis_len
        vis_sink_inds = [i.unsqueeze(0) for i in sink_inds if im <= i < im+pa] # [1, 2, 3, ... ]

        if len(vis_sink_inds) > 0:
            vis_sink_inds = torch.cat(vis_sink_inds, dim=0) # shape: torch.Size([n])
            image_attn = attn[:, :, :, im:im+pa]

            portion = torch.sum(image_attn[:, :, :, vis_sink_inds-im], dim=-1) / torch.sum(image_attn + 1e-6, dim=-1) # [bsz, head, query]
            summation = torch.sum(image_attn, dim=-1)  # [bsz, head, query]

            # condition 1: portion <= cls.rho
            portion_condition = portion <= cls.rho

            # condition 2: summation >= cls.summ
            summation_condition = summation >= cls.summ

            candidate_coords = torch.nonzero( portion_condition & summation_condition )
            cls.__base__.forked_head[layer] = candidate_coords.clone()
        else:
            cls.__base__.forked_head[layer] = []
        
        cls.__base__.forked_head_per_token[cls.output_token_count][layer] = cls.__base__.forked_head[layer]
        
        return
