from functools import partial
import os
import warnings
from typing import Callable, Dict, Optional, Tuple, Union
import math
import random

import torch
from torch import nn


from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
from transformers.generation import GenerationMixin
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
    TokenClassifierOutput,
)
from transformers.modeling_layers import GenericForSequenceClassification, GenericForTokenClassification, GradientCheckpointingLayer
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from transformers.modeling_utils import PreTrainedModel
from transformers.processing_utils import Unpack
from transformers.models.qwen2.modeling_qwen2 import ALL_ATTENTION_FUNCTIONS, eager_attention_forward
from transformers.masking_utils import (
    create_causal_mask,
    create_sliding_window_causal_mask,
)
from transformers.utils import (
    TransformersKwargs,
)
from transformers.models.qwen2.modeling_qwen2 import (
    Qwen2Config,
    Qwen2Model,
    Qwen2PreTrainedModel,
    Qwen2RotaryEmbedding,
    Qwen2RMSNorm,
    Qwen2DecoderLayer,
    Qwen2Attention,
    Qwen2MLP,
    apply_rotary_pos_emb,
    rotate_half,
    repeat_kv,
)
from torch.nn.attention.flex_attention import BlockMask, create_block_mask
from .configuration_multi_seq_calibrator import MultiSeqCalibratorConfig
from .mask_utils import create_attention_masks

class MultiSeqCalibratorPreTrainedModel(Qwen2PreTrainedModel):
    config_class = MultiSeqCalibratorConfig

class ModifiedQwen2Attention(nn.Module):
    def __init__(self, module: Qwen2Attention):
        super().__init__()
        self.config = module.config
        self.layer_idx = module.layer_idx
        self.head_dim = module.head_dim
        self.num_key_value_groups = module.num_key_value_groups
        self.scaling = module.scaling
        self.attention_dropout = module.attention_dropout
        self.is_causal = module.is_causal
        self.q_proj = module.q_proj
        self.k_proj = module.k_proj
        self.v_proj = module.v_proj
        self.o_proj = module.o_proj
        self.sliding_window = module.sliding_window

        self.attn_types = [attn_type.strip() for attn_type in self.config.attn_types.split(",")]
        self.gating_lambdas = dict()
        for attn_type in self.attn_types:
            assert attn_type in ["omni", "omni_intra", "omni_bin", "omni_indiv", "causal", "causal_intra", "causal_bin", "causal_indiv", "recent", "recent_bin", "last", "last_bin"]
            self.gating_lambdas[attn_type] = nn.Parameter(torch.zeros(self.config.num_attention_heads))
        self.gating_lambdas = nn.ParameterDict(self.gating_lambdas)
        
    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor],
        past_key_value: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
        # if random.random() < 0.1:
        #     print(f"Gating lambdas: {[self.gating_lambdas[key] for key in self.gating_lambdas]}")
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)

        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        attention_interface: Callable = eager_attention_forward
        if self.config._attn_implementation != "eager":
            attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

        
        if self.config._attn_implementation == "flex_attention":
            kernel_options = {
                "BLOCK_M": 64,
                "BLOCK_N": 64,
                "BLOCK_M1": 32,
                "BLOCK_N1": 64,
                "BLOCK_M2": 64,
                "BLOCK_N2": 32,
            }
            kwargs["kernel_options"] = kernel_options


        mask_mapping = kwargs["mask_mapping"]
        if len(self.gating_lambdas) == 0:
            raise ValueError("No attention types are enabled")

        attn_outputs = dict()
        for key in self.gating_lambdas:
            attn_outputs[key], _ = attention_interface(
                self,
                query_states,
                key_states,
                value_states,
                mask_mapping[key],
                scaling=self.scaling,
                **kwargs,
            )
        
        attn_outputs = torch.stack([attn_outputs[key] for key in self.gating_lambdas]) # [num_gates, batch_size, seq_len, num_heads, head_dim]
        gatings = torch.softmax(torch.stack([self.gating_lambdas[key] for key in self.gating_lambdas]), dim=0)[:,None,None,:,None] # [num_gates, 1, 1, num_heads, 1]
        attn_output = (attn_outputs * gatings).sum(dim=0) # [batch_size, seq_len, num_heads, head_dim]

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.o_proj(attn_output)
        return attn_output, _



class MultiSeqCalibratorModel(MultiSeqCalibratorPreTrainedModel):
    def __init__(self, config: MultiSeqCalibratorConfig):
        super().__init__(config)
        self.config = config
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
        self.input_embeds_size = config.input_embeds_size
        self.architecture = config.architecture
        self.num_hidden_layers = config.num_hidden_layers
        self.group_size = config.group_size
        
        # Probe
        self.mlp_hidden_size = config.mlp_hidden_size

        # Full
        self.increment_position_ids = config.increment_position_ids
        self.max_context_len = config.max_context_len
        self.agent_emb = config.agent_emb
        
        if self.agent_emb:
            self.agent_embeddings = nn.Parameter(torch.randn(self.group_size, self.config.hidden_size) / math.sqrt(self.config.hidden_size))

        self.rotary_emb = Qwen2RotaryEmbedding(config=config)
        self.gradient_checkpointing = False
        self.has_sliding_layers = "sliding_attention" in self.config.layer_types

        if config.architecture == "probe":
            if config.mlp_hidden_size is not None:
                self.mlp = nn.Linear(config.input_embeds_size, config.mlp_hidden_size, bias=True)
                self.mlp_activation = ACT2FN[config.hidden_act]
            else:
                # Always create at least one parameter to avoid Transformers loading issues
                # This dummy parameter ensures the base model is never empty
                self.dummy_param = nn.Parameter(torch.zeros(1))

        if config.architecture == "full": 
            self.layers = nn.ModuleList(
                [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
            )
            for layer in self.layers:
                layer.self_attn = ModifiedQwen2Attention(layer.self_attn)
            self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

        self.node_features = [x.strip() for x in config.node_features.split(",")] if config.node_features!="" else []
        if len(self.node_features) > 0:
            if not config.no_early_node_features_projection:
                self.node_features_projection = nn.Sequential(
                    nn.Linear(len(self.node_features), config.hidden_size*2, bias=True),
                    ACT2FN[config.hidden_act],
                    nn.Linear(config.hidden_size*2, config.hidden_size, bias=True),
                )
            if config.late_node_features_projection:
                self.node_features_projection_late = nn.Sequential(
                    nn.Linear(len(self.node_features), config.hidden_size*2, bias=True),
                    ACT2FN[config.hidden_act],
                    nn.Linear(config.hidden_size*2, config.hidden_size, bias=True),
                )
                if config.late_node_features_projection_norm:
                    self.post_features_norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

        self.input_embeds_proj = None
        if config.architecture != "probe" and config.input_embeds_size != config.hidden_size:
            self.input_embeds_proj = nn.Linear(config.input_embeds_size, config.hidden_size, bias=False)
        
        if config.group_softmax:
            self.query_abstain = nn.Parameter(torch.randn(config.hidden_size) / math.sqrt(config.hidden_size))
            self.key_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
            self.value_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)

        self.bin_aggregate = config.bin_aggregate
        self.causal_bin_aggregate = config.causal_bin_aggregate

        # Initialize weights and apply final processing
        self.post_init()
    
    def get_count_features(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> torch.Tensor:
        B, T, _ = inputs_embeds.shape
        G = self.group_size
        N = B // G

        concat_position_ids = position_ids.view(N, G*T)
        concat_attention_mask = attention_mask.view(N, G*T)
        concat_bin_idx = kwargs["bin_idx"].view(N, G*T)
        agent_idx = torch.arange(0, G, device=concat_bin_idx.device)[:,None].expand(-1, T).long().reshape(1, G*T) # [N, G*T]
        section_mask = kwargs["section_mask"] # [B, T]
        assert section_mask.shape == (B, T)
        is_last = torch.zeros_like(section_mask, dtype=torch.bool)
        for b in range(B):
            last_section_b = section_mask[b].max().item()
            is_last[b] = section_mask[b] == last_section_b
        is_last = is_last.reshape(N, G*T) # [N, G*T]

        feature_tensors = []

        count_features = [x for x in self.node_features if "count" in x]
        for feature in count_features:
            feature_tensor = torch.zeros_like(concat_bin_idx).float() # [N, G*T]
            for n in range(N):
                valid_pos = concat_bin_idx[n] != -100 # [G*T]
                valid_position_ids = concat_position_ids[n][valid_pos]
                valid_bin_idx = concat_bin_idx[n][valid_pos]
                valid_attention_mask = concat_attention_mask[n][valid_pos]
                valid_agent_idx = agent_idx[0][valid_pos] # [valid]
                valid_is_last = is_last[n][valid_pos] # [valid]

                same_bins_edge_matrix = valid_bin_idx[:,None] == valid_bin_idx[None,:] # [valid, valid]
                if feature == "omni_count":
                    feature_tensor[n][valid_pos] = same_bins_edge_matrix.sum(dim=-1) / G # [valid]
                elif feature == "omni_log_count":
                    feature_tensor[n][valid_pos] = torch.log(torch.clamp(same_bins_edge_matrix.sum(dim=-1) / G, min=1e-4)) # [valid]
                elif feature == "causal_count":
                    causal_mask = valid_position_ids[:,None] >= valid_position_ids[None,:] # [valid, valid]
                    feature_tensor[n][valid_pos] = (causal_mask & same_bins_edge_matrix).sum(dim=-1) / causal_mask.sum(dim=-1) # [valid]
                elif feature == "recent_count":
                    causal_mask = valid_position_ids[:,None] >= valid_position_ids[None,:] # [valid, valid]
                    recent_mask = torch.zeros_like(causal_mask)
                    for i in range(len(recent_mask)):
                        for g in range(G):
                            agent_g_pos = torch.nonzero((valid_agent_idx==g) & causal_mask[i])
                            if len(agent_g_pos) > 0:
                                recent_mask[i,agent_g_pos[-1]] = True
                    feature_tensor[n][valid_pos] = (recent_mask & same_bins_edge_matrix).sum(dim=-1) / G # [valid]
                elif feature == "recent_log_count":
                    causal_mask = valid_position_ids[:,None] >= valid_position_ids[None,:] # [valid, valid]
                    recent_mask = torch.zeros_like(causal_mask)
                    for i in range(len(recent_mask)):
                        for g in range(G):
                            agent_g_pos = torch.nonzero((valid_agent_idx==g) & causal_mask[i])
                            if len(agent_g_pos) > 0:
                                recent_mask[i,agent_g_pos[-1]] = True
                    feature_tensor[n][valid_pos] = torch.log(torch.clamp((recent_mask & same_bins_edge_matrix).sum(dim=-1) / G, min=1e-4)) # [valid]
                elif feature == "last_count":
                    feature_tensor[n][valid_pos] = (valid_is_last[None,:] & same_bins_edge_matrix).sum(dim=-1) / G # [valid]

            feature_tensors.append(feature_tensor.to(inputs_embeds.dtype))
        features = torch.stack(feature_tensors, dim=-1) # [N, G*T, num_features]
        section_mask_reshaped = kwargs["section_mask"].view(N, G*T) + agent_idx * 10_000 # [N, G*T]

        for n in range(N):
            valid_pos = concat_bin_idx[n] != -100
            for p in torch.where(valid_pos)[0]:
                features[n,section_mask_reshaped[n]==section_mask_reshaped[n,p]] = features[n,p].clone()
        
        features = features.reshape(B, T, -1)
        return features


    def get_node_features(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> torch.Tensor:
        features = []
        if any(["count" in feature for feature in self.node_features]):
            count_features = self.get_count_features(
                input_ids,
                attention_mask,
                position_ids,
                past_key_values,
                inputs_embeds,
                use_cache,
                cache_position,
                **kwargs,
            )
            features.append(count_features) # [B, T, num_features]
        
        if "max_log_probs" in self.node_features:
            max_log_probs = kwargs["max_log_probs"] # [B, T]
            features.append(max_log_probs[:,:,None].to(inputs_embeds.dtype)) # [B, T, 1]
        
        if "position_ids" in self.node_features:
            features.append(position_ids[:,:,None].to(inputs_embeds.dtype) / self.max_context_len) # [B, T, 1]
        
        features = torch.cat(features, dim=-1) # [B, T, num_features]
        return features




    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> BaseModelOutputWithPast:
        # if random.random() < 0.1:
        #     print(f"Agent embeddings norm: {self.agent_embeddings.norm(dim=-1)}")
        assert input_ids is None and inputs_embeds is not None
        
        hidden_states = inputs_embeds
        B, T, _ = hidden_states.shape
        G = self.group_size
        N = B // G

        assert N > 0

        if self.input_embeds_proj is not None:
            hidden_states = self.input_embeds_proj(hidden_states)

        
        if len(self.node_features) > 0:
            features = self.get_node_features(
                input_ids,
                attention_mask,
                position_ids,
                past_key_values,
                inputs_embeds,
                use_cache,
                cache_position,
                **kwargs,
            )
            if not self.config.no_early_node_features_projection:
                projected_features = self.node_features_projection(features)
                hidden_states = hidden_states + projected_features

        if self.agent_emb:
            hidden_states = hidden_states.reshape(N, G, T, -1) + self.agent_embeddings[None,:,None,:]
            hidden_states = hidden_states.reshape(B, T, -1)
            
        if self.architecture == "probe":
            if self.mlp_hidden_size is not None:
                hidden_states = self.mlp(hidden_states)
                hidden_states = self.mlp_activation(hidden_states)
        elif self.architecture == "full":
            assert position_ids is not None
            incremented_position_ids = position_ids + (torch.arange(0, B, device=position_ids.device).unsqueeze(1) % G) * self.max_context_len
            concat_position_ids = position_ids.view(N, G*T)
            concat_incremented_position_ids = incremented_position_ids.view(N, G*T)

            hidden_states = hidden_states.view(N, G*T, -1)
            assert attention_mask is not None
            attention_mask = attention_mask.view(N, G*T).bool()

            for n in range(N):
                if attention_mask[n].sum() > 20_000:
                    return "too_long"
            
            if self.increment_position_ids:
                position_embeddings = self.rotary_emb(hidden_states, concat_incremented_position_ids)
            else:
                position_embeddings = self.rotary_emb(hidden_states, concat_position_ids)

            # Get attention types from the first layer (all layers should have the same types)
            attn_types = self.layers[0].self_attn.attn_types
            
            # Create attention masks based on required types
            mask_mapping = create_attention_masks(
                attention_mask=attention_mask,
                position_ids=concat_position_ids,
                bin_idx=kwargs["bin_idx"],
                section_mask=kwargs["section_mask"],
                attn_types=attn_types,
                group_size=G,
                use_block_mask=(self.config._attn_implementation == "flex_attention"),
            )
        
            for decoder_layer in self.layers[: self.num_hidden_layers]:
                # Process only valid tokens for each batch to save memory
                hidden_states_processed = torch.zeros_like(hidden_states)
                for n in range(N):
                    attn_pos = attention_mask[n] # [G*T]
                    if attn_pos.sum() == 0:
                        continue

                    hidden_states_valid = hidden_states[n][attn_pos] # [valid_tokens, hidden_size]

                    # Create position embeddings for valid tokens only with batch dimension
                    cos_valid = position_embeddings[0][n][attn_pos].unsqueeze(0) # [1, valid_tokens, head_dim]
                    sin_valid = position_embeddings[1][n][attn_pos].unsqueeze(0) # [1, valid_tokens, head_dim]
                    position_embeddings_valid = (cos_valid, sin_valid)


                    mask_mapping_valid = {}
                    for key, mask in mask_mapping.items():
                        mask_valid = mask[n]
                        mask_mapping_valid[key] = mask_valid

                    # Process valid tokens through decoder layer
                    hidden_states_valid = decoder_layer(
                        hidden_states_valid.unsqueeze(0),  # Add batch dimension
                        position_embeddings=position_embeddings_valid,
                        mask_mapping=mask_mapping_valid,
                        **kwargs,
                    ).squeeze(0)  # Remove batch dimension

                    # Put processed tokens back
                    hidden_states_processed[n][attn_pos] = hidden_states_valid

                hidden_states = hidden_states_processed
            

            hidden_states = hidden_states.view(B, T, -1)
            hidden_states = self.norm(hidden_states)
        
        def perform_group_softmax(hidden_states):
            if not self.config.attend_all_group_softmax:
                # perform attention pooling with self.query_abstain
                hidden_states = hidden_states.view(N, G*T, -1)
                hidden_states_append = torch.zeros(N,T,hidden_states.shape[-1]).to(hidden_states.dtype).to(hidden_states.device) # [N, T, hidden_size]
                bin_idx = kwargs["bin_idx"]
                bin_idx = bin_idx.view(N, G*T)
                for n in range(N):
                    valid_pos = bin_idx[n] != -100 # [G*T]
                    keys = self.key_proj(hidden_states[n][valid_pos]) # [valid, hidden_size]
                    values = self.value_proj(hidden_states[n][valid_pos]) # [valid, hidden_size]
                    query = self.query_abstain[None,:] # [1, hidden_size]
                    attention_scores = torch.matmul(query, keys.transpose(-2, -1)) / math.sqrt(self.config.hidden_size) # [1, valid]
                    attention_weights = torch.softmax(attention_scores, dim=-1) # [1, valid]
                    pooled_vector = torch.matmul(attention_weights, values) # [1, hidden_size]
                    hidden_states_append[n, 0] = pooled_vector[0]
                hidden_states = torch.cat([hidden_states, hidden_states_append], dim=1) # [N, G*T+T, hidden_size]
                hidden_states = hidden_states.reshape(N, G+1,T, -1)
                hidden_states = hidden_states.reshape(B+N, T, -1) # [B+N, T, hidden_size]
            else:
                hidden_states = hidden_states.view(N, G*T, -1)
                hidden_states_append = torch.zeros(N,T,hidden_states.shape[-1]).to(hidden_states.dtype).to(hidden_states.device) # [N, T, hidden_size]
                # Use the original attention_mask parameter from the forward method
                attention_mask_processed = attention_mask.view(N, G*T).bool()

                keys = self.key_proj(hidden_states) # [N, G*T, hidden_size]
                values = self.value_proj(hidden_states) # [N, G*T, hidden_size]
                query = self.query_abstain # [hidden_size]
                attention_scores = torch.einsum("d,nkd->nk", query, keys) / math.sqrt(self.config.hidden_size) # [N, G*T]
                attention_scores[~attention_mask_processed] = -float("inf")
                attention_weights = torch.softmax(attention_scores, dim=-1) # [N, G*T]
                pooled_vector = torch.einsum("nk,nkd->nd", attention_weights, values) # [N, hidden_size]
                hidden_states_append[:, 0] = pooled_vector

                hidden_states = torch.cat([hidden_states, hidden_states_append], dim=1) # [N, G*T+T, hidden_size]
                hidden_states = hidden_states.reshape(N, G+1,T, -1)
                hidden_states = hidden_states.reshape(B+N, T, -1) # [B+N, T, hidden_size]
            
            return hidden_states
        
        if self.config.group_softmax and not self.config.late_group_softmax:
            hidden_states = perform_group_softmax(hidden_states)

        
        if self.config.late_node_features_projection:
            if self.config.group_softmax and not self.config.late_group_softmax:
                hidden_states = hidden_states.reshape(N, G+1, T, -1)
                projection = self.node_features_projection_late(features) # [B, T, hidden_size]
                projection = projection.reshape(N, G, T, -1)
                hidden_states[:,:G] += projection
                hidden_states = hidden_states.reshape(B+N, T, -1)
                if self.config.late_node_features_projection_norm:
                    hidden_states = self.post_features_norm(hidden_states)
            else:
                hidden_states = hidden_states + self.node_features_projection_late(features)
                if self.config.late_node_features_projection_norm:
                    hidden_states = self.post_features_norm(hidden_states)

        
        if self.config.group_softmax and self.config.late_group_softmax:
            hidden_states = perform_group_softmax(hidden_states)
            
        if self.bin_aggregate:
            assert not self.config.group_softmax
            hidden_states = hidden_states.view(N, G*T, -1)
            bin_idx = kwargs["bin_idx"]
            bin_idx = bin_idx.view(N, G*T)
            for n in range(N):
                valid_pos = bin_idx[n] != -100 # [G*T]
                bin_idx_valid = bin_idx[n][valid_pos] # [valid]
                assert bin_idx_valid.shape[0] == G
                same_bins_edge_matrix = (bin_idx_valid[:,None] == bin_idx_valid[None,:]).to(hidden_states.dtype) # [valid, valid]
                if self.config.sum_bin_aggregate:
                    normalized_same_bins_edge_matrix = same_bins_edge_matrix
                else:
                    normalized_same_bins_edge_matrix = same_bins_edge_matrix / same_bins_edge_matrix.sum(dim=-1, keepdim=True) # [valid, valid]
                hidden_states_valid = hidden_states[n][valid_pos] # [valid, hidden_size]
                hidden_states_valid_aggregated = torch.einsum("ab,bd->ad", normalized_same_bins_edge_matrix, hidden_states_valid) # [valid, valid]
                hidden_states[n][valid_pos] = hidden_states_valid_aggregated
            hidden_states = hidden_states.reshape(B, T, -1)
        
        if self.causal_bin_aggregate:
            assert not self.config.group_softmax
            hidden_states = hidden_states.view(N, G*T, -1)
            bin_idx = kwargs["bin_idx"]
            bin_idx = bin_idx.view(N, G*T)
            for n in range(N):
                valid_pos = bin_idx[n] != -100 # [G*T]
                bin_idx_valid = bin_idx[n][valid_pos] # [valid]
                concat_position_ids = position_ids.view(N, G*T)
                concat_position_ids_valid = concat_position_ids[n][valid_pos] # [valid]
                causal_mask = concat_position_ids_valid[:,None] >= concat_position_ids_valid[None,:] # [valid, valid]
                same_bins_edge_matrix = bin_idx_valid[:,None] == bin_idx_valid[None,:] # [valid, valid]
                same_bins_edge_matrix = same_bins_edge_matrix & causal_mask
                same_bins_edge_matrix = same_bins_edge_matrix.to(hidden_states.dtype)
                normalized_same_bins_edge_matrix = same_bins_edge_matrix / same_bins_edge_matrix.sum(dim=-1, keepdim=True) # [valid, valid]
                hidden_states_valid = hidden_states[n][valid_pos] # [valid, hidden_size]
                hidden_states_valid_aggregated = torch.einsum("ab,bd->ad", normalized_same_bins_edge_matrix, hidden_states_valid) # [valid, valid]
                hidden_states[n][valid_pos] = hidden_states_valid_aggregated
            hidden_states = hidden_states.reshape(B, T, -1)

        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=past_key_values if use_cache else None,
        )


class MultiSeqCalibratorForSequenceClassification(GenericForSequenceClassification, MultiSeqCalibratorPreTrainedModel):
    pass

class MultiSeqCalibratorForTokenClassification(GenericForTokenClassification, MultiSeqCalibratorPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        
        if config.architecture == "probe":
            if config.mlp_hidden_size is not None:
                self.score = nn.Linear(config.mlp_hidden_size, config.num_labels)
            else:
                self.score = nn.Linear(config.input_embeds_size, config.num_labels)

        if config.torch_dtype is not None:
            for p in self.score.parameters():
                p.data = p.data.to(config.torch_dtype)
        
        self.architecture = config.architecture
        self.wm_group_size = None
        self.wm_type = None

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        **kwargs,
    ) -> TokenClassifierOutput:
        if self.architecture == "constant":
            logits = torch.zeros(*inputs_embeds.shape[:2], 1).to(inputs_embeds.device).to(inputs_embeds.dtype).requires_grad_(True) # [B, T, num_labels]
            outputs = TokenClassifierOutput(
                loss=None,
                logits=logits,
                hidden_states=None,
                attentions=None,
            )

        elif self.architecture == "probs":
            # input_embeds # [B, T, hidden_size]
            valid_mask = attention_mask.bool()
            inputs_embeds_valid = inputs_embeds[valid_mask] # [valid, hidden_size]
            with torch.no_grad():
                logits_valid = self.lm_head(inputs_embeds_valid).float() # [valid, vocab_size]
                max_logits_valid = logits_valid.max(dim=-1, keepdim=True).values # [valid, 1]
                max_log_probs_valid = max_logits_valid - logits_valid.logsumexp(dim=-1, keepdim=True) # [valid, 1]
                del logits_valid
            max_log_probs_valid = max_log_probs_valid.to(inputs_embeds.dtype)
            section_mask = kwargs["section_mask"] # [B, T]
            section_mask_valid = section_mask[valid_mask] # [valid]
            aggregate_mask = (section_mask_valid[:,None]==section_mask_valid[None,:]).to(max_log_probs_valid.dtype) # [valid, valid]
            aggregate_mask = aggregate_mask / aggregate_mask.sum(dim=-1, keepdim=True) # [valid, valid]
            max_log_probs_valid_aggregated = torch.einsum("ab,bc->ac", aggregate_mask, max_log_probs_valid) # [valid, 1]
            x = max_log_probs_valid_aggregated
            logits_valid = x - torch.log1p(torch.clamp(-torch.exp(x), min=-(1-1e-6)))
            
            logits = torch.zeros(*inputs_embeds.shape[:2], 1).to(inputs_embeds.device).to(inputs_embeds.dtype).requires_grad_(True) # [B, T, num_labels]
            logits[valid_mask] = logits_valid # [valid, 1]

            outputs = TokenClassifierOutput(
                loss=None,
                logits=logits,
                hidden_states=None,
                attentions=None,
            )
            
        else:
            # Replace super().forward() with custom implementation to handle "too_long" cases
            model_outputs = self.model.forward(
                input_ids,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_values=past_key_values,
                inputs_embeds=inputs_embeds,
                use_cache=use_cache,
                **kwargs,
            )

            # Handle "too_long" signal from the model
            if model_outputs == "too_long":
                # Create zero logits with the same shape as expected output
                B, T, _ = inputs_embeds.shape
                G = self.config.group_size
                N = B // G
                if self.config.group_softmax:
                    logits = torch.zeros(B+N, T, self.config.num_labels, dtype=inputs_embeds.dtype, device=inputs_embeds.device)
                else:
                    logits = torch.zeros(B, T, self.config.num_labels, dtype=inputs_embeds.dtype, device=inputs_embeds.device)

                # Detach to prevent gradients
                logits = logits.requires_grad_(True)

                outputs = TokenClassifierOutput(
                    loss=None,
                    logits=logits,
                    hidden_states=None,
                    attentions=None,
                )
            else:
                # Normal case - continue with standard processing
                sequence_output = model_outputs.last_hidden_state
                sequence_output = self.dropout(sequence_output)
                logits = self.score(sequence_output)

                loss = None
                if labels is not None:
                    loss = self.loss_function(logits, labels, self.config)

                outputs = TokenClassifierOutput(
                    loss=loss,
                    logits=logits,
                    hidden_states=model_outputs.hidden_states,
                    attentions=model_outputs.attentions,
                )
        
        if self.wm_group_size is None:
            # No weighted majority voting
            return outputs

        if not self.config.group_softmax:
            logits = outputs.logits.squeeze(-1)
            B, T, _ = inputs_embeds.shape
            assert type(self.wm_group_size) == int
            G = self.wm_group_size
            assert B % G == 0
            N = B // G

            bin_idx = kwargs["bin_idx"] # [B, T]
            

            bin_idx = bin_idx.reshape(N, G, T)
            position_ids = position_ids.reshape(N, G, T)
            logits = logits.reshape(N, G, T)
            agent_idx = torch.arange(0, G, device=bin_idx.device)[None,:,None].expand(N,-1, T).long() # [N, G, T]

            for n in range(N):
                bin_idx_group = bin_idx[n] # [G, T]
                logits_group = logits[n] # [G, T]
                position_ids_group = position_ids[n] # [G, T]
                agent_idx_group = agent_idx[n] # [G, T]

                valid_pos = (bin_idx_group != -100)
                logits_group_valid = logits_group[valid_pos] # flat

                probs_group_valid = torch.sigmoid(logits_group_valid) # flat
                position_ids_group_valid = position_ids_group[valid_pos] # flat
                bin_idx_group_valid = bin_idx_group[valid_pos] # flat
                agent_idx_group_valid = agent_idx_group[valid_pos] # flat
                
                same_bins_edge_matrix = bin_idx_group_valid[:,None] == bin_idx_group_valid[None,:] # [valid, valid]
                probs_group_valid_sum_candidates = probs_group_valid[None,:] # [1, valid]
                
                if self.wm_type == "causal":
                    mask = position_ids_group_valid[:,None] >= position_ids_group_valid[None,:]
                if self.wm_type == "recent":
                    causal_mask = position_ids_group_valid[:,None] >= position_ids_group_valid[None,:]
                    mask = torch.zeros_like(causal_mask)
                    for i in range(mask.shape[0]):
                        for g in range(G):
                            agent_g_pos = torch.nonzero((agent_idx_group_valid==g) & causal_mask[i])
                            if len(agent_g_pos) > 0:
                                mask[i,agent_g_pos[-1]] = True

                if self.wm_type in ["causal", "recent"]:
                    for i in range(mask.shape[0]):
                        if mask[i].sum() < self.wm_group_size:
                            mask[i] = False
                            mask[i,i] = True
                    same_bins_edge_matrix = same_bins_edge_matrix & mask

                    probs_group_valid_sum_candidates = probs_group_valid_sum_candidates.expand_as(same_bins_edge_matrix).clone()
                    probs_group_valid_sum_candidates[~mask] = 0.0
                
                probs_group_valid_sum = probs_group_valid_sum_candidates.sum(dim=-1, keepdim=True) # [valid,1] or [1,1]

                if self.wm_type in ["causal", "recent"]:
                    probs_group_valid_sum[mask.sum(dim=-1)<=1] = 1.0

                probs_aggregated = torch.einsum("ab,b->a", same_bins_edge_matrix.to(probs_group_valid.dtype), probs_group_valid) # [valid, valid]
                probs_aggregated = probs_aggregated / probs_group_valid_sum.flatten() # [valid, valid]
                eps = 1e-4
                probs_aggregated = probs_aggregated.clip(min=eps, max=1-eps) # [valid, valid]
                logits_group[valid_pos] = torch.logit(probs_aggregated) # flat
                logits[n] = logits_group
                
            logits = logits.reshape(B, T, 1)
            outputs.logits = logits
            return outputs
        elif self.config.group_softmax:
            logits = outputs.logits.squeeze(-1)
            B, T, _ = inputs_embeds.shape
            assert logits.shape[0] == B*2
            assert type(self.wm_group_size) == int
            G = self.wm_group_size
            assert B % G == 0
            N = B // G

            bin_idx = kwargs["bin_idx"] # [B, T]
            
            bin_idx = bin_idx.reshape(N, G, T)
            position_ids = position_ids.reshape(N, G, T)
            logits = logits.reshape(B, 2, T)
            logits_class = logits[:,0] # [B, T]
            logits_abstain = logits[:,1] # [B, T]
            logits_class = logits_class.reshape(N, G, T)
            logits_abstain = logits_abstain.reshape(N, G, T)
            agent_idx = torch.arange(0, G, device=bin_idx.device)[None,:,None].expand(N,-1, T).long() # [N, G, T]

            logits_new = torch.zeros(B, T, dtype=logits.dtype, device=logits.device)
            logits_new = logits_new.reshape(N, G, T)

            for n in range(N):
                bin_idx_group = bin_idx[n] # [G, T]
                logits_class_group = logits_class[n] # [G, T]
                logits_abstain_group = logits_abstain[n] # [G, T]
                position_ids_group = position_ids[n] # [G, T]
                agent_idx_group = agent_idx[n] # [G, T]

                valid_pos = (bin_idx_group != -100)
                logits_class_group_valid = logits_class_group[valid_pos] # flat
                logits_abstain_group_valid = logits_abstain_group[valid_pos] # flat

                probs_group_valid = torch.softmax(torch.stack([logits_class_group_valid, logits_abstain_group_valid], dim=-1), dim=-1)[:,0] # [valid]
                position_ids_group_valid = position_ids_group[valid_pos] # flat
                bin_idx_group_valid = bin_idx_group[valid_pos] # flat
                agent_idx_group_valid = agent_idx_group[valid_pos] # flat
                
                same_bins_edge_matrix = bin_idx_group_valid[:,None] == bin_idx_group_valid[None,:] # [valid, valid]
                probs_group_valid_sum_candidates = probs_group_valid[None,:] # [1, valid]
                
                if self.wm_type == "causal":
                    mask = position_ids_group_valid[:,None] >= position_ids_group_valid[None,:]
                if self.wm_type == "recent":
                    causal_mask = position_ids_group_valid[:,None] >= position_ids_group_valid[None,:]
                    mask = torch.zeros_like(causal_mask)
                    for i in range(mask.shape[0]):
                        for g in range(G):
                            agent_g_pos = torch.nonzero((agent_idx_group_valid==g) & causal_mask[i])
                            if len(agent_g_pos) > 0:
                                mask[i,agent_g_pos[-1]] = True

                if self.wm_type in ["causal", "recent"]:
                    for i in range(mask.shape[0]):
                        if mask[i].sum() < self.wm_group_size:
                            mask[i] = False
                            mask[i,i] = True
                    same_bins_edge_matrix = same_bins_edge_matrix & mask

                    probs_group_valid_sum_candidates = probs_group_valid_sum_candidates.expand_as(same_bins_edge_matrix).clone()
                    probs_group_valid_sum_candidates[~mask] = 0.0
                
                probs_group_valid_sum = probs_group_valid_sum_candidates.sum(dim=-1, keepdim=True) # [valid,1] or [1,1]

                if self.wm_type in ["causal", "recent"]:
                    probs_group_valid_sum[mask.sum(dim=-1)<=1] = 1.0

                probs_aggregated = torch.einsum("ab,b->a", same_bins_edge_matrix.to(probs_group_valid.dtype), probs_group_valid) # [valid, valid]
                probs_aggregated = probs_aggregated / probs_group_valid_sum.flatten() # [valid, valid]
                eps = 1e-4
                probs_aggregated = probs_aggregated.clip(min=eps, max=1-eps) # [valid, valid]
                logits_new[n][valid_pos] = torch.logit(probs_aggregated)
            logits_new = logits_new.reshape(B, T, 1)
            outputs.logits = logits_new
            return outputs

