import torch
import transformers
from utils.utils import model_id
import os


class GenerateDataset(torch.utils.data.Dataset):
    """
    Dataset for generating token embeddings
    """
    def __init__(self, encodings):
        self.encodings = encodings
        
    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        return item
    
    def __len__(self):
        return self.encodings['input_ids'].size(0)
    
class ReductionDataset(torch.utils.data.Dataset):
    """
    Dataset for training reduction module
    """
    def __init__(self, embeddings, attention_mask, neighbor_embeddings, labels):
        self.embeddings = embeddings
        self.attention_mask = attention_mask
        self.neighbor_embeddings = neighbor_embeddings
        self.labels = labels
        
    def __getitem__(self, idx):
        return self.embeddings[idx], self.attention_mask[idx], self.neighbor_embeddings[idx], self.labels[idx]
    
    def __len__(self):
        return self.embeddings.size(0)

class NCDataset(torch.utils.data.Dataset):
    """
    Dataset for fine-tuning language models
    """
    def __init__(self, graphs, labels, tokenizer, text, config):
        self.labels = labels
        self.tokenizer = tokenizer        
        self.config = config
        
        #self.reduction_tokenizer = transformers.AutoTokenizer.from_pretrained(model_id[self.config.reduction_lm_type])
        model_path = f""
 
        self.reduction_tokenizer = transformers.AutoTokenizer.from_pretrained( 
            model_path,
            local_files_only=True  # 强制只使用本地文件
        )


        if self.config.use_reduction:
            reduction_file = os.path.join('out', 'reduction_out', f'{config.reduction_lm_type}', f'{config.dataset}.pt')
            print(f"Loading reduction file : {reduction_file}")
            save_reduction = torch.load(reduction_file)
            self.scores = save_reduction['score']
            self.attention_masks = save_reduction['attention_mask']
            self.encodings = save_reduction['encodings']
        else:
            # tokenizing the text
            self.encodings = [tokenizer.encode(txt, add_special_tokens=False) for txt in text]
            
        self.input_ids = []     
        self.attention_mask = []
        self.root_mask = []
               
        for graph in graphs:
            ids, input_mask, graph_root_mask = self._process(graph)
            self.input_ids.append(torch.tensor(ids))
            self.attention_mask.append(torch.tensor(input_mask))
            self.root_mask.append(torch.tensor(graph_root_mask))
        
        # Padding
        self.input_ids = torch.nn.utils.rnn.pad_sequence(
            self.input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )

        self.attention_mask = torch.nn.utils.rnn.pad_sequence(
            self.attention_mask, batch_first=True, padding_value=0
        )
        
        self.root_mask = torch.nn.utils.rnn.pad_sequence(
            self.root_mask, batch_first=True, padding_value=0
        )

    def _token_reduction_select(self, original_idx, doc_max_length):
        select_encodings = []
        
        for idx in original_idx:
            score = self.scores[idx]
            attention_mask = self.attention_masks[idx]
            encoding = self.encodings[idx]
            
            valid_token_len = attention_mask.sum().item()
            
            if self.config.lm_type == self.config.reduction_lm_type:  
                        
                # select topk importance tokens
                _, indices = score.topk(doc_max_length - 1 if valid_token_len >= doc_max_length - 1 else valid_token_len)
                
                # keep the tokens original position
                indices, _ = indices.sort() 
                
                select_encodings.append(encoding[indices].tolist())
                #print(f"max = {doc_max_length}, len = {len(indices)}")
            
            else:                
                num_tokens = int(doc_max_length * 1.2)
                _, indices = score.topk(num_tokens if valid_token_len >= num_tokens else valid_token_len)
                indices, _ = indices.sort()
                
                # First decode due to different tokenizer
                txt = self.reduction_tokenizer.decode(encoding[indices], skip_special_tokens=True)
                
                # encode using target LMs tokenizer
                enc = self.tokenizer.encode(txt, add_special_tokens=False)
                
                select_encodings.append(enc[:doc_max_length - 1])
                
        return select_encodings
                

    def _process(self, graph):
        origin_idx = graph.original_idx.tolist()
        
        root_origin_idx = graph.original_idx[graph.root_n_index].item()
        origin_idx.remove(root_origin_idx)
        origin_idx.append(root_origin_idx)
        graph.root_n_index = graph.center = len(origin_idx) - 1
        
        #print(f"max_length = {self.config.max_length}")
        doc_max_length = self.config.max_length // graph.num_nodes
        #print(f"num = {graph.num_nodes}")
        root_mask = []
        ids = []
        
        if self.config.use_reduction:
            _encodings = self._token_reduction_select(origin_idx, doc_max_length)
        
        for i, idx in enumerate(origin_idx):
            
            if self.config.use_reduction:
                # select important tokens via redcution module
                token_ids = _encodings[i]
            else:
                token_ids = self.encodings[idx][:doc_max_length - 1]
            
            # Add [SEP] token between nodes' text
            token_ids.append(self.tokenizer.sep_token_id)
                      
            if i == graph.root_n_index:
                _root_mask = [1 for _ in range(len(token_ids))]
                _root_mask[-1] = 0
            else:
                _root_mask = [0 for _ in range(len(token_ids))]
            
            root_mask.extend(_root_mask)
            
            ids.extend(token_ids)   
            
        input_mask = [1 for _ in range(len(ids))]

        #print(f"len = {len(ids)}")
            
        return ids, input_mask, root_mask
          
        
    def __getitem__(self, idx):
        item = {'input_ids': self.input_ids[idx], \
                'attention_mask': self.attention_mask[idx],
                'labels': self.labels[idx],
                'root_mask': self.root_mask[idx]}
        
        return item
    
    def __len__(self):
        return len(self.labels)

class B_NCDataset(torch.utils.data.Dataset):
    """
    Dataset for fine-tuning language models
    """
    def __init__(self, graphs, labels, tokenizer, text, config, hop):
        self.labels = labels
        self.tokenizer = tokenizer        
        self.config = config
        self.hop = hop
        
        model_path = f""
        self.reduction_tokenizer = transformers.AutoTokenizer.from_pretrained( 
            model_path,
            local_files_only=True  # 强制只使用本地文件
        )

        if self.config.use_reduction:
            reduction_file = os.path.join('out', 'reduction_out', f'{config.reduction_lm_type}', f'{config.dataset}.pt')
            print(f"Loading reduction file : {reduction_file}")
            save_reduction = torch.load(reduction_file)
            self.scores = save_reduction['score']
            self.attention_masks = save_reduction['attention_mask']
            self.encodings = save_reduction['encodings']
        else:
            # tokenizing the text
            self.encodings = [tokenizer.encode(txt, add_special_tokens=False) for txt in text]
            
        self.input_ids = []     
        self.attention_mask = []
        self.origin_attention_mask = []
        self.root_mask = []
               
        for graph in graphs:
            ids, input_mask, graph_root_mask = self._process(graph)
            self.input_ids.append(torch.tensor(ids))
            self.origin_attention_mask.append(torch.tensor(input_mask))
            self.root_mask.append(torch.tensor(graph_root_mask))
        
        # Padding
        self.input_ids = torch.nn.utils.rnn.pad_sequence(
            self.input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )
        
        self.root_mask = torch.nn.utils.rnn.pad_sequence(
            self.root_mask, batch_first=True, padding_value=0
        )

        max_len = self.input_ids.size(1)
        padded_attention = torch.zeros( 
            (len(self.origin_attention_mask),  max_len, max_len),
            dtype=torch.float32, 
            device=self.input_ids.device 
        )

        for i, mask in enumerate(self.origin_attention_mask): 
            seq_len = mask.size(0) 
            padded_attention[i, :seq_len, :seq_len] = mask 
        
        self.attention_mask = padded_attention

    def _token_reduction_select(self, original_idx, doc_max_length):
        select_encodings = []
        
        for idx in original_idx:
            score = self.scores[idx]
            attention_mask = self.attention_masks[idx]
            encoding = self.encodings[idx]
            
            valid_token_len = attention_mask.sum().item()
            
            if self.config.lm_type == self.config.reduction_lm_type:  
                        
                # select topk importance tokens
                _, indices = score.topk(doc_max_length - 1 if valid_token_len >= doc_max_length - 1 else valid_token_len)
                
                # keep the tokens original position
                indices, _ = indices.sort() 
                
                select_encodings.append(encoding[indices].tolist())
                #print(f"max = {doc_max_length}, len = {len(indices)}")
            
            else:                
                num_tokens = int(doc_max_length * 1.2)
                _, indices = score.topk(num_tokens if valid_token_len >= num_tokens else valid_token_len)
                indices, _ = indices.sort()
                
                # First decode due to different tokenizer
                txt = self.reduction_tokenizer.decode(encoding[indices], skip_special_tokens=True)
                
                # encode using target LMs tokenizer
                enc = self.tokenizer.encode(txt, add_special_tokens=False)
                
                select_encodings.append(enc[:doc_max_length - 1])
                
        return select_encodings

    def get_block_rules(self, ego_graph, center_idx):
        node_num = len(ego_graph.x)
        block_rules = [[] for _ in range(node_num)]
        level = [[] for _ in range(self.hop+1)]
        row, col = ego_graph.edge_index
        distances = {center_idx: 0}
        queue = [center_idx]
        level[0].append(center_idx)

        while queue:
            current = queue.pop(0)
            block_rules[current].append(current)
            for neighbor_tensor in col[row == current]:
                neighbor = neighbor_tensor.item()
                if neighbor in distances:
                    if distances[current] == distances[neighbor]:
                        if current not in block_rules[neighbor]:
                            block_rules[current].append(neighbor)
                            block_rules[neighbor].append(current)
                else:
                    distances[neighbor] = distances[current] + 1
                    level[distances[neighbor]].append(neighbor)
                    block_rules[neighbor].append(current)
                    block_rules[current].append(neighbor)
                    queue.append(neighbor)
        return block_rules, level
                
    def _process(self, graph):
        origin_idx = graph.original_idx.tolist()
        block_rules, level = self.get_block_rules(graph, graph.root_n_index)
        index_map = {}
        mapped_nodes = []
        for k in range(3):
            distance = 2 - k
            for index in level[distance]:
                index_map[index] = len(mapped_nodes)
                mapped_nodes.append(index)
        num_nodes = len(mapped_nodes)

            
        # print(f"num_nodes = {num_nodes}  true = {len(mapped_nodes)}")
        # print(f"max_length = {self.config.max_length}")
        doc_max_length = self.config.max_length // num_nodes
        root_mask = []
        ids = []
        block_boundaries = [0]
        
        if self.config.use_reduction:
            _encodings = self._token_reduction_select(origin_idx, doc_max_length)

        for i, index in enumerate(mapped_nodes):
            if self.config.use_reduction:
                token_ids = _encodings[index]

            if index == graph.root_n_index:
                _root_mask = [1 for _ in range(len(token_ids))]
            else:
                _root_mask = [0 for _ in range(len(token_ids))]

            root_mask.extend(_root_mask)
            ids.extend(token_ids)
            block_boundaries.append(len(ids))

        # print(f"len = {len(ids)}")


        input_mask = torch.zeros((len(ids),  len(ids)), dtype=torch.float32) 
        for i, index in enumerate(mapped_nodes):
            visible_indices = [index_map[v] for v in block_rules[index]]
            for visible_i in visible_indices:
                input_mask[block_boundaries[i]:block_boundaries[i+1], block_boundaries[visible_i]:block_boundaries[visible_i+1]] = 1 
            
        return ids, input_mask, root_mask
          
        
    def __getitem__(self, idx):
        item = {'input_ids': self.input_ids[idx], \
                'attention_mask': self.attention_mask[idx],
                'labels': self.labels[idx],
                'root_mask': self.root_mask[idx]}
        
        return item
    
    def __len__(self):
        return len(self.labels)

class H_NCDataset(torch.utils.data.Dataset):
    """
    Dataset for fine-tuning language models
    """
    def __init__(self, graphs, labels, tokenizer, text, config):
        self.labels = labels
        self.tokenizer = tokenizer        
        self.config = config
        
        #self.reduction_tokenizer = transformers.AutoTokenizer.from_pretrained(model_id[self.config.reduction_lm_type])
        model_path = f""
 
        self.reduction_tokenizer = transformers.AutoTokenizer.from_pretrained( 
            model_path,
            local_files_only=True  # 强制只使用本地文件
        )


        if self.config.use_reduction:
            reduction_file = os.path.join('out', 'reduction_out', f'{config.reduction_lm_type}', f'{config.dataset}.pt')
            print(f"Loading reduction file : {reduction_file}")
            save_reduction = torch.load(reduction_file)
            self.scores = save_reduction['score']
            self.attention_masks = save_reduction['attention_mask']
            self.encodings = save_reduction['encodings']
        else:
            # tokenizing the text
            self.encodings = [tokenizer.encode(txt, add_special_tokens=False) for txt in text]
            
        self.input_ids = []     
        self.attention_mask = []
        self.boundaries = []
        self.roots = []
               
        for graph in graphs:
            ids, input_mask, boundary, root = self._process(graph)
            self.input_ids.append(torch.tensor(ids))
            self.attention_mask.append(torch.tensor(input_mask))
            self.boundaries.append(torch.tensor(boundary))
            self.roots.append(torch.tensor(root))
        
        # Padding
        self.input_ids = torch.nn.utils.rnn.pad_sequence(
            self.input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )

        self.attention_mask = torch.nn.utils.rnn.pad_sequence(
            self.attention_mask, batch_first=True, padding_value=0
        )

    def _token_reduction_select(self, original_idx, doc_max_length):
        select_encodings = []
        
        for idx in original_idx:
            score = self.scores[idx]
            attention_mask = self.attention_masks[idx]
            encoding = self.encodings[idx]
            
            valid_token_len = attention_mask.sum().item()
            
            if self.config.lm_type == self.config.reduction_lm_type:  
                        
                # select topk importance tokens
                _, indices = score.topk(doc_max_length - 1 if valid_token_len >= doc_max_length - 1 else valid_token_len)
                
                # keep the tokens original position
                indices, _ = indices.sort() 
                
                select_encodings.append(encoding[indices].tolist())
                #print(f"max = {doc_max_length}, len = {len(indices)}")
            
            else:                
                num_tokens = int(doc_max_length * 1.2)
                _, indices = score.topk(num_tokens if valid_token_len >= num_tokens else valid_token_len)
                indices, _ = indices.sort()
                
                # First decode due to different tokenizer
                txt = self.reduction_tokenizer.decode(encoding[indices], skip_special_tokens=True)
                
                # encode using target LMs tokenizer
                enc = self.tokenizer.encode(txt, add_special_tokens=False)
                
                select_encodings.append(enc[:doc_max_length - 1])
                
        return select_encodings
                

    def _process(self, graph):
        origin_idx = graph.original_idx.tolist()
        
        root_origin_idx = graph.original_idx[graph.root_n_index].item()
        origin_idx.remove(root_origin_idx)
        origin_idx.append(root_origin_idx)
        graph.root_n_index = graph.center = len(origin_idx) - 1
        
        #print(f"max_length = {self.config.max_length}")
        doc_max_length = self.config.max_length // graph.num_nodes
        #print(f"num = {graph.num_nodes}")
        #root_mask = []
        root = graph.root_n_index
        ids = []
        boundary = []
        
        if self.config.use_reduction:
            _encodings = self._token_reduction_select(origin_idx, doc_max_length)
        
        for i, idx in enumerate(origin_idx):
            start = len(ids)
            
            if self.config.use_reduction:
                # select important tokens via redcution module
                token_ids = _encodings[i]
            else:
                token_ids = self.encodings[idx][:doc_max_length - 1]
            
            # Add [SEP] token between nodes' text
            token_ids.append(self.tokenizer.sep_token_id)
            
            ids.extend(token_ids)

            end = len(ids) - 2
            boundary.append([start, end])
            
            
        input_mask = [1 for _ in range(len(ids))]

        #print(f"len = {len(ids)}")

        while len(boundary) < 40:
            boundary.append([-1, -1])

        if len(boundary) > 40:
            print(f"wow !!! {len(boundary)}")
            
        return ids, input_mask, boundary, root
          
        
    def __getitem__(self, idx):
        item = {'input_ids': self.input_ids[idx], \
                'attention_mask': self.attention_mask[idx],
                'labels': self.labels[idx],
                'root': self.roots[idx],
                'boundary': self.boundaries[idx]}
        
        return item
    
    def __len__(self):
        return len(self.labels)

class A_NCDataset(torch.utils.data.Dataset):
    """
    Dataset for fine-tuning language models
    """
    def __init__(self, graphs, labels, tokenizer, text, config):
        self.labels = labels
        self.tokenizer = tokenizer        
        self.config = config
        
        #self.reduction_tokenizer = transformers.AutoTokenizer.from_pretrained(model_id[self.config.reduction_lm_type])
        model_path = f""
 
        self.reduction_tokenizer = transformers.AutoTokenizer.from_pretrained( 
            model_path,
            local_files_only=True  # 强制只使用本地文件
        )


        if self.config.use_reduction:
            reduction_file = os.path.join('out', 'reduction_out', f'{config.reduction_lm_type}', f'{config.dataset}.pt')
            print(f"Loading reduction file : {reduction_file}")
            save_reduction = torch.load(reduction_file)
            self.scores = save_reduction['score']
            self.attention_masks = save_reduction['attention_mask']
            self.encodings = save_reduction['encodings']
        else:
            # tokenizing the text
            self.encodings = [tokenizer.encode(txt, add_special_tokens=False) for txt in text]
            
        self.input_ids = []    
        self.all_block_rules = [] 
        self.boundaries = []
        self.roots = []
        self.attention_mask = []
        self.input_root_ids = []
        self.lengths = []
        self.root_mask = []
        self.sorted_sequences = []
               
        for graph in graphs:
            ids, block_rules, boundary, root, input_mask, root_ids, length, graph_root_mask, sorted_sequence = self._process(graph)
            self.input_ids.append(torch.tensor(ids))
            self.all_block_rules.append(torch.tensor(block_rules))
            self.boundaries.append(torch.tensor(boundary))
            self.roots.append(torch.tensor(root))
            self.attention_mask.append(torch.tensor(input_mask))
            self.input_root_ids.append(torch.tensor(root_ids))
            self.lengths.append(torch.tensor(length))
            self.root_mask.append(torch.tensor(graph_root_mask))
            self.sorted_sequences.append(torch.tensor(sorted_sequence))
        
        # Padding
        self.input_ids = torch.nn.utils.rnn.pad_sequence(
            self.input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )

        self.attention_mask = torch.nn.utils.rnn.pad_sequence(
            self.attention_mask, batch_first=True, padding_value=0
        )

        self.input_root_ids = torch.nn.utils.rnn.pad_sequence(
            self.input_root_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )

        self.root_mask = torch.nn.utils.rnn.pad_sequence(
            self.root_mask, batch_first=True, padding_value=0
        )

    def _token_reduction_select(self, original_idx, doc_max_length):
        select_encodings = []
        
        for idx in original_idx:
            score = self.scores[idx]
            attention_mask = self.attention_masks[idx]
            encoding = self.encodings[idx]
            
            valid_token_len = attention_mask.sum().item()
            
            if self.config.lm_type == self.config.reduction_lm_type:  
                        
                # select topk importance tokens
                _, indices = score.topk(doc_max_length - 1 if valid_token_len >= doc_max_length - 1 else valid_token_len)
                
                # keep the tokens original position
                indices, _ = indices.sort() 
                
                select_encodings.append(encoding[indices].tolist())
                #print(f"max = {doc_max_length}, len = {len(indices)}")
            
            else:                
                num_tokens = int(doc_max_length * 1.2)
                _, indices = score.topk(num_tokens if valid_token_len >= num_tokens else valid_token_len)
                indices, _ = indices.sort()
                
                # First decode due to different tokenizer
                txt = self.reduction_tokenizer.decode(encoding[indices], skip_special_tokens=True)
                
                # encode using target LMs tokenizer
                enc = self.tokenizer.encode(txt, add_special_tokens=False)
                
                select_encodings.append(enc[:doc_max_length - 1])
                
        return select_encodings

    # def get_block_rules(self, ego_graph, center_idx):
    #     # 0 : 节点内部
    #     # 1 : 0 - 1 有连接
    #     # 2 : 1 - 2 有连接
    #     # 3 : 同层 有连接 非同一节点
    #     # 4 : 0 - 2 无连接
    #     # 5 : 1 - 2 无连接
    #     # 6 : 同层 无连接 非同一节点

    #     node_num = len(ego_graph.x)
    #     block_rules = [[-1 for _ in range(40)] for _ in range(40)]
    #     row, col = ego_graph.edge_index
    #     distances = {center_idx: 0}
    #     queue = [center_idx]

    #     while queue:
    #         current = queue.pop(0)
    #         block_rules[current][current] = 0
    #         for neighbor_tensor in col[row == current]:
    #             neighbor = neighbor_tensor.item()
    #             if neighbor == current:
    #                 continue
    #             if not neighbor in distances:
    #                 distances[neighbor] = distances[current] + 1
    #                 queue.append(neighbor)
    #             if distances[current] == distances[neighbor]:
    #                 block_rules[current][neighbor] = 3
    #                 block_rules[neighbor][current] = 3
    #             else:
    #                 if distances[current] + distances[neighbor] == 1:
    #                     block_rules[current][neighbor] = 1
    #                 else:
    #                     block_rules[current][neighbor] = 2
    #                 block_rules[neighbor][current] = block_rules[current][neighbor]

    #     for index_i in range(node_num):
    #         for index_j in range(node_num):
    #             if block_rules[index_i][index_j] != -1:
    #                 continue
    #             if distances[index_i] == distances[index_j]:
    #                 block_rules[index_i][index_j] = 6
    #                 continue
    #             if distances[index_i] + distances[index_j] == 2:
    #                 block_rules[index_i][index_j] = 4
    #             else:
    #                 block_rules[index_i][index_j] = 5

    #     #level = [[] for _ in range(3)]
    #     sorted_idx = []
    #     sorted_block_rules = [[-1 for _ in range(40)] for _ in range(40)]
    #     # for node, distance in distances.items():
    #     #     level[distance].append(node)
        
    #     # for i in range(3):
    #     #     nodes = level[2-i]
    #     #     for node in nodes:
    #     #         sorted_idx.append(node)

    #     for i in range(node_num):
    #         if i == center_idx:
    #             continue
    #         sorted_idx.append(i)
    #     sorted_idx.append(center_idx)

    #     for index_i in range(node_num):
    #         for index_j in range(node_num):
    #             sorted_block_rules[index_i][index_j] = block_rules[sorted_idx[index_i]][sorted_idx[index_j]]

    #     return sorted_idx, sorted_block_rules 

    def get_block_rules(self, ego_graph, center_idx):
        # 简化规则：
        # 0 : 节点自身
        # 0 : 有边
        # 4 : 无边

        node_num = len(ego_graph.x)
        block_rules = [[4 for _ in range(node_num)] for _ in range(40)]  # 默认所有节点间无边，设为4
        row, col = ego_graph.edge_index

        # 自己对自己设置为0
        for i in range(node_num):
            block_rules[i][i] = 0

        # 设置有边的节点对为0
        for current, neighbor_tensor in zip(row, col):
            current = current.item()
            neighbor = neighbor_tensor.item()
            block_rules[current][neighbor] = 0
            block_rules[neighbor][current] = 0

        # 排序索引列表
        sorted_idx = [i for i in range(node_num) if i != center_idx]
        sorted_idx.append(center_idx)

        sorted_block_rules = [[-1 for _ in range(40)] for _ in range(40)]

        for index_i in range(node_num):
            for index_j in range(node_num):
                sorted_block_rules[index_i][index_j] = block_rules[sorted_idx[index_i]][sorted_idx[index_j]]

        return sorted_idx, sorted_block_rules           

    def _process(self, graph):
        sorted_idx, block_rules = self.get_block_rules(graph, graph.root_n_index)

        origin_idx = graph.original_idx.tolist()
        doc_max_length = self.config.max_length // graph.num_nodes
        root = len(sorted_idx) - 1
        #root = graph.root_n_index
        ids = []
        boundary = []
        root_mask = []
        sorted_sequence = []
        
        if self.config.use_reduction:
            _encodings = self._token_reduction_select(origin_idx, doc_max_length)
        
        for i, sorted_i in enumerate(sorted_idx):
            start = len(ids)

            sorted_sequence.append(origin_idx[sorted_i])
            
            if self.config.use_reduction:
                token_ids = _encodings[sorted_i]
            else:
                token_ids = self.encodings[origin_idx[sorted_i]][:doc_max_length - 1]

            # Add [SEP] token between nodes' text
            token_ids.append(self.tokenizer.sep_token_id)

            if sorted_i == graph.root_n_index:
                _root_mask = [1 for _ in range(len(token_ids))]
                _root_mask[-1] = 0
            else:
                _root_mask = [0 for _ in range(len(token_ids))]
            
            root_mask.extend(_root_mask)
            
            ids.extend(token_ids)

            end = len(ids) - 2
            boundary.append([start, end])

        input_mask = [1 for _ in range(len(ids))]

        while len(boundary) < 40:
            boundary.append([-1, -1])
            sorted_sequence.append(-1)

        root_ids = []
        if self.config.use_reduction:
            token_encodings = self._token_reduction_select([origin_idx[graph.root_n_index]], 50)
        
        root_encodings = token_encodings[0]

        root_ids.extend(root_encodings)
        length = len(root_ids)
            
        return ids, block_rules, boundary, root, input_mask, root_ids, length, root_mask, sorted_sequence
          
        
    def __getitem__(self, idx):
        item = {'input_ids': self.input_ids[idx], \
                'block_rules': self.all_block_rules[idx],
                # 'attention_mask': self.attention_mask[idx],
                'labels': self.labels[idx],
                'root': self.roots[idx],
                'boundary': self.boundaries[idx],
                # 'input_root_ids':self.input_root_ids[idx],
                # 'length':self.lengths[idx],
                'root_mask': self.root_mask[idx],
                'sorted_sequence' : self.sorted_sequences[idx]
                }
        
        return item
    
    def __len__(self):
        return len(self.labels)

class Faster_A_NCDataset(torch.utils.data.Dataset):
    """
    Dataset for fine-tuning language models
    """
    def __init__(self, graphs, labels, tokenizer, text, config):
        self.labels = labels
        self.tokenizer = tokenizer        
        self.config = config
        
        #self.reduction_tokenizer = transformers.AutoTokenizer.from_pretrained(model_id[self.config.reduction_lm_type])
        model_path = f""
 
        self.reduction_tokenizer = transformers.AutoTokenizer.from_pretrained( 
            model_path,
            local_files_only=True  # 强制只使用本地文件
        )


        if self.config.use_reduction:
            reduction_file = os.path.join('out', 'reduction_out', f'{config.reduction_lm_type}', f'{config.dataset}.pt')
            print(f"Loading reduction file : {reduction_file}")
            save_reduction = torch.load(reduction_file)
            self.scores = save_reduction['score']
            self.attention_masks = save_reduction['attention_mask']
            self.encodings = save_reduction['encodings']
        else:
            # tokenizing the text
            self.encodings = [tokenizer.encode(txt, add_special_tokens=False) for txt in text]
            
        self.input_ids = []    
        self.all_block_rules = [] 
        self.boundaries = []
        self.roots = []
        self.attention_mask = []
        self.input_root_ids = []
        self.lengths = []
        self.root_mask = []
               
        for graph in graphs:
            ids, block_rules, boundary, root, input_mask, root_ids, length, graph_root_mask = self._process(graph)
            self.input_ids.append(torch.tensor(ids))
            self.all_block_rules.append(torch.tensor(block_rules))
            self.boundaries.append(torch.tensor(boundary))
            self.roots.append(torch.tensor(root))
            self.attention_mask.append(torch.tensor(input_mask))
            self.input_root_ids.append(torch.tensor(root_ids))
            self.lengths.append(torch.tensor(length))
            self.root_mask.append(torch.tensor(graph_root_mask))
        
        # Padding
        self.input_ids = torch.nn.utils.rnn.pad_sequence(
            self.input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )

        self.attention_mask = torch.nn.utils.rnn.pad_sequence(
            self.attention_mask, batch_first=True, padding_value=0
        )

        self.input_root_ids = torch.nn.utils.rnn.pad_sequence(
            self.input_root_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )

        self.root_mask = torch.nn.utils.rnn.pad_sequence(
            self.root_mask, batch_first=True, padding_value=0
        )

    def _token_reduction_select(self, original_idx, doc_max_length):
        select_encodings = []
        
        for idx in original_idx:
            score = self.scores[idx]
            attention_mask = self.attention_masks[idx]
            encoding = self.encodings[idx]
            
            valid_token_len = attention_mask.sum().item()
            
            if self.config.lm_type == self.config.reduction_lm_type:  
                        
                # select topk importance tokens
                _, indices = score.topk(doc_max_length - 1 if valid_token_len >= doc_max_length - 1 else valid_token_len)
                
                # keep the tokens original position
                indices, _ = indices.sort() 
                
                select_encodings.append(encoding[indices].tolist())
                #print(f"max = {doc_max_length}, len = {len(indices)}")
            
            else:                
                num_tokens = int(doc_max_length * 1.2)
                _, indices = score.topk(num_tokens if valid_token_len >= num_tokens else valid_token_len)
                indices, _ = indices.sort()
                
                # First decode due to different tokenizer
                txt = self.reduction_tokenizer.decode(encoding[indices], skip_special_tokens=True)
                
                # encode using target LMs tokenizer
                enc = self.tokenizer.encode(txt, add_special_tokens=False)
                
                select_encodings.append(enc[:doc_max_length - 1])
                
        return select_encodings

    def get_block_rules(self, ego_graph, center_idx):
        # 0 : 有连接
        # 4 : 无连接

        node_num = len(ego_graph.x)
        block_rules = [[4 for _ in range(40)] for _ in range(40)]

        for i in range(node_num):
            block_rules[i][i] = 0
        
        row, col = ego_graph.edge_index  
        for src, dst in zip(row.tolist(),  col.tolist()): 
            src, dst = src.item(),  dst.item() 
            if src != dst:  # 避免重复设置自环
                block_rules[src][dst] = 0 
                block_rules[dst][src] = 0  # 无向图对称设置 

        #level = [[] for _ in range(3)]
        sorted_idx = []
        sorted_block_rules = [[4 for _ in range(40)] for _ in range(40)]

        for i in range(node_num):
            if i == center_idx:
                continue
            sorted_idx.append(i)
        sorted_idx.append(center_idx)

        for index_i in range(node_num):
            for index_j in range(node_num):
                sorted_block_rules[index_i][index_j] = block_rules[sorted_idx[index_i]][sorted_idx[index_j]]

        return sorted_idx, sorted_block_rules              

    def _process(self, graph):
        sorted_idx, block_rules = self.get_block_rules(graph, graph.root_n_index)

        origin_idx = graph.original_idx.tolist()
        doc_max_length = self.config.max_length // graph.num_nodes
        root = len(sorted_idx) - 1
        #root = graph.root_n_index
        ids = []
        boundary = []
        root_mask = []
        
        if self.config.use_reduction:
            _encodings = self._token_reduction_select(origin_idx, doc_max_length)
        
        for i, sorted_i in enumerate(sorted_idx):
            start = len(ids)
            
            if self.config.use_reduction:
                token_ids = _encodings[sorted_i]
            else:
                token_ids = self.encodings[origin_idx[sorted_i]][:doc_max_length - 1]

            # Add [SEP] token between nodes' text
            token_ids.append(self.tokenizer.sep_token_id)

            if sorted_i == graph.root_n_index:
                _root_mask = [1 for _ in range(len(token_ids))]
                _root_mask[-1] = 0
            else:
                _root_mask = [0 for _ in range(len(token_ids))]
            
            root_mask.extend(_root_mask)
            
            ids.extend(token_ids)

            end = len(ids) - 2
            boundary.append([start, end])

        input_mask = [1 for _ in range(len(ids))]

        while len(boundary) < 40:
            boundary.append([-1, -1])

        root_ids = []
        if self.config.use_reduction:
            token_encodings = self._token_reduction_select([origin_idx[graph.root_n_index]], 50)
        
        root_encodings = token_encodings[0]

        root_ids.extend(root_encodings)
        length = len(root_ids)
            
        return ids, block_rules, boundary, root, input_mask, root_ids, length, root_mask
          
        
    def __getitem__(self, idx):
        item = {'input_ids': self.input_ids[idx], \
                'block_rules': self.all_block_rules[idx],
                'attention_mask': self.attention_mask[idx],
                'labels': self.labels[idx],
                'root': self.roots[idx],
                'boundary': self.boundaries[idx],
                'input_root_ids':self.input_root_ids[idx],
                'length':self.lengths[idx],
                'root_mask': self.root_mask[idx]
                }
        
        return item
    
    def __len__(self):
        return len(self.labels)
    
# class Node_NCDataset(torch.utils.data.Dataset):
#     """
#     Dataset for fine-tuning language models
#     """
#     def __init__(self, graphs, labels, tokenizer, text, config, best_token_embeddings, best_low_embeddings, best_boundary, best_block_rules, best_root):
#         self.labels = labels
#         self.tokenizer = tokenizer        
#         self.config = config
        
#         self.low_embeddings = best_low_embeddings
#         self.token_embeddings = best_token_embeddings
#         self.boundary = best_boundary
#         self.block_rules = best_block_rules
#         self.root = best_root
        
#     def __getitem__(self, idx):
#         item = {'frozen_node_embeddings': self.low_embeddings[idx],
#                 'token_embeddings': self.token_embeddings[idx],
#                 'block_rules': self.block_rules[idx],
#                 'labels': self.labels[idx],
#                 'root': self.root[idx],
#                 'boundary': self.boundary[idx]}
        
#         return item
    
#     def __len__(self):
#         return len(self.labels)

class Node_NCDataset(torch.utils.data.Dataset):
    """
    Dataset for fine-tuning language models
    """
    def __init__(self, graphs, labels, tokenizer, text, config):
        self.labels = labels
        self.tokenizer = tokenizer        
        self.config = config

        path = os.path.join('out', 'lm_token_embeddings', config.lm_type, f'{config.dataset}.pt')
        self.token_embeddings = torch.load(path)

        path = os.path.join('out', 'lm_low_embeddings', config.lm_type, f'{config.dataset}.pt')
        self.low_embeddings = torch.load(path)

        path = os.path.join('out', 'lm_boundary', config.lm_type, f'{config.dataset}.pt')
        self.boundary = torch.load(path)

        path = os.path.join('out', 'lm_root', config.lm_type, f'{config.dataset}.pt')
        self.root = torch.load(path)

        path = os.path.join('out', 'lm_block_rules', config.lm_type, f'{config.dataset}.pt')
        self.block_rules = torch.load(path)
        
    def __getitem__(self, idx):
        item = {'frozen_node_embeddings': self.low_embeddings[idx],
                'token_embeddings': self.token_embeddings[idx],
                'block_rules': self.block_rules[idx],
                'labels': self.labels[idx],
                'root': self.root[idx],
                'boundary': self.boundary[idx]}
        
        return item
    
    def __len__(self):
        return len(self.labels)

class L_NCDataset(torch.utils.data.Dataset):
    """
    Dataset for fine-tuning language models
    """
    def __init__(self, graphs, labels, tokenizer, text, config):
        self.tokenizer = tokenizer        
        self.config = config
        self.labels = labels
        
        #self.reduction_tokenizer = transformers.AutoTokenizer.from_pretrained(model_id[self.config.reduction_lm_type])
        model_path = f""
 
        self.reduction_tokenizer = transformers.AutoTokenizer.from_pretrained( 
            model_path,
            add_special_tokens=False,
            local_files_only=True  # 强制只使用本地文件
        )

        if self.config.use_reduction:
            reduction_file = os.path.join('out', 'reduction_out', f'{config.reduction_lm_type}', f'{config.dataset}.pt')
            print(f"Loading reduction file : {reduction_file}")
            save_reduction = torch.load(reduction_file)
            self.scores = save_reduction['score']
            self.attention_masks = save_reduction['attention_mask']
            self.encodings = save_reduction['encodings']
        else:
            # tokenizing the text
            self.encodings = [tokenizer.encode(txt, add_special_tokens=False) for txt in text]
            
        self.input_ids = []    
        self.lengths = []
               
        for graph in graphs:
            ids, length = self._process(graph)
            self.input_ids.append(torch.tensor(ids))
            self.lengths.append(torch.tensor(length))
        
        # Padding
        self.input_ids = torch.nn.utils.rnn.pad_sequence(
            self.input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )

    def _token_reduction_select(self, original_idx, doc_max_length):
        select_encodings = []
        
        for idx in original_idx:
            score = self.scores[idx]
            attention_mask = self.attention_masks[idx]
            encoding = self.encodings[idx]
            
            valid_token_len = attention_mask.sum().item()
            
            if self.config.lm_type == self.config.reduction_lm_type:  
                        
                # select topk importance tokens
                _, indices = score.topk(doc_max_length if valid_token_len >= doc_max_length else valid_token_len)
                
                # keep the tokens original position
                indices, _ = indices.sort() 
                
                select_encodings.append(encoding[indices].tolist())
                #print(f"max = {doc_max_length}, len = {len(indices)}")
            
            else:                
                num_tokens = int(doc_max_length * 1.2)
                _, indices = score.topk(num_tokens if valid_token_len >= num_tokens else valid_token_len)
                indices, _ = indices.sort()
                
                # First decode due to different tokenizer
                txt = self.reduction_tokenizer.decode(encoding[indices], skip_special_tokens=True)
                
                # encode using target LMs tokenizer
                enc = self.tokenizer.encode(txt, add_special_tokens=False)
                
                select_encodings.append(enc[:doc_max_length - 1])
                
        return select_encodings
          

    def _process(self, graph):

        origin_idx = graph.original_idx.tolist()
        #doc_max_length = self.config.max_length - self.label_length
        doc_max_length = 50
        root = graph.root_n_index
        boundary = []
        ids = []
        
        if self.config.use_reduction:
            token_encodings = self._token_reduction_select([origin_idx[root]], doc_max_length)
        
        root_encodings = token_encodings[0]

        ids.extend(root_encodings)
        length = len(ids)

            
        return ids, length
        
    def __getitem__(self, idx):
        item = {'input_ids': self.input_ids[idx], 
                'labels': self.labels[idx],
                'length': self.lengths[idx]}
        
        return item
    
    def __len__(self):
        return len(self.labels)

class S_NCDataset(torch.utils.data.Dataset):
    """
    Dataset for fine-tuning language models
    """
    def __init__(self, graphs, labels, tokenizer, token_model, text, config):
        self.labels = labels
        self.tokenizer = tokenizer   
        self.token_encoder = token_model     
        self.config = config

        if self.config.use_reduction:
            reduction_file = os.path.join('out', 'reduction_out', f'{config.reduction_lm_type}', f'{config.dataset}.pt')
            print(f"Loading reduction file : {reduction_file}")
            save_reduction = torch.load(reduction_file)
            self.scores = save_reduction['score']
            self.attention_masks = save_reduction['attention_mask']
            self.encodings = save_reduction['encodings']
        else:
            # tokenizing the text
            self.encodings = [tokenizer.encode(txt, add_special_tokens=False) for txt in text]
            
        self.input_encodings = []
        self.lengths = []
        self.root_encodings = []
               
        for graph in graphs:
            input_encoding, root_encoding, length = self._process(graph)
            self.input_encodings.append(torch.tensor(input_encoding))
            self.lengths.append(torch.tensor(length))
            self.root_encodings.append(torch.tensor(root_encoding))

        
        # Padding
        self.input_encodings = torch.nn.utils.rnn.pad_sequence(
            self.input_encodings, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )

    def _token_reduction_select(self, original_idx, doc_max_length):
        select_encodings = []
        
        for idx in original_idx:
            score = self.scores[idx]
            attention_mask = self.attention_masks[idx]
            encoding = self.encodings[idx]
            
            valid_token_len = attention_mask.sum().item()
            
            if self.config.lm_type == self.config.reduction_lm_type:  
                        
                # select topk importance tokens
                _, indices = score.topk(doc_max_length if valid_token_len >= doc_max_length else valid_token_len)
                
                # keep the tokens original position
                indices, _ = indices.sort() 
                
                select_encodings.append(encoding[indices].tolist())
                #print(f"max = {doc_max_length}, len = {len(indices)}")
            
            else:                
                num_tokens = int(doc_max_length * 1.2)
                _, indices = score.topk(num_tokens if valid_token_len >= num_tokens else valid_token_len)
                indices, _ = indices.sort()
                
                # First decode due to different tokenizer
                txt = self.reduction_tokenizer.decode(encoding[indices], skip_special_tokens=True)
                
                # encode using target LMs tokenizer
                enc = self.tokenizer.encode(txt, add_special_tokens=False)
                
                select_encodings.append(enc[:doc_max_length - 1])
                
        return select_encodings   

    def _process(self, graph, return_dict=None):
        device = self.token_encoder.device

        origin_idx = graph.original_idx.tolist()
        doc_max_length = self.config.max_length // graph.num_nodes
        root = graph.root_n_index
        
        if self.config.use_reduction:
            token_encodings = self._token_reduction_select([origin_idx[graph.root_n_index]], 50)
        
        input_tokens = token_encodings[0]
        input_mask = [1 for _ in range(len(input_tokens))]
        input_mask = torch.tensor(input_mask).unsqueeze(0).to(device)
        input_tokens = torch.tensor(input_tokens).unsqueeze(0).to(device)

        outputs = self.token_encoder(input_ids=input_tokens,
                                    attention_mask=input_mask,
                                    return_dict=return_dict,
                                    output_hidden_states=True)

        input_embedding = outputs['hidden_states'][-1]

        input_encoding = input_embedding[0]
        length = input_encoding.size(0)
        root_encoding = torch.mean(input_encoding, dim=0)

        return input_encoding, root_encoding, length
          
        
    def __getitem__(self, idx):
        item = {'input_encodings': self.input_encodings[idx], \
                'length': self.lengths[idx],
                'labels': self.labels[idx]}
        
        return item
    
    def __len__(self):
        return len(self.labels)


class TL_NCDataset(torch.utils.data.Dataset):
    """
    Dataset for fine-tuning language models
    """
    def __init__(self, graphs, labels, tokenizer, text, config, label_text):
        self.labels = labels
        self.tokenizer = tokenizer        
        self.config = config
        self.label_text = label_text
        self.label_tokens = []
        self.label_length = 0
        
        #self.reduction_tokenizer = transformers.AutoTokenizer.from_pretrained(model_id[self.config.reduction_lm_type])
        model_path = f""
 
        self.reduction_tokenizer = transformers.AutoTokenizer.from_pretrained( 
            model_path,
            add_special_tokens=False,
            local_files_only=True  # 强制只使用本地文件
        )

        for label in label_text:
            #print(label)
            label_encoding = self.reduction_tokenizer.encode(label, add_special_tokens=False)
            #print(label_encoding)
            self.label_tokens.append(label_encoding)
            self.label_length += len(label_encoding)

        #print(self.label_tokens)

        if self.config.use_reduction:
            reduction_file = os.path.join('out', 'reduction_out', f'{config.reduction_lm_type}', f'{config.dataset}.pt')
            print(f"Loading reduction file : {reduction_file}")
            save_reduction = torch.load(reduction_file)
            self.scores = save_reduction['score']
            self.attention_masks = save_reduction['attention_mask']
            self.encodings = save_reduction['encodings']
        else:
            # tokenizing the text
            self.encodings = [tokenizer.encode(txt, add_special_tokens=False) for txt in text]
            
        self.input_ids = []    
        self.boundaries = []
               
        for graph in graphs:
            ids, boundary = self._process(graph)
            self.input_ids.append(torch.tensor(ids))
            self.boundaries.append(torch.tensor(boundary))
        
        # Padding
        self.input_ids = torch.nn.utils.rnn.pad_sequence(
            self.input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )

    def _token_reduction_select(self, original_idx, doc_max_length):
        select_encodings = []
        
        for idx in original_idx:
            score = self.scores[idx]
            attention_mask = self.attention_masks[idx]
            encoding = self.encodings[idx]
            
            valid_token_len = attention_mask.sum().item()
            
            if self.config.lm_type == self.config.reduction_lm_type:  
                        
                # select topk importance tokens
                _, indices = score.topk(doc_max_length if valid_token_len >= doc_max_length else valid_token_len)
                
                # keep the tokens original position
                indices, _ = indices.sort() 
                
                select_encodings.append(encoding[indices].tolist())
                #print(f"max = {doc_max_length}, len = {len(indices)}")
            
            else:                
                num_tokens = int(doc_max_length * 1.2)
                _, indices = score.topk(num_tokens if valid_token_len >= num_tokens else valid_token_len)
                indices, _ = indices.sort()
                
                # First decode due to different tokenizer
                txt = self.reduction_tokenizer.decode(encoding[indices], skip_special_tokens=True)
                
                # encode using target LMs tokenizer
                enc = self.tokenizer.encode(txt, add_special_tokens=False)
                
                select_encodings.append(enc[:doc_max_length - 1])
                
        return select_encodings
          

    def _process(self, graph):

        origin_idx = graph.original_idx.tolist()
        doc_max_length = self.config.max_length - self.label_length
        root = graph.root_n_index
        boundary = []
        ids = []
        
        if self.config.use_reduction:
            token_encodings = self._token_reduction_select([origin_idx[root]], doc_max_length)
        
        root_encodings = token_encodings[0]

        for label_encoding in self.label_tokens:
            start = len(ids)
            ids.extend(label_encoding)
            end = len(ids) - 1
            boundary.append([start, end])

        start = len(ids)
        ids.extend(root_encodings)
        end = len(ids) - 1
        boundary.append([start, end])


            
        return ids, boundary
        
    def __getitem__(self, idx):
        item = {'input_ids': self.input_ids[idx], 
                'labels': self.labels[idx],
                'boundary': self.boundaries[idx],
                'label_length': self.label_length}
        
        return item
    
    def __len__(self):
        return len(self.labels)