from typing import List, Optional, Tuple, Union
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss

from transformers import AutoConfig, AutoModelForCausalLM, \
                         LlamaConfig, LlamaModel, LlamaForCausalLM

from transformers.modeling_outputs import CausalLMOutputWithPast

from utils.constants import IGNORE_INDEX, GRAPH_TOKEN_INDEX, DEFAULT_GRAPH_START_TOKEN, DEFAULT_GRAPH_END_TOKEN, DEFAULT_GRAPH_PAD_ID
from torch_geometric.nn import GCNConv, GINConv, GATv2Conv, SAGEConv, LayerNorm, SAGPooling
from torch_geometric.nn.models import MLP, GCN, GAT, GIN
from torch_geometric.data import Data, Batch
from torch_geometric.utils import unbatch, subgraph, remove_isolated_nodes, dense_to_sparse, coalesce, to_dense_adj, add_remaining_self_loops, unbatch_edge_index, coalesce
from torch_geometric.transforms import GDC

LIMIT_LEFT = -0.1
LIMIT_RIGHT = 1.1
EPS = 1e-6
TEMPERATURE = 2 / 3
FACTOR = 0.8

def sample_z_from_u(u, log_alpha):
    s = torch.sigmoid((torch.log(u) - torch.log(1 - u) + log_alpha) / TEMPERATURE)
    return (LIMIT_RIGHT - LIMIT_LEFT) * s + LIMIT_LEFT

def sample_z_from_log_alpha(log_alpha):
    u = torch.autograd.Variable(torch.FloatTensor(log_alpha.shape).uniform_(EPS, 1 - EPS)).to(log_alpha.device)
    z = sample_z_from_u(u, log_alpha)
    z = F.hardtanh(z, 0, 1)
    
    return z

class LlagaConfig(LlamaConfig):
    model_type = "llaga"
    gnn_type = 'linear'
    num_adapeter = 2

class LlagaModel(LlamaModel):
    config_class = LlagaConfig

    def __init__(self, config: LlamaConfig):
        super(LlagaModel, self).__init__(config)

class LlagaForCausalLM(LlamaForCausalLM):
    config_class = LlagaConfig

    # def __init__(self, config, gnn_type='gin', num_adapter=5, node_size=2048):
    def __init__(self, config):
        super(LlamaForCausalLM, self).__init__(config)
        self.model = LlagaModel(config)

        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        
        if self.config.gnn_type == 'linear':
            self.adapter = MLP(in_channels=config.node_size, hidden_channels=config.node_size, out_channels=config.hidden_size, num_layers=config.num_adapter, norm='layer')
        elif self.config.gnn_type == 'gcn':
            self.adapter = GCN(in_channels=config.node_size, hidden_channels=config.node_size, out_channels=config.hidden_size, num_layers=config.num_adapter, norm='layer')
        elif self.config.gnn_type == 'gat':
            self.adapter = GAT(in_channels=config.node_size, hidden_channels=config.node_size, out_channels=config.hidden_size, num_layers=config.num_adapter, norm='layer')
        elif self.config.gnn_type == 'gin':
            self.adapter = GIN(in_channels=config.node_size, hidden_channels=config.node_size, out_channels=config.hidden_size, num_layers=config.num_adapter, norm='layer')
        else:
            print('Need to specify the adapter type')
            
        if hasattr(config, 'graph_transform'):
            if 'dc' in getattr(config, 'graph_transform'):
                self.transform = GDC(
                    self_loop_weight=1,
                    normalization_in='sym',
                    normalization_out='col',
                    diffusion_kwargs=dict(method='ppr', alpha=0.05),
                    sparsification_kwargs=dict(method='threshold', avg_degree=5),
                    exact=True,
                )
            elif 'attn' in getattr(config, 'graph_transform'):
                self.node_adapter = GATv2Conv(
                    in_channels=config.node_size, 
                    out_channels=config.node_size, 
                    edge_dim=config.node_size * 2,
                    # heads=config.num_attention_heads // math.ceil(config.hidden_size / config.node_size),
                    add_self_loops=False
                )
                self.attr_adapter = nn.Linear(config.node_size, 1, bias=False)
                # self.edge_attr_adapter = nn.Linear(config.node_size, 1, bias=False)
        else:
            self.transform = None
        
                    
        # self.adapter_pooling = SAGPooling(
        #     in_channels=config.node_size, 
        #     ratio=0.3, 
        #     GNN=GATv2Conv, 
        #     min_score=0, 
        #     nonlinearity='relu'
        # )
        
        
        # self.adapter_norms = nn.ModuleList([nn.LayerNorm(config.node_size) for _ in range(config.num_adapter)])
        # self.adapter_projection = nn.Linear(config.node_size, config.hidden_size)
        
        # if self.config.gnn_type == 'linear':
        #     self.adapter = nn.ModuleList([nn.Linear(config.node_size, config.node_size) for _ in range(config.num_adapter)])
        # elif self.config.gnn_type == 'gcn':
        #     self.adapter = nn.ModuleList([GCNConv(config.node_size, config.node_size) for _ in range(config.num_adapter)])
        # elif self.config.gnn_type == 'gat':
        #     self.adapter = nn.ModuleList([GATv2Conv(config.node_size, config.node_size) for _ in range(config.num_adapter)])
        # elif self.config.gnn_type == 'gin':
        #     self.adapter = nn.ModuleList([GINConv(nn.Linear(config.node_size, config.node_size)) for _ in range(config.num_adapter)])
        # else:
        #     print('Need to specify the adapter type')
        #     raise(ValueError)
        # print(self.adapter)
        # exit
        
        # Initialize weights and apply final processing
        self.post_init()

    def get_model(self):
        return self.model
    
    def transform_graph(self, graph):
        if getattr(self.config, 'graph_transform') == 'gdc':
            # graphs = self.transform(graph)
            if hasattr(graph, 'num_graphs'):
                graphs = Batch.to_data_list(graph)
                for i, graph in enumerate(graphs):
                    # row, col = graph.edge_index
                                             
                    # selected_nodes = torch.unique(torch.cat([row[row == 0], col[row == 0], row[col == 0], col[col == 0]], dim=-1))
                    # edge_index, _ = subgraph(selected_nodes, graph.edge_index, relabel_nodes=True)
                    # print('-'*20)
                    # print(edge_index)
                    if graph.x.size(0) == 1:
                        graphs[i] = Data(x=graph.x, edge_index=graph.edge_index)
                        continue
                    try:
                        graph = self.transform(graph)
                    except:
                        continue
                    row, col = graph.edge_index
                                             
                    selected_nodes = torch.unique(torch.cat([row[row == 0], col[row == 0], row[col == 0], col[col == 0]], dim=-1))
                    edge_index, _ = subgraph(selected_nodes, graph.edge_index, relabel_nodes=True, num_nodes=graph.num_nodes)
                    # print(edge_index)
                    # exit()
                    graphs[i] = Data(x=graph.x[selected_nodes], edge_index=edge_index)
                    # print(graph, graphs[i])
                    # exit()
                graph = Batch.from_data_list(graphs)
            else:
                try:
                    graph = self.transform(graph)
                    row, col = graph.edge_index
                                            
                    selected_nodes = torch.unique(torch.cat([row[row == 0], col[row == 0], row[col == 0], col[col == 0]], dim=-1))
                    edge_index, _ = subgraph(selected_nodes, graph.edge_index, relabel_nodes=True, num_nodes=graph.num_nodes)
                    graph = Data(x=graph.x[selected_nodes], edge_index=edge_index)
                except:
                    pass
                
            edge_loss = None
            edge_feature = None
            edge_activation = None
            node_loss = None
        elif getattr(self.config, 'graph_transform') == 'sgdc':
            # graphs = self.transform(graph)
            if hasattr(graph, 'num_graphs'):
                graphs = Batch.to_data_list(graph)
                for i, graph in enumerate(graphs):
                    # row, col = graph.edge_index
                                             
                    # selected_nodes = torch.unique(torch.cat([row[row == 0], col[row == 0], row[col == 0], col[col == 0]], dim=-1))
                    # edge_index, _ = subgraph(selected_nodes, graph.edge_index, relabel_nodes=True)
                    # print('-'*20)
                    # print(edge_index)
                    
                    if graph.x.size(0) == 1:
                        graphs[i] = Data(x=graph.x, edge_index=graph.edge_index)
                        continue
                    # Semantic Graph
                    if hasattr(graph, 'similarity_index'):
                        similarity_index = graph.similarity_index
                    else:
                        distance = F.cosine_similarity(graph.x.unsqueeze(1), graph.x.unsqueeze(0),dim=-1)
                        avg_distance = torch.ones_like(distance) * distance.mean(dim=-1, keepdim=True)
                        adjacency_matrix = torch.where(distance > avg_distance, 1, 0)
                        similarity_index, _ = dense_to_sparse(adjacency_matrix)
                    
                    new_graph = Data(x=graph.x, edge_index=coalesce(torch.cat([graph.edge_index, similarity_index], dim=-1)))
                    try:
                        new_graph = self.transform(new_graph)
                    except:
                        continue
                    row, col = new_graph.edge_index
                                             
                    selected_nodes = torch.unique(torch.cat([row[row == 0], col[row == 0], row[col == 0], col[col == 0]], dim=-1))
                    
                    edge_index, _ = subgraph(selected_nodes, new_graph.edge_index, relabel_nodes=True, num_nodes=graph.num_nodes)
                    # print(edge_index)
                    # exit()
                    graphs[i] = Data(x=graph.x[selected_nodes], edge_index=edge_index)
                    # print(graph, graphs[i])
                    # exit()
                graph = Batch.from_data_list(graphs)
            else:
                if hasattr(graph, 'similarity_index'):
                    similarity_index = graph.similarity_index
                else:
                    distance = F.cosine_similarity(graph.x.unsqueeze(1), graph.x.unsqueeze(0),dim=-1)
                    avg_distance = torch.ones_like(distance) * distance.mean(dim=-1, keepdim=True)
                    adjacency_matrix = torch.where(distance > avg_distance, 1, 0)
                    similarity_index, _ = dense_to_sparse(adjacency_matrix)
                
                new_graph = Data(x=graph.x, edge_index=coalesce(torch.cat([graph.edge_index, similarity_index], dim=-1)))
                try:
                    new_graph = self.transform(new_graph)
                    row, col = new_graph.edge_index
                                            
                    selected_nodes = torch.unique(torch.cat([row[row == 0], col[row == 0], row[col == 0], col[col == 0]], dim=-1))
                    
                    edge_index, _ = subgraph(selected_nodes, new_graph.edge_index, relabel_nodes=True, num_nodes=graph.num_nodes)
                    graph = Data(x=graph.x[selected_nodes], edge_index=edge_index)
                except:
                    graph = new_graph
            edge_loss = None
            edge_feature = None
            edge_activation = None
            node_loss = None
        elif getattr(self.config, 'graph_transform') == 'sdc':
            # graphs = self.transform(graph)
            if hasattr(graph, 'num_graphs'):
                graphs = Batch.to_data_list(graph)
                for i, graph in enumerate(graphs):
                    # row, col = graph.edge_index
                                             
                    # selected_nodes = torch.unique(torch.cat([row[row == 0], col[row == 0], row[col == 0], col[col == 0]], dim=-1))
                    # edge_index, _ = subgraph(selected_nodes, graph.edge_index, relabel_nodes=True)
                    # print('-'*20)
                    # print(edge_index)
                    
                    if graph.x.size(0) == 1:
                        graphs[i] = Data(x=graph.x, edge_index=graph.edge_index)
                        continue
                    # Semantic Graph
                    if hasattr(graph, 'similarity_index'):
                        new_graph = Data(x=graph.x, edge_index=graph.similarity_index)
                        new_graph = self.transform(new_graph)
                        row, col = new_graph.edge_index
                                             
                        selected_nodes = torch.unique(torch.cat([row[row == 0], col[row == 0], row[col == 0], col[col == 0]], dim=-1))
                        
                        edge_index, _ = subgraph(selected_nodes, new_graph.edge_index, relabel_nodes=True, num_nodes=new_graph.num_nodes)
                    else:
                        distance = F.cosine_similarity(graph.x.unsqueeze(1), graph.x.unsqueeze(0),dim=-1)
                        avg_distance = torch.ones_like(distance) * distance.mean(dim=-1, keepdim=True)
                        adjacency_matrix = torch.where(distance > avg_distance, 1, 0)
                        similarity_index, _ = dense_to_sparse(adjacency_matrix)
                        
                        new_graph = Data(x=graph.x, edge_index=similarity_index)
                        new_graph = self.transform(new_graph)
                        row, col = new_graph.edge_index
                                             
                        selected_nodes = torch.unique(torch.cat([row[row == 0], col[row == 0], row[col == 0], col[col == 0]], dim=-1))
                        
                        edge_index, _ = subgraph(selected_nodes, new_graph.edge_index, relabel_nodes=True, num_nodes=graph.num_nodes)
                    # print(edge_index)
                    # exit()
                    graphs[i] = Data(x=graph.x[selected_nodes], edge_index=edge_index)
                    # print(graph, graphs[i])
                    # exit()
                graph = Batch.from_data_list(graphs)
            else:
                if hasattr(graph, 'similarity_index'):
                    new_graph = Data(x=graph.x, edge_index=graph.similarity_index)
                else:
                    distance = F.cosine_similarity(graph.x.unsqueeze(1), graph.x.unsqueeze(0),dim=-1)
                    avg_distance = torch.ones_like(distance) * distance.mean(dim=-1, keepdim=True)
                    adjacency_matrix = torch.where(distance > avg_distance, 1, 0)
                    similarity_index, _ = dense_to_sparse(adjacency_matrix)
                    
                    new_graph = Data(x=graph.x, edge_index=similarity_index)
                
                try:
                    new_graph = self.transform(new_graph)
                    row, col = new_graph.edge_index
                                            
                    selected_nodes = torch.unique(torch.cat([row[row == 0], col[row == 0], row[col == 0], col[col == 0]], dim=-1))
                    
                    edge_index, _ = subgraph(selected_nodes, new_graph.edge_index, relabel_nodes=True, num_nodes=graph.num_nodes)
                    graph = Data(x=graph.x[selected_nodes], edge_index=edge_index)
                except:
                    graph = new_graph
            edge_loss = None
            edge_feature = None
            edge_activation = None
            node_loss = None
        elif 'attn' in getattr(self.config, 'graph_transform'):
            # edge_index, _ = add_remaining_self_loops(graph.edge_index, num_nodes=graph.num_nodes)
            # try:
            #     graphs = Batch.to_data_list(graph)
            # except:
            #     graphs = [graph]
            # new_graphs = []
            # for graph in graphs:
            #     add_on_edges = torch.LongTensor([
            #         [0] * graph.num_nodes,
            #         list(range(graph.num_nodes))
            #     ]).to(graph.x.device)
            #     edge_index = coalesce(torch.cat([graph.edge_index, add_on_edges], dim=-1))
            #     new_graphs.append(Data(x=graph.x, edge_index=edge_index))
            # new_graph = Batch.from_data_list(new_graphs)
            # row, col = new_graph.edge_index
            # edge_attr = torch.cat([new_graph.x[row], new_graph.x[col]], dim=-1)
            # # print(edge_attr.shape)
            # node_weights = F.silu(self.node_adapter(new_graph.x, new_graph.edge_index, edge_attr=edge_attr))
            # # print(node_weights.shape)
            # new_graph.x = new_graph.x * F.sigmoid(self.attr_adapter(node_weights)).squeeze().unsqueeze(-1)
            # graph = new_graph
            
            row, col = graph.edge_index
            edge_attr = torch.cat([graph.x[row], graph.x[col]], dim=-1)
            # print(edge_attr.shape)
            node_weights = F.silu(self.node_adapter(graph.x, graph.edge_index, edge_attr=edge_attr))
            node_weights = F.sigmoid(self.attr_adapter(node_weights)).squeeze()
            # print(node_weights.shape)
            new_graph = Data(x=graph.x * node_weights.unsqueeze(-1), edge_index=graph.edge_index, batch=graph.batch)
            if self.training:
                edge_loss = F.binary_cross_entropy(node_weights, graph.sparsity_label.to(node_weights.dtype))
            else:
                edge_loss = None
            graph = new_graph
            # print(edge_loss)
        else:
            return graph, None, None, None
        
        return graph, edge_loss, None, None
        return graph, edge_loss, edge_feature, edge_activation
        # return graph, edge_loss, edge_feature, edge_activation
    
    # def graph_transform(self, graph):
    #     if getattr(self.config, 'graph_transform') == 'attn':
    #         row, col = graph.edge_index
    #         edge_feature = torch.cat([graph.x[row], graph.x[col]], dim=-1)
    #         print(graph.x.shape, graph.edge_index.shape)
    #         edge_feature = self.edge_adapter_norm(F.silu(self.edge_adapter(edge_feature)))
    #         node_feature = F.silu(self.node_adapter(graph.x))
    #         if hasattr(graph, 'batch') and graph.batch is not None:
    #             node_features = unbatch(node_feature, graph.batch)
    #         else:
    #             node_features = [node_feature]
            
    #         node_activation = []
    #         for nodes in node_features:
    #             if nodes.ndim < 2:
    #                 nodes = nodes.unsqueeze(0)
    #             node_activation.append(F.gelu(F.cosine_similarity(nodes, nodes[0, :].squeeze().unsqueeze(0), dim=-1)))
    #         node_activation = torch.cat(node_activation)
            
    #         edge_activation = F.sigmoid(F.linear(edge_feature, graph.x.mean(dim=0)))
    #         adjacency_matrix = to_dense_adj(graph.edge_index, edge_attr=edge_activation, max_num_nodes=graph.x.size(0)).squeeze()
    #         adjacency_matrix = torch.where(to_dense_adj(graph.edge_index, max_num_nodes=graph.x.size(0)).squeeze() == 0, -10, adjacency_matrix)
    #         adjacency_matrix = sample_z_from_log_alpha(adjacency_matrix)
            
    #         edge_loss = adjacency_matrix.sum() / graph.edge_index.size(-1)
    #         graph = Data(x=graph.x * node_activation.squeeze().unsqueeze(-1), edge_index=graph.edge_index, batch=graph.batch)
    #         # print(adjacency_matrix.shape, edge_loss, adjacency_matrix)
    #         # exit()
    #     else:
    #         return graph, None, None, None
    
    def prepare_inputs_labels_for_multimodal(
        self, input_ids, attention_mask, past_key_values, labels, first_graph, second_graph
    ):
        # first_node_embeds = self.adapter_forward(first_graph.x, first_graph.edge_index)
        # print(first_graph.x.dtype)
        # exit()
        # first_node_embeds = self.adapter_forward(first_graph.x, first_graph.edge_index)
        # new_x, new_edge_index, _, new_batch, _, _ = self.adapter_pooling(first_graph)
        # first_node_embeds = self.adapter(new_x, new_edge_index)
        # if first_graph.batch is not None:
        #     first_node_embeds = unbatch(first_node_embeds, new_batch)
        # else:
        #     first_node_embeds = [first_node_embeds]
        if hasattr(self.config, 'graph_transform') and self.config.graph_transform in ['gdc', 'attn', 'sgdc', 'sdc', 'biattn']:
            # print(self.config.graph_transform)
            graph, graph_sparsity, graph_edge_attr, edge_activation = self.transform_graph(first_graph)
            # print(graph_sparsity)
            # print(first_graph)
            # print(graph)
            # exit()
            # print(edge_activation)
            # exit()
        else:
            graph = first_graph
            graph_sparsity = None
            graph_edge_attr = None
            edge_activation = None
            
        if graph_edge_attr is None or self.config.gnn_type == 'linear':
            # print(graph.edge_index, graph.num_nodes)
            # exit()
            first_node_embeds = self.adapter(graph.x, graph.edge_index)
        else:
            first_node_embeds = self.adapter(graph.x, graph.edge_index, edge_weight=edge_activation, edge_attr=graph_edge_attr)
        
        if graph.batch is not None:
            first_node_embeds = unbatch(first_node_embeds, graph.batch)
        else:
            first_node_embeds = [first_node_embeds]
        if second_graph is not None: # and second_graph.edge_index.size(-1) != 0:
            # second_node_embeds = self.adapter_forward(second_graph.x, second_graph.edge_index)
            if hasattr(self.config, 'graph_transform') and self.config.graph_transform in ['gdc', 'attn', 'sgdc', 'sdc', 'biattn']:
                graph2, graph2_sparsity, graph2_edge_attr, graph2_edge_activation = self.transform_graph(second_graph)
            else:
                graph2 = second_graph
                graph2_sparsity = None
                graph2_edge_attr = None
                graph2_edge_activation = None
                
            if graph2_edge_attr is None or self.config.gnn_type == 'linear':
                second_node_embeds = self.adapter(graph2.x, graph2.edge_index)
            else:    
                second_node_embeds = self.adapter(graph2.x, graph2.edge_index, edge_weight=graph2_edge_activation, edge_attr=graph2_edge_attr)
            if second_graph.batch is not None:
                second_node_embeds = unbatch(second_node_embeds, graph2.batch)
            else:
                second_node_embeds = [second_node_embeds]
            max_extra_length = max([first.size(0) + second.size(0) for first, second in zip(first_node_embeds, second_node_embeds)])
        else:
            second_node_embeds = None
            graph2_sparsity = None
            
            max_extra_length = max([first.size(0)for first in first_node_embeds])
        if past_key_values is not None and first_graph is not None and input_ids.shape[1] == 1:
            attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1),
                                        dtype=attention_mask.dtype, device=attention_mask.device)
            return input_ids, attention_mask, past_key_values, None, labels, None, None
            
        new_input_embeds = []
        new_labels = [] if labels is not None else None
        cur_graph_idx = 0
        
        for batch_idx, cur_input_ids in enumerate(input_ids):
            if (cur_input_ids == GRAPH_TOKEN_INDEX).sum() == 0:
                # multimodal LLM, but the current sample is not multimodal
                # FIXME: this is a hacky fix, for deepspeed zero3 to work
                half_len = cur_input_ids.shape[0] // 2
                cur_graph_features = first_node_embeds[batch_idx]
                cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids[:half_len])
                cur_input_embeds_2 = self.get_model().embed_tokens(cur_input_ids[half_len:])
                cur_input_embeds = torch.cat([cur_input_embeds_1, cur_graph_features[0:0], cur_input_embeds_2], dim=0)
                new_input_embeds.append(cur_input_embeds)
                if labels is not None:
                    new_labels.append(labels[batch_idx])
                cur_graph_idx += 1
                continue
            graph_token_indices = torch.where(cur_input_ids == GRAPH_TOKEN_INDEX)[0]
            cur_new_input_embeds = []
            if labels is not None:
                cur_labels = labels[batch_idx]
                cur_new_labels = []
                assert cur_labels.shape == cur_input_ids.shape
            graph_index = 0
            cur_extra_length = 0
            while graph_token_indices.numel() > 0:
                if graph_index == 0:
                    cur_graph_features = first_node_embeds[batch_idx]
                elif graph_index == 1:
                    cur_graph_features = second_node_embeds[batch_idx]
                if cur_graph_features.ndim < 2:
                    cur_graph_features.unsqueeze(0)
                cur_extra_length += cur_graph_features.size(0)
                if hasattr(self.config, "mm_use_graph_special_token") and getattr(self.config, 'mm_use_graph_special_token', False):
                    cur_graph_features = self.inject_special_token(cur_graph_features)

                graph_token_start = graph_token_indices[0]
                if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_graph_start_end', False):
                    cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:graph_token_start-1]).detach())
                    cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[graph_token_start-1:graph_token_start]))
                    cur_new_input_embeds.append(cur_graph_features)
                    cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[graph_token_start+1:graph_token_start+2]))
                    if labels is not None:
                        cur_new_labels.append(cur_labels[:graph_token_start])
                        cur_new_labels.append(torch.full((cur_graph_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
                        cur_new_labels.append(cur_labels[graph_token_start:graph_token_start+1])
                        cur_labels = cur_labels[graph_token_start+2:]
                else:
                    cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:graph_token_start]))
                    cur_new_input_embeds.append(cur_graph_features)
                    if labels is not None:
                        cur_new_labels.append(cur_labels[:graph_token_start])
                        cur_new_labels.append(torch.full((cur_graph_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
                        cur_labels = cur_labels[graph_token_start+1:]
                cur_graph_idx += 1
                if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_graph_start_end', False):
                    cur_input_ids = cur_input_ids[graph_token_start+2:]
                else:
                    cur_input_ids = cur_input_ids[graph_token_start+1:]
                graph_index += 1
                graph_token_indices = torch.where(cur_input_ids == GRAPH_TOKEN_INDEX)[0]
            if cur_input_ids.numel() > 0:
                if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_graph_start_end', False):
                    cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids).detach())
                else:
                    cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids))
                if labels is not None:
                    cur_new_labels.append(cur_labels)
            cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds]
            if max_extra_length - cur_extra_length > 0:
                eos_token_embeds = cur_new_input_embeds[-1][-1].squeeze(0).unsqueeze(0).repeat(max_extra_length - cur_extra_length, 1)
                
                cur_new_input_embeds.append(eos_token_embeds)
                
            cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)
            new_input_embeds.append(cur_new_input_embeds)
            if labels is not None:
                if max_extra_length - cur_extra_length > 0:
                    cur_new_labels.append(torch.full((max_extra_length - cur_extra_length,), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
                cur_new_labels = torch.cat(cur_new_labels, dim=0)
                new_labels.append(cur_new_labels)

        if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds):
            max_len = max(x.shape[0] for x in new_input_embeds)

            new_input_embeds_align = []
            for cur_new_embed in new_input_embeds:
                cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)
                new_input_embeds_align.append(cur_new_embed)
            new_input_embeds = torch.stack(new_input_embeds_align, dim=0)

            if labels is not None:
                new_labels_align = []
                _new_labels = new_labels
                for cur_new_label in new_labels:
                    cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0)
                    new_labels_align.append(cur_new_label)
                new_labels = torch.stack(new_labels_align, dim=0)

            if attention_mask is not None:
                new_attention_mask = []
                for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels):
                    new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device)
                    new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device)
                    cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0)
                    new_attention_mask.append(cur_new_attention_mask)
                attention_mask = torch.stack(new_attention_mask, dim=0)
                assert attention_mask.shape == new_labels.shape
        else:
            new_input_embeds = torch.stack(new_input_embeds, dim=0)
            if labels is not None:
                new_labels  = torch.stack(new_labels, dim=0)

            if attention_mask is not None:
                new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device)
                attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1)
                assert attention_mask.shape == new_input_embeds.shape[:2]

        return None, attention_mask, past_key_values, new_input_embeds, new_labels, graph_sparsity, graph2_sparsity

    def adapter_forward(self, x, edge_index=None):
        if self.config.gnn_type != 'linear':
            for i in range(self.config.num_adapter):
                x = self.adapter_norms[i](F.relu(self.adapter[i](x, edge_index=edge_index)) + x)
        else:
            for i in range(self.config.num_adapter):
                # print(x.dtype)
                # exit()
                x = self.adapter_norms[i](F.relu(self.adapter[i](x)) + x)
        return self.adapter_projection(x)
                    
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        first_graph: Optional[Data] = None,
        second_graph: Optional[Data] = None,
        return_dict: Optional[bool] = None,
        target_sparsity: Optional[float] = 0.94,
        # cache_position: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        # return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        # for gnn in self.adapter:
        #     for name, param in gnn.named_parameters():
        #         print(name, param)
        # exit()

        input_ids, attention_mask, past_key_values, inputs_embeds, labels, graph_sparsity, graph2_sparsity = self.prepare_inputs_labels_for_multimodal(
            input_ids, 
            attention_mask, 
            past_key_values, 
            labels, 
            first_graph, 
            second_graph
        )
        
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            # cache_position=cache_position,
        )
        # print(inputs_embeds)
        # exit()
        hidden_states = outputs[0]
        logits = self.lm_head(hidden_states)

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss(ignore_index=IGNORE_INDEX)
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # Enable model/pipeline parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)
            
            if graph_sparsity is not None:
                # edge_loss = 1 - graph_sparsity
                # loss += 0.1 * ((target_sparsity - edge_loss) + (target_sparsity - edge_loss) ** 2)
                loss += 0.1 * graph_sparsity
            if graph2_sparsity is not None:
                edge_loss2 = 1 - graph2_sparsity
                loss += 0.1 * ((target_sparsity - edge_loss2) + (target_sparsity - edge_loss2) ** 2)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def prepare_inputs_for_generation(
        self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
    ):
        if past_key_values:
            input_ids = input_ids[:, -1:]

        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
        if inputs_embeds is not None and past_key_values is None:
            model_inputs = {"inputs_embeds": inputs_embeds}
        else:
            model_inputs = {"input_ids": input_ids}

        model_inputs.update(
            {
                "past_key_values": past_key_values,
                "use_cache": kwargs.get("use_cache"),
                "attention_mask": attention_mask,
                "first_graph": kwargs.get("first_graph", None),
                "second_graph": kwargs.get("second_graph", None),
            }
        )
        return model_inputs

AutoConfig.register("llaga", LlagaConfig)
AutoModelForCausalLM.register(LlagaConfig, LlagaForCausalLM)
