import torch
import torch.nn.functional as F
from torch import nn
import deepspeed
import transformers
import pdb
import cv2
import torch.nn.init as init
import torch.distributed as dist
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
from transformers import LlamaForCausalLM, LlamaConfig, LlamaModel
from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM
from transformers import LlamaTokenizer, AutoTokenizer, AutoConfig, AutoModelForCausalLM
from peft import get_peft_model, LoraConfig, TaskType 
from dataset.constants import IGNORE_INDEX, EVENT_TOKEN_INDEX, DEFAULT_EVENT_TOKEN, DEFAULT_EV_START_TOKEN, DEFAULT_EV_END_TOKEN, DEFAULT_EVENT_PATCH_TOKEN
from typing import List, Optional, Tuple, Union, Dict, Callable
from transformers.generation.utils import GenerateOutput
from torch.nn.utils.rnn import pad_sequence
import numpy as np

def get_spatio_temporal_features(features, num_temporal_tokens=None):
    if isinstance(features, list):
        features = torch.stack(features)

    if features.ndim != 3:
        raise ValueError("Input features should be a 3D tensor with shape (t, s, c)")

    t, s, c = features.shape

    if num_temporal_tokens is None:
        num_temporal_tokens = t

    temporal_tokens = features.mean(dim=1)

    if num_temporal_tokens > t:
        padding_size = num_temporal_tokens - t
        temporal_tokens = torch.nn.functional.pad(temporal_tokens, (0, 0, 0, padding_size))
    elif num_temporal_tokens < t:
        temporal_tokens = temporal_tokens[:num_temporal_tokens]

    spatial_tokens = features.mean(dim=0)
    sp_features = torch.cat([temporal_tokens, spatial_tokens], dim=0)

    return sp_features.to(features.device)  

def sharpen_rgb_image(event_img):
    channels = cv2.split(event_img)  
    sharpened_channels = []
    for ch in channels:
        blur = cv2.GaussianBlur(ch, (3, 3), 0)
        sharpened = cv2.addWeighted(ch, 1.5, blur, -0.5, 0)
        sharpened_channels.append(sharpened)
    return cv2.merge(sharpened_channels)  

def split_event_stream(event_stream, interval=10000):
    bins = np.arange(event_stream[0, 2], event_stream[-1, 2] + interval, interval)
    indices = np.digitize(event_stream[:, 2], bins)
    return [event_stream[indices == i] for i in range(1, len(bins))]

def generate_event_image(x, y, p):
    height, width = 240, 320 
    event_image = np.ones((height, width, 3), dtype=np.uint8) * 255

    blue_mask = p == 0
    red_mask = p == 1

    event_image[y[blue_mask], x[blue_mask]] = np.array([0, 0, 255])  # 蓝色
    event_image[y[red_mask], x[red_mask]] = np.array([255, 0, 0])      # 红色

    return event_image

def compute_normalized_event_density_batch(event_batch, resolution=(320, 240), target_res=224, grid_size=16):
    density_list = []
    patch_size = target_res // grid_size  # 计算每个网格的像素尺寸

    for events in event_batch:
        if events.size == 0:  # 处理空事件的情况
            density_list.append(np.zeros((grid_size, grid_size), dtype=np.float32))
            continue

        x, y = events[:, 0], events[:, 1]
        x_scaled = (x / resolution[0] * target_res).clip(0, target_res-1).astype(int)
        y_scaled = (y / resolution[1] * target_res).clip(0, target_res-1).astype(int)

        grid_x = x_scaled // patch_size
        grid_y = y_scaled // patch_size

        grid_indices = grid_y * grid_size + grid_x
        density_map = np.bincount(grid_indices, minlength=grid_size**2).reshape(grid_size, grid_size)

        max_density = density_map.max()
        if max_density > 0:
            density_map = density_map.astype(np.float32) / max_density
        else:
            density_map = density_map.astype(np.float32)

        density_list.append(density_map)

    return density_list

def token_merge(token_sequence, valid_token_indices, event_density):
    merged_tokens_list = []
    merged_density_list = []
    token_dim = token_sequence[0].shape[-1]
    
    for i in range(len(token_sequence)):
        ev_density = event_density[i]
        tokens = token_sequence[i]  # shape (num_valid_tokens + 1, token_dim)
        cls_token = tokens[0:1]  # (1, token_dim)
        patch_tokens = tokens[1:]  # (N_valid, token_dim)

        token_grid = torch.zeros(16, 16, token_dim, device=tokens.device)
        # patch_tokens_indices = [token_indices[1:] for token_indices in valid_token_indices]
        patch_tokens_indices = [[idx - 1 for idx in token_indices[1:]] for token_indices in valid_token_indices]
        for j, idx in enumerate(patch_tokens_indices[i]):
            row, col = (idx // 16) % 16, (idx % 16) % 16
            
            if row >= 16 or col >= 16:
                raise IndexError(f"Invalid index {idx}: row {row}, col {col} exceeds the 14x14 grid")
            
            token_grid[row, col] = patch_tokens[j]

        token_grid = token_grid.permute(2, 0, 1).unsqueeze(0)  # (1, C, 16, 16)
        pooled = F.avg_pool2d(token_grid, kernel_size=2, stride=2)  # (1, C, 8, 8)
        pooled = pooled.squeeze(0).permute(1, 2, 0).reshape(-1, token_dim)  # (64, C)
        pooled_ev_density = F.avg_pool2d(ev_density.unsqueeze(0).unsqueeze(0), kernel_size=2, stride=2).squeeze().flatten()
        
        non_zero_tokens = pooled[~torch.all(pooled == 0, dim=-1)]  # Remove rows that are all zeros
        non_zero_density = pooled_ev_density[pooled_ev_density != 0]
        cls_token_density = torch.tensor([1], dtype=torch.bfloat16, device=tokens.device)    
            
        merged_tokens = torch.cat([cls_token, non_zero_tokens], dim=0)  # (1 + valid_tokens, token_dim)
        merged_tokens_density = torch.cat([cls_token_density, non_zero_density], dim=0)  # (1 + valid_tokens, )
        
        merged_tokens_list.append(merged_tokens)
        merged_density_list.append(merged_tokens_density)

        for i in range(len(merged_tokens_list)):
            if merged_tokens_list[i].shape[0] != merged_density_list[i].shape[0]:
                print(f"Mismatch at index {i}")
        
    return merged_tokens_list, merged_density_list


def select_non_white_tokens_batch(image_np_list: list, image_features: torch.Tensor, patch_size=14):
    batch_size = len(image_np_list)
    non_white_tokens_batch = []
    selected_indices_batch = []

    for batch_idx in range(batch_size):
        # Resize the image to 224x224
        image_np = cv2.resize(image_np_list[batch_idx], (224, 224))
        image_features_single = image_features[batch_idx]

        h, w, _ = image_np.shape
        assert image_features_single.shape[0] == 1 + (h // patch_size) * (w // patch_size), \
            f"Token num {image_features_single.shape[0]} image size not match"

        non_white_tokens = [image_features_single[0]]  
        selected_indices = [0]  

        idx = 1  
        for i in range(0, h, patch_size):
            for j in range(0, w, patch_size):
                patch = image_np[i:i+patch_size, j:j+patch_size]
                if np.all(patch == 255):
                    pass  
                else:
                    non_white_tokens.append(image_features_single[idx])
                    selected_indices.append(idx)
                idx += 1

        non_white_tokens_batch.append(torch.stack(non_white_tokens, dim=0))
        selected_indices_batch.append(selected_indices)

    return non_white_tokens_batch, selected_indices_batch

def split_event_by_time(event_npy, time_interval=10000):
    t = event_npy['t']
    x = event_npy['x']
    y = event_npy['y']
    p = event_npy['p']

    time_bins = t // time_interval

    bin_change_indices = np.where(np.diff(time_bins) != 0)[0] + 1

    t_splits = np.split(t, bin_change_indices)
    x_splits = np.split(x, bin_change_indices)
    y_splits = np.split(y, bin_change_indices)
    p_splits = np.split(p, bin_change_indices)

    split_data = [
        {'t': t_splits[i], 'x': x_splits[i], 'y': y_splits[i], 'p': p_splits[i]}
        for i in range(len(t_splits))
    ]
    return split_data

def convert_data(data):
    result = []
    for item in data:
        p = item['p']
        t = item['t']
        x = item['x']
        y = item['y']
        combined = np.column_stack((x, y, t, p))
        result.append(combined)
    return result

# def spatial_max_pool(events: np.ndarray):
#     """
#     对 [N, 4] 的 events 做 spatial pooling（按 x, y 聚合，保留 t 最大的事件）

#     返回:
#         [M, 4] 的 numpy 数组（M <= N），每个 x, y 保留一个 t 最大的事件
#     """
#     if len(events) == 0:
#         return events

#     xyt = events[:, :3]
#     p = events[:, 3]
#     xyt = xyt.astype(np.int32)
#     xy = xyt[:, :2]

#     xy_str = [f"{x}_{y}" for x, y in xy]
#     index_map = {}

#     for idx, key in enumerate(xy_str):
#         if key not in index_map:
#             index_map[key] = idx
#         else:
#             if xyt[idx, 2] > xyt[index_map[key], 2]:  
#                 index_map[key] = idx

#     selected_indices = list(index_map.values())
#     return events[selected_indices]


def adaptive_aggregate_optimized(event_windows, features, sparsity_factors, min_group_len=1, percentile=25):
    features = features.to(torch.float32)
    sims = F.cosine_similarity(features[:-1], features[1:], dim=-1).cpu().numpy()
    avg_sparsity = (sparsity_factors[:-1] + sparsity_factors[1:]) / 2.0
    agg_scores = sims * avg_sparsity

    threshold = np.percentile(agg_scores, percentile)
    # print(f"Using percentile {percentile} → threshold = {threshold:.4f}")

    n = len(event_windows)
    group_indices = []
    start = 0
    for i in range(n - 1):
        if (i + 1 - start >= min_group_len) and (agg_scores[i] < threshold):
            group_indices.append((start, i + 1))
            start = i + 1
    group_indices.append((start, n))

    aggregated_windows = [np.concatenate(event_windows[s:e], axis=0) for s, e in group_indices]
    return aggregated_windows

def compute_sparsity_factor(window, alpha=0.01, H=240, W=320):
    num_events = len(window)
    r = num_events / (H * W)
    F_val = np.exp(-alpha * r)
    return F_val


def visualize_event_stream(event_windows, save_path="/mnt/data2/SyL/LongEvent/script/data_preprocess/test_img"):
    frames = []
    for ev_win in event_windows:
        ev_img = generate_event_image(ev_win[:,0], ev_win[:,1], ev_win[:,3])
        frames.append(ev_img)
    for i, frame in enumerate(frames):
        cv2.imwrite(f"{save_path}/keyframe_{i:03d}.jpg", frame)

class EventChatLLaMAConfig(LlamaConfig):
    model_type = "EventChat_llama" 


class EventChatQwenConfig(Qwen2Config):
    model_type = "EventChat_Qwen" 


class VisualTower(nn.Module):
    def __init__(self, visual_tower):
        super().__init__()

        self.visual_tower_name = visual_tower
        self.event_processor = CLIPImageProcessor.from_pretrained(self.visual_tower_name)
        self.visual_tower = CLIPVisionModel.from_pretrained(self.visual_tower_name, torch_dtype=torch.bfloat16)
        self.visual_tower.requires_grad_(False)
    
    def forward(self, event_tensor):
        outputs = self.visual_tower.vision_model(event_tensor)
        events_feature = outputs.last_hidden_state
        events_feature = self.visual_projecotor(events_feature)

        return events_feature
    
class DensityGuidedCompressor(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_queries):
        super().__init__()
        self.num_queries = num_queries
        self.query_embed = nn.Parameter(torch.randn(num_queries, hidden_dim))
        self.key_proj = nn.Linear(input_dim, hidden_dim)
        self.value_proj = nn.Linear(input_dim, hidden_dim)
        self.scale = hidden_dim ** 0.5

        self.density_encoder = nn.Sequential(
            nn.Linear(1, hidden_dim * 2),
            nn.ReLU(),
            nn.Linear(hidden_dim * 2, 1)  
        )

    def forward(self, token_features, token_densities, attention_mask):
        B, N, D = token_features.shape

        K = self.key_proj(token_features)  # [B, N, hidden_dim]
        V = self.value_proj(token_features)  # [B, N, hidden_dim]
        Q = self.query_embed.unsqueeze(0).expand(B, -1, -1)  # [B, num_queries, hidden_dim]

        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale  # [B, num_queries, N]

        density_bias = token_densities.unsqueeze(-1)  # [B, N, 1]
        density_bias = self.density_encoder(density_bias)  # [B, N, 1]
        density_bias = density_bias.transpose(1, 2)  # [B, 1, N]
        attn_scores = attn_scores + density_bias  # [B, num_queries, N]

        # Mask + softmax
        attn_scores = attn_scores.masked_fill(attention_mask.unsqueeze(1) == 0, -float('inf'))
        attn_weights = F.softmax(attn_scores, dim=-1)  # [B, num_queries, N]

        output = torch.matmul(attn_weights, V)  # [B, num_queries, hidden_dim]
        return output


class EventChatLlamaModel(LlamaModel):
    config_class = EventChatLLaMAConfig

    def __init__(self, config: LlamaConfig):
        super(EventChatLlamaModel, self).__init__(config)

        self.mlp_depth = 2
        self.text_hidden_size = 1024
        self.hidden_size = 4096

        # self.query_embeddings = None
        # self.attention_layers = None

        if hasattr(config, "mm_visual_tower"):          
            self.visual_tower = self.build_visual_tower(config.mm_visual_tower)
            self.visual_projector = self.build_mlp_projector(self.text_hidden_size, self.hidden_size).to(dtype=torch.bfloat16)

        if hasattr(config, "event_feature_adaptor"):
            self.feature_adaptor = nn.Linear(self.hidden_size, self.hidden_size)

        if hasattr(config, "use_event_qformer"):
            self.query_embeddings, self.attention_layers = self.build_event_qformer(config)
            self.register_parameter('query_embeddings', self.query_embeddings)  
            self.attention_layers = nn.ModuleList(self.attention_layers)

    def build_visual_tower(self, visual_tower):
        return VisualTower(visual_tower)
        

    def build_mlp_projector(self, text_hidden_size, hidden_dim):
        mlp_depth = self.mlp_depth
        modules = [nn.Linear(text_hidden_size, hidden_dim)]
        for _ in range(1, mlp_depth):
            modules.append(nn.GELU())
            modules.append(nn.Linear(hidden_dim, hidden_dim))
        return nn.Sequential(*modules)
    
    def get_vision_tower(self):
        vision_tower = getattr(self, 'vision_tower', None)
        if type(vision_tower) is list:
            vision_tower = vision_tower[0]
        return vision_tower
    
    def initialize_vision_modules(self, model_args, fsdp=None):
        visual_tower = model_args.vision_tower
        self.config.mm_visual_tower = visual_tower
        self.config.event_feature_adaptor = True

        # Build the visual tower
        visual_tower = self.build_visual_tower(model_args.vision_tower) 
        self.visual_tower = visual_tower

        # Build the visual projector
        self.visual_projector = self.build_mlp_projector(self.text_hidden_size, self.hidden_size).to(dtype=torch.bfloat16)

        # Load feature adaptor if needed
        if model_args.use_feature_adaptor:
            self.feature_adaptor = nn.Linear(self.hidden_size, self.hidden_size)

        # Load event Qformer if needed
        if model_args.use_event_qformer:
            self.query_embedder, self.attention_layers = self.build_event_qformer(model_args)
            self.add_module("query_embedder", self.query_embedder)
            self.attention_layers = nn.ModuleList(self.attention_layers)
        
        # Load pretrained weights for feature_adaptor if provided
        if model_args.pretrain_feature_adaptor is not None:
            print("Loading feature_adaptor pretrain weights...")
            pretrained_weights = torch.load(model_args.pretrain_feature_adaptor)
            # Adjust keys to match model structure
            pretrained_weights = {k.replace("model.feature_adaptor.", ""): v for k, v in pretrained_weights.items()}
            self.feature_adaptor.load_state_dict(pretrained_weights, strict=True)
            print("Pretrained weights loaded successfully into feature_adaptor.")

        # Load pretrained weights for visual_projector if provided
        if model_args.pretrain_mm_mlp_adapter is not None:
            print("Loading mm_projector pretrain weights...")
            pretrained_weights = torch.load(model_args.pretrain_mm_mlp_adapter)
            # Adjust keys to match model structure
            pretrained_weights = {k.replace("model.visual_projector.", ""): v for k, v in pretrained_weights.items()}
            self.visual_projector.load_state_dict(pretrained_weights, strict=True)
            print("Pretrained weights loaded successfully into visual_projector.")

        # Load pretrained weights for query_embedder if specified
        if model_args.pretrain_query_embedder is not None:
            print("Loading query_embedder pretrain weights...")
            pretrained_weights = torch.load(model_args.pretrain_query_embedder)
            pretrained_weights = {k.replace("model.query_embedder.", ""): v for k, v in pretrained_weights.items()}
            # Load query_embedder weights
            self.query_embedder.load_state_dict(pretrained_weights, strict=True)
            print("Pretrained weights loaded successfully into query_embedder.")

        # Load pretrained weights for attention_layers if specified
        if model_args.pretrain_attention_layers is not None:
            print("Loading attention_layers pretrain weights...")
            pretrained_weights = torch.load(model_args.pretrain_attention_layers)
            
            # Filter the pretrained weights to only include attention layers' weights
            attention_layer_weights = {k: v for k, v in pretrained_weights.items() if "attention_layers" in k}
            
            # Ensure we are only loading weights for the attention layers
            for i, attention_layer in enumerate(self.attention_layers):
                # Match keys for each attention layer and load the state dict
                layer_weights = {k.replace(f"model.attention_layers.{i}.", ""): v for k, v in attention_layer_weights.items() if f"attention_layers.{i}" in k}
                attention_layer.load_state_dict(layer_weights, strict=True)
            print("Pretrained weights loaded successfully into attention_layers.")


class EventChatQwenModel(Qwen2Model):
    config_class = EventChatQwenConfig

    def __init__(self, config: Qwen2Config):
        super(EventChatQwenModel, self).__init__(config)

        self.mlp_depth = 2
        self.text_hidden_size = 1024
        self.hidden_size = 2048

        if hasattr(config, "mm_visual_tower"):          
            self.visual_tower = self.build_visual_tower(config.mm_visual_tower)
            self.visual_projector = self.build_mlp_projector(self.text_hidden_size, self.hidden_size).to(dtype=torch.bfloat16)

        if hasattr(config, "event_feature_adaptor") and config.event_feature_adaptor is True:
            self.feature_adaptor = nn.Linear(self.hidden_size, self.hidden_size)

        if hasattr(config, "use_event_qformer"):
            self.query_embeddings, self.attention_layers = self.build_event_qformer(config)
            self.register_parameter('query_embeddings', self.query_embeddings)  
            self.attention_layers = nn.ModuleList(self.attention_layers)
        
        if hasattr(config, "sparse_spatial_temporal"):
            self.SDGA_module = self.build_SDGA_module(config)

    def build_visual_tower(self, visual_tower):
        return VisualTower(visual_tower)
    
        
    def build_SDGA_module(self, model_args):
       return DensityGuidedCompressor(model_args.hidden_size, model_args.hidden_size, model_args.num_queries)
        

    def build_mlp_projector(self, text_hidden_size, hidden_dim):
        mlp_depth = self.mlp_depth
        modules = [nn.Linear(text_hidden_size, hidden_dim)]
        for _ in range(1, mlp_depth):
            modules.append(nn.GELU())
            modules.append(nn.Linear(hidden_dim, hidden_dim))
        return nn.Sequential(*modules)
    
    def get_vision_tower(self):
        vision_tower = getattr(self, 'vision_tower', None)
        if type(vision_tower) is list:
            vision_tower = vision_tower[0]
        return vision_tower
      
    
    def initialize_vision_modules(self, model_args, fsdp=None):
        visual_tower = model_args.vision_tower
        self.config.mm_visual_tower = visual_tower
        self.config.event_feature_adaptor = True

        visual_tower = self.build_visual_tower(model_args.vision_tower)
        self.SDGA_module = self.build_SDGA_module(model_args)
        self.SDGA_module.to(dtype=torch.bfloat16)
        
        self.config.sparse_spatial_temporal = True
        self.config.num_queries = model_args.num_queries 
        
        self.visual_tower = visual_tower
        self.visual_projector = self.build_mlp_projector(self.text_hidden_size, self.hidden_size).to(dtype=torch.bfloat16)
        if model_args.use_feature_adaptor:
            self.feature_adaptor = nn.Linear(self.hidden_size, self.hidden_size)

        if model_args.use_event_qformer:
            self.query_embedder, self.attention_layers = self.build_event_qformer(model_args)
            self.add_module("query_embedder", self.query_embedder)
            self.attention_layers = nn.ModuleList(self.attention_layers)
        
        if model_args.pretrain_feature_adaptor is not None:
            print("Loading feature_adaptor pretrain weights...")
            pretrained_weights = torch.load(model_args.pretrain_feature_adaptor)
            pretrained_weights = {k.replace("model.feature_adaptor.", ""): v for k, v in pretrained_weights.items()}
            self.feature_adaptor.load_state_dict(pretrained_weights, strict=True)
            print("Pretrained weights loaded successfully into feature_adaptor.")

        if model_args.pretrain_mm_mlp_adapter is not None:
            print("Loading mm_projector pretrain weights...")
            pretrained_weights = torch.load(model_args.pretrain_mm_mlp_adapter)
            pretrained_weights = {k.replace("model.visual_projector.", ""): v for k, v in pretrained_weights.items()}
            self.visual_projector.load_state_dict(pretrained_weights, strict=True)
            print("Pretrained weights loaded successfully into visual_projector.")

        if model_args.pretrain_query_embedder is not None:
            print("Loading query_embedder pretrain weights...")
            pretrained_weights = torch.load(model_args.pretrain_query_embedder)
            pretrained_weights = {k.replace("model.query_embedder.", ""): v for k, v in pretrained_weights.items()}
            self.query_embedder.load_state_dict(pretrained_weights, strict=True)
            print("Pretrained weights loaded successfully into query_embedder.")

        if model_args.pretrain_attention_layers is not None:
            print("Loading attention_layers pretrain weights...")
            pretrained_weights = torch.load(model_args.pretrain_attention_layers)
            
            attention_layer_weights = {k: v for k, v in pretrained_weights.items() if "attention_layers" in k}         
            for i, attention_layer in enumerate(self.attention_layers):
                layer_weights = {k.replace(f"model.attention_layers.{i}.", ""): v for k, v in attention_layer_weights.items() if f"attention_layers.{i}" in k}
                attention_layer.load_state_dict(layer_weights, strict=True)
            print("Pretrained weights loaded successfully into attention_layers.")


# class EventChatLLaMAModel(LlamaForCausalLM):

#     config_class = EventChatLLaMAConfig

#     def __init__(self, config) -> None:
#         super(LlamaForCausalLM, self).__init__(config)
        
#         self.model = EventChatLlamaModel(config)
#         self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

#         self.post_init()

#     def get_model(self):
#         return self.model
    
#     def get_visual_tower(self):
#         return self.get_model().visual_tower
    
    
#     def visval_encode(self, event_tensor):
#         with torch.no_grad():
#             outputs = self.get_model().visual_tower.visual_tower(event_tensor)
#         events_feature = outputs.last_hidden_state
#         events_feature = events_feature.detach().requires_grad_(True)
#         events_feature = self.get_model().visual_projector(events_feature)
#         return events_feature

#     def initialize_vision_tokenizer(self, model_args, tokenizer):
#         if model_args.mm_use_im_patch_token:
#             tokenizer.add_tokens([DEFAULT_EVENT_PATCH_TOKEN], special_tokens=True)
#             self.resize_token_embeddings(len(tokenizer))

#         if model_args.mm_use_im_start_end:
#             num_new_tokens = tokenizer.add_tokens([DEFAULT_EV_START_TOKEN, DEFAULT_EV_END_TOKEN], special_tokens=True)
#             self.resize_token_embeddings(len(tokenizer))

#             if num_new_tokens > 0:
#                 input_embeddings = self.get_input_embeddings().weight.data
#                 output_embeddings = self.get_output_embeddings().weight.data

#                 input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
#                     dim=0, keepdim=True)
#                 output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
#                     dim=0, keepdim=True)

#                 input_embeddings[-num_new_tokens:] = input_embeddings_avg
#                 output_embeddings[-num_new_tokens:] = output_embeddings_avg

#             if model_args.tune_mm_mlp_adapter:
#                 for p in self.get_input_embeddings().parameters():
#                     p.requires_grad = True
#                 for p in self.get_output_embeddings().parameters():
#                     p.requires_grad = False

#             if model_args.pretrain_mm_mlp_adapter:
#                 mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
#                 embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
#                 assert num_new_tokens == 2
#                 if input_embeddings.shape == embed_tokens_weight.shape:
#                     input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
#                 elif embed_tokens_weight.shape[0] == num_new_tokens:
#                     input_embeddings[-num_new_tokens:] = embed_tokens_weight
#                 else:
#                     raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
#         elif model_args.mm_use_im_patch_token:
#             if model_args.tune_mm_mlp_adapter:
#                 for p in self.get_input_embeddings().parameters():
#                     p.requires_grad = False
#                 for p in self.get_output_embeddings().parameters():
#                     p.requires_grad = False

    
#     def forward(self, 
#             event_tensors: Optional[torch.FloatTensor] = None,
#             input_ids: torch.LongTensor = None, 
#             labels: Optional[torch.LongTensor] = None,
#             events: Optional[torch.FloatTensor] = None,
#             events_list: Optional[torch.FloatTensor] = None,
#             inputs_embeds: Optional[torch.FloatTensor] = None,
#             position_ids: Optional[torch.LongTensor] = None,
#             attention_mask: Optional[torch.Tensor] = None,
#             past_key_values: Optional[List[torch.FloatTensor]] = None,
#             use_cache: Optional[bool] = None,
#             event_image_sizes : Optional[List[List[int]]] = None,
#             output_attentions: Optional[bool] = None,
#             output_hidden_states: Optional[bool] = None,
#             return_dict: Optional[bool] = None):
    
#         if inputs_embeds is None:
#             (
#                 input_ids,
#                 position_ids,
#                 attention_mask,
#                 past_key_values,
#                 inputs_embeds,
#                 labels
#             ) = self.prepare_inputs_labels_for_multimodal(
#                 input_ids, 
#                 position_ids,
#                 attention_mask,
#                 past_key_values,
#                 labels,
#                 event_tensors,
#                 event_image_sizes           
#             )
            
#         outputs = super().forward(
#             input_ids=input_ids,
#             attention_mask=attention_mask,
#             position_ids=position_ids,
#             past_key_values=past_key_values,
#             inputs_embeds=inputs_embeds,
#             labels=labels,
#             use_cache=use_cache,
#             output_attentions=output_attentions,
#             output_hidden_states=output_hidden_states,
#             return_dict=return_dict
#         )      

#         torch.cuda.synchronize()
#         return outputs
    
#     @torch.no_grad()
#     def generate(
#         self,
#         inputs: Optional[torch.Tensor] = None,
#         event_tensors: Optional[torch.Tensor] = None,
#         event_image_sizes: Optional[torch.Tensor] = None,
#         event_feature = None,
#         **kwargs,
#     ) -> Union[GenerateOutput, torch.LongTensor]:
#         position_ids = kwargs.pop("position_ids", None)
#         attention_mask = kwargs.pop("attention_mask", None)
#         if "inputs_embeds" in kwargs:
#             raise NotImplementedError("`inputs_embeds` is not supported")
        
#         if event_tensors is not None:
#             (
#                 inputs,
#                 position_ids,
#                 attention_mask,
#                 _,
#                 inputs_embeds,
#                 _
#             ) = self.prepare_inputs_labels_for_multimodal(
#                 inputs,
#                 position_ids,
#                 attention_mask,
#                 None,
#                 None,
#                 event_tensors,
#                 event_image_sizes=event_image_sizes
#             )
#         else:
#             raise NotImplementedError("please input Event")
        
#         return super().generate(
#             position_ids=position_ids,
#             attention_mask=attention_mask,
#             inputs_embeds=inputs_embeds,
#             **kwargs
#         )
    
   
#     def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
#                                       inputs_embeds=None, **kwargs):
#         event_tensors = kwargs.pop("event_tensors", None)
#         event_image_sizes = kwargs.pop("event_sizes", None)
#         inputs = super().prepare_inputs_for_generation(
#             input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
#         )
#         if event_tensors is not None:
#             inputs['event_tensors'] = event_tensors
#         if event_image_sizes is not None:
#             inputs['event_image_sizes'] = event_image_sizes
#         return inputs
        
    
#     def prepare_inputs_labels_for_multimodal(
#         self, input_ids, position_ids, attention_mask, past_key_values, labels,
#         event_tensors, event_image_sizes=None
#     ):  
#         if event_tensors is None or input_ids.shape[1] == 1:
#             return input_ids, position_ids, attention_mask, past_key_values, None, labels

#         if isinstance(event_tensors, list):
#             if not all(isinstance(item, list) for item in event_tensors):
#                 event_tensors = [event_tensors]

#             ev_features_list = []
#             ev_features_lengths = []  
#             for item in event_tensors:
#                 ev_feature = []
#                 for ev in item:
#                     ev = ev.unsqueeze(0)
#                     feature = self.visval_encode(ev)
#                     feature = self.get_model().feature_adaptor(feature)
#                     feature = feature.squeeze(0)
#                     ev_feature.append(feature)
#                 event_feature = get_spatio_temporal_features(ev_feature)
#                 ev_features_list.append(event_feature)
#                 ev_features_lengths.append(event_feature.shape[0])
#             padded_event_features = pad_sequence(ev_features_list, batch_first=True, padding_value=0)
#             event_features = padded_event_features
#         else:
#             event_features = self.visval_encode(event_tensors)
#             ev_features_lengths = [event_features.shape[0]]

#         _labels = labels
#         _position_ids = position_ids
#         _attention_mask = attention_mask
#         if attention_mask is None:
#             attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
#         else:
#             attention_mask = attention_mask.bool()
#         if position_ids is None:
#             position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
#         if labels is None:
#             labels = torch.full_like(input_ids, IGNORE_INDEX)

#         _input_ids = input_ids
#         input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
#         labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]

#         new_input_embeds = []
#         new_labels = []
#         cur_event_idx = 0
#         for batch_idx, cur_input_ids in enumerate(input_ids):
#             num_events = (cur_input_ids == EVENT_TOKEN_INDEX).sum()
#             if num_events == 0:
#                 cur_event_features = event_features[cur_event_idx]
#                 cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
#                 cur_input_embeds = torch.cat([cur_input_embeds_1, cur_event_features[0:0]], dim=0)
#                 new_input_embeds.append(cur_input_embeds)
#                 new_labels.append(labels[batch_idx])
#                 cur_event_idx += 1
#                 continue

#             event_token_indices = [-1] + torch.where(cur_input_ids == EVENT_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
#             cur_input_ids_noim = []
#             cur_labels = labels[batch_idx]
#             cur_labels_noim = []
#             for i in range(len(event_token_indices) - 1):
#                 cur_input_ids_noim.append(cur_input_ids[event_token_indices[i]+1:event_token_indices[i+1]])
#                 cur_labels_noim.append(cur_labels[event_token_indices[i]+1:event_token_indices[i+1]])
#             split_sizes = [x.shape[0] for x in cur_labels_noim]
#             cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
#             cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
#             cur_new_input_embeds = []
#             cur_new_labels = []

#             for i in range(num_events + 1):
#                 cur_new_input_embeds.append(cur_input_embeds_no_im[i])
#                 cur_new_labels.append(cur_labels_noim[i])
#                 if i < num_events:
#                     valid_length = ev_features_lengths[cur_event_idx]
#                     cur_event_features = event_features[cur_event_idx][:valid_length]
#                     cur_event_idx += 1
#                     cur_new_input_embeds.append(cur_event_features)
#                     cur_new_labels.append(torch.full((valid_length,), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))

#             cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
#             cur_new_input_embeds = torch.cat(cur_new_input_embeds)
#             cur_new_labels = torch.cat(cur_new_labels)

#             new_input_embeds.append(cur_new_input_embeds)
#             new_labels.append(cur_new_labels)

#         tokenizer_model_max_length = 20480
#         if tokenizer_model_max_length is not None:
#             new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
#             new_labels = [x[:tokenizer_model_max_length] for x in new_labels]

#         new_input_embeds_padded = pad_sequence(new_input_embeds, batch_first=True, padding_value=0)
#         new_labels_padded = pad_sequence(new_labels, batch_first=True, padding_value=IGNORE_INDEX)

#         lengths = [seq.shape[0] for seq in new_labels]
#         batch_size, max_len = new_labels_padded.shape
#         attention_mask = torch.zeros((batch_size, max_len), dtype=torch.bool, device=new_labels_padded.device)
#         for i, l in enumerate(lengths):
#             attention_mask[i, :l] = True

#         position_ids = torch.arange(max_len, device=new_labels_padded.device).unsqueeze(0).expand(batch_size, max_len)
        
#         if _labels is None:
#             new_labels = None
#         else:
#             new_labels = new_labels_padded

#         if _attention_mask is None:
#             attention_mask = None
#         else:
#             attention_mask = attention_mask.to(dtype=_attention_mask.dtype)

#         if _position_ids is None:
#             position_ids = None

#         return None, position_ids, attention_mask, past_key_values, new_input_embeds_padded, new_labels


class EventChatQwenCausalLM(Qwen2ForCausalLM):

    config_class = EventChatQwenConfig

    def __init__(self, config) -> None:
        # super(Qwen2ForCausalLM, self).__init__(config)       
        Qwen2ForCausalLM.__init__(self, config)
        config.model_type = "EventChat_Qwen"
        config.rope_scaling = None
        self.model = EventChatQwenModel(config)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        self.post_init()

    def get_model(self):
        return self.model
    
    def get_visual_tower(self):
        return self.get_model().visual_tower
    
    
    def visval_encode(self, event_tensor):
        if event_tensor[0].shape[0] and event_tensor[0].shape[1] != 224:
            event_tensor = self.get_model().visual_tower.event_processor(images=event_tensor, return_tensors="pt")['pixel_values']
            event_tensor = event_tensor.to(self.get_model().device)
            event_tensor = event_tensor.to(torch.bfloat16)
        with torch.no_grad():
            try:
                outputs = self.get_model().visual_tower.visual_tower(event_tensor)
            except Exception as e:
                raise Exception(f"报错了: {event_tensor}, 错误信息: {str(e)}, {event_tensor[0].shape[0]}, {event_tensor[0].shape[1]}, type(event_tensor): {type(event_tensor)}")
        events_feature = outputs.last_hidden_state
        events_feature = events_feature.detach().requires_grad_(True)
        events_feature = self.get_model().visual_projector(events_feature)
        return event_tensor, events_feature

    def initialize_vision_tokenizer(self, model_args, tokenizer):
        if model_args.mm_use_im_patch_token:
            tokenizer.add_tokens([DEFAULT_EVENT_PATCH_TOKEN], special_tokens=True)
            self.resize_token_embeddings(len(tokenizer))

        if model_args.mm_use_im_start_end:
            num_new_tokens = tokenizer.add_tokens([DEFAULT_EV_START_TOKEN, DEFAULT_EV_END_TOKEN], special_tokens=True)
            self.resize_token_embeddings(len(tokenizer))

            if num_new_tokens > 0:
                input_embeddings = self.get_input_embeddings().weight.data
                output_embeddings = self.get_output_embeddings().weight.data

                input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
                    dim=0, keepdim=True)
                output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
                    dim=0, keepdim=True)

                input_embeddings[-num_new_tokens:] = input_embeddings_avg
                output_embeddings[-num_new_tokens:] = output_embeddings_avg

            if model_args.tune_mm_mlp_adapter:
                for p in self.get_input_embeddings().parameters():
                    p.requires_grad = True
                for p in self.get_output_embeddings().parameters():
                    p.requires_grad = False

            if model_args.pretrain_mm_mlp_adapter:
                mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
                embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
                assert num_new_tokens == 2
                if input_embeddings.shape == embed_tokens_weight.shape:
                    input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
                elif embed_tokens_weight.shape[0] == num_new_tokens:
                    input_embeddings[-num_new_tokens:] = embed_tokens_weight
                else:
                    raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
        elif model_args.mm_use_im_patch_token:
            if model_args.tune_mm_mlp_adapter:
                for p in self.get_input_embeddings().parameters():
                    p.requires_grad = False
                for p in self.get_output_embeddings().parameters():
                    p.requires_grad = False

    
    def forward(self, 
            event_tensors: Optional[torch.FloatTensor] = None,
            input_ids: torch.LongTensor = None, 
            labels: Optional[torch.LongTensor] = None,
            event_data: Optional[torch.FloatTensor] = None,
            event_feature: Optional[torch.FloatTensor] = None,
            inputs_embeds: Optional[torch.FloatTensor] = None,
            position_ids: Optional[torch.LongTensor] = None,
            attention_mask: Optional[torch.Tensor] = None,
            past_key_values: Optional[List[torch.FloatTensor]] = None,
            use_cache: Optional[bool] = None,
            event_image_sizes : Optional[List[List[int]]] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None):
    
        if inputs_embeds is None:
            (
                input_ids,
                position_ids,
                attention_mask,
                past_key_values,
                inputs_embeds,
                labels
            ) = self.prepare_inputs_labels_for_multimodal(
                input_ids, 
                position_ids,
                attention_mask,
                past_key_values,
                labels,
                event_data,
                event_feature,
                event_image_sizes           
            )
            
        outputs = super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            labels=labels,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict
        )      

        torch.cuda.synchronize()
        return outputs
    
    @torch.no_grad()
    def generate(
        self,
        inputs: Optional[torch.Tensor] = None,
        event_tensors: Optional[torch.Tensor] = None,
        event_image_sizes: Optional[torch.Tensor] = None,
        event_data=None,
        event_feature = None,
        **kwargs,
    ) -> Union[GenerateOutput, torch.LongTensor]:
        position_ids = kwargs.pop("position_ids", None)
        attention_mask = kwargs.pop("attention_mask", None)
        if "inputs_embeds" in kwargs:
            raise NotImplementedError("`inputs_embeds` is not supported")
        
        if event_tensors is not None or event_data is not None:
            (
                inputs,
                position_ids,
                attention_mask,
                _,
                inputs_embeds,
                _
            ) = self.prepare_inputs_labels_for_multimodal(
                inputs,
                position_ids,
                attention_mask,
                None,
                None,
                event_data,
                event_feature,
                event_image_sizes=event_image_sizes
            )
        else:
            raise NotImplementedError("please input Event")
        
        return super().generate(
            position_ids=position_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            **kwargs
        )
    
    def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
                                      inputs_embeds=None, **kwargs):
        event_tensors = kwargs.pop("event_tensors", None)
        event_image_sizes = kwargs.pop("event_sizes", None)
        inputs = super().prepare_inputs_for_generation(
            input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
        )
        if event_tensors is not None:
            inputs['event_tensors'] = event_tensors
        if event_image_sizes is not None:
            inputs['event_image_sizes'] = event_image_sizes
        return inputs
        
    
    def prepare_inputs_labels_for_multimodal(
        self, input_ids, position_ids, attention_mask, past_key_values, labels,
        event_data, event_feature, event_image_sizes=None
    ):  
        if event_data is None or input_ids.shape[1] == 1:
            return input_ids, position_ids, attention_mask, past_key_values, None, labels
        
        event_windows = []
        events_density_tensor_list = []
        sparsity_factors_list = []
        # sample_event_indices = []
        
        event_feature = list(event_feature)
        current_idx = 0
        if not (isinstance(event_data, list) and all(isinstance(item, list) for item in event_data)):
            event_data = [event_data]
            event_feature = torch.tensor(event_feature).to(self.device).to(torch.bfloat16)
            event_feature = [event_feature]
                     
        for ev_data, ev_feature in zip(event_data, event_feature):
            ev_window = ev_data          
            ev_window = convert_data(ev_window)
            
            sparsity_factors = np.array([compute_sparsity_factor(window) for window in ev_window])
            ev_feature = ev_feature[0: len(sparsity_factors)]  
        
            aggregated_windows = adaptive_aggregate_optimized(ev_window, ev_feature, sparsity_factors)
            # aggregated_windows = ev_window
            # visualize_event_stream(aggregated_windows)
            # event_windows.append(aggregated_windows)
            # ev_window = ev_window[::4]
            ev_window = aggregated_windows
            ev_density = compute_normalized_event_density_batch(ev_window)
            ev_density_tensor = [torch.tensor(ev_density[i], device=self.device, dtype=torch.bfloat16) for i in range(len(ev_density))]
            
            events_density_tensor_list.append(ev_density_tensor)
            event_windows.append(ev_window)

        event_tensor = []
        event_tensor_list = []
        ev_features_lengths = []  
        for ev_w in event_windows:
            for ev_w_i in ev_w:
                x, y, p = ev_w_i[:, 0], ev_w_i[:, 1], ev_w_i[:, 3]
                ev_img = generate_event_image(x, y, p)            
                event_tensor.append(ev_img)
            event_tensor_list.append(event_tensor)
            event_tensor = []        
        
        feature_list = []
        density_list = []
        for e_t, e_d in zip(event_tensor_list, events_density_tensor_list):
            event_density = e_d
            _, fea = self.visval_encode(e_t)
            filter_token, filter_token_index = select_non_white_tokens_batch(e_t, fea)
            
            merged_tokens_list, merged_density_list = token_merge(filter_token, filter_token_index, event_density)
            merged_density_tensor = torch.cat(merged_density_list, dim=0)
            
            all_non_white_tokens = torch.cat(merged_tokens_list, dim=0)
            all_non_white_tokens = all_non_white_tokens.to(torch.bfloat16)
            
            # print(f"all_non_white_tokens: {all_non_white_tokens.shape}")
            # fea = self.get_model().feature_adaptor(fea)
            # all_non_white_tokens = get_spatio_temporal_features(all_non_white_tokens)
            density_list.append(merged_density_tensor)
            feature_list.append(all_non_white_tokens)  
            
            ev_features_lengths.append(all_non_white_tokens.shape[0])  
        padded_event_features = pad_sequence(feature_list, batch_first=True, padding_value=0)
        merged_density_tensor = pad_sequence(density_list, batch_first=True, padding_value=0)
        att_mask = (padded_event_features.abs().sum(dim=-1) != 0).long()  # [B, N_max]
        
        event_features = padded_event_features          
        event_features = self.get_model().SDGA_module(event_features, merged_density_tensor, att_mask)    
        ev_features_lengths = [event_features.shape[1]] * event_features.shape[0]
               
        _labels = labels
        _position_ids = position_ids
        _attention_mask = attention_mask
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
        else:
            attention_mask = attention_mask.bool()
        if position_ids is None:
            position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
        if labels is None:
            labels = torch.full_like(input_ids, IGNORE_INDEX)

        _input_ids = input_ids
        input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
        labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]

        new_input_embeds = []
        new_labels = []
        cur_event_idx = 0
        for batch_idx, cur_input_ids in enumerate(input_ids):
            num_events = (cur_input_ids == EVENT_TOKEN_INDEX).sum()
            if num_events == 0:
                cur_event_features = event_features[cur_event_idx]
                cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
                cur_input_embeds = torch.cat([cur_input_embeds_1, cur_event_features[0:0]], dim=0)
                new_input_embeds.append(cur_input_embeds)
                new_labels.append(labels[batch_idx])
                cur_event_idx += 1
                continue

            event_token_indices = [-1] + torch.where(cur_input_ids == EVENT_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
            cur_input_ids_noim = []
            cur_labels = labels[batch_idx]
            cur_labels_noim = []
            for i in range(len(event_token_indices) - 1):
                cur_input_ids_noim.append(cur_input_ids[event_token_indices[i]+1:event_token_indices[i+1]])
                cur_labels_noim.append(cur_labels[event_token_indices[i]+1:event_token_indices[i+1]])
            split_sizes = [x.shape[0] for x in cur_labels_noim]
            cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
            cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
            cur_new_input_embeds = []
            cur_new_labels = []

            for i in range(num_events + 1):
                cur_new_input_embeds.append(cur_input_embeds_no_im[i])
                cur_new_labels.append(cur_labels_noim[i])
                if i < num_events:
                    valid_length = ev_features_lengths[cur_event_idx]
                    cur_event_features = event_features[cur_event_idx][:valid_length]
                    cur_event_idx += 1
                    cur_new_input_embeds.append(cur_event_features)
                    cur_new_labels.append(torch.full((valid_length,), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
 
            cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
            cur_new_input_embeds = torch.cat(cur_new_input_embeds)
            cur_new_labels = torch.cat(cur_new_labels)

            new_input_embeds.append(cur_new_input_embeds)
            new_labels.append(cur_new_labels)

        tokenizer_model_max_length = 10240
        if tokenizer_model_max_length is not None:
            new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
            new_labels = [x[:tokenizer_model_max_length] for x in new_labels]

        new_input_embeds_padded = pad_sequence(new_input_embeds, batch_first=True, padding_value=0)
        new_labels_padded = pad_sequence(new_labels, batch_first=True, padding_value=IGNORE_INDEX)

        lengths = [seq.shape[0] for seq in new_labels]
        batch_size, max_len = new_labels_padded.shape
        attention_mask = torch.zeros((batch_size, max_len), dtype=torch.bool, device=new_labels_padded.device)
        for i, l in enumerate(lengths):
            attention_mask[i, :l] = True

        position_ids = torch.arange(max_len, device=new_labels_padded.device).unsqueeze(0).expand(batch_size, max_len)
        
        if _labels is None:
            new_labels = None
        else:
            new_labels = new_labels_padded

        if _attention_mask is None:
            attention_mask = None
        else:
            attention_mask = attention_mask.to(dtype=_attention_mask.dtype)

        if _position_ids is None:
            position_ids = None

        return None, position_ids, attention_mask, past_key_values, new_input_embeds_padded, new_labels


# AutoConfig.register("EventChat_llama", EventChatLLaMAConfig)
# AutoModelForCausalLM.register(EventChatLLaMAConfig, EventChatLLaMAModel)
AutoConfig.register("EventChat_Qwen", EventChatQwenConfig)
AutoModelForCausalLM.register(EventChatQwenConfig, EventChatQwenModel)
