import torch
from torch.utils.data import Dataset, DataLoader
import json
import os

import random
from transformers import AutoTokenizer
from tqdm import tqdm

from torch.nn.utils.rnn import pad_sequence


all_predicate = set()

class qwen_dataset(Dataset):
    def __init__(self, data_path, 
                tokenizer  = None,
                cache_path = None,
                add_template = False 
                ):
        super().__init__()
        self.data = json.load(open(data_path, 'r'))
        self.add_template = add_template
        self.cache_path = cache_path
        self.relation_map = {'/business/company/advisors':0,
                            '/business/company/founders':1, 
                            '/business/company/industry':2, 
                            '/business/company/major_shareholders':3, 
                            '/business/company/place_founded':4, 
                            '/business/company_shareholder/major_shareholder_of':5, 
                            '/business/person/company':6, 
                            '/location/administrative_division/country':7, 
                            '/location/country/administrative_divisions':8, 
                            '/location/country/capital':9, 
                            '/location/location/contains':10, 
                            '/location/neighborhood/neighborhood_of':11, 
                            '/people/deceased_person/place_of_death':12, 
                            '/people/ethnicity/geographic_distribution':13, 
                            '/people/ethnicity/people':14, 
                            '/people/person/children':15, 
                            '/people/person/ethnicity':16, 
                            '/people/person/nationality':17, 
                            '/people/person/place_lived':18, 
                            '/people/person/place_of_birth':19, 
                            '/people/person/profession':20, 
                            '/people/person/religion':21, 
                            '/sports/sports_team/location':22, 
                            '/sports/sports_team_location/teams':23}
        # self.all_predicate = set()
        if tokenizer is None:
            self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
            special_tokens = {"additional_special_tokens": ["<|emb|>"]}
            self.tokenizer.add_special_tokens(special_tokens)        
            self.special_token_id = self.tokenizer.additional_special_tokens_ids[0]
        else:
            self.tokenizer = tokenizer
        self.valid_data = []
        if os.path.exists(self.cache_path):
            print(f"Loading cached filtered data from {self.cache_path}...")
            with open(self.cache_path, 'r') as f:
                self.valid_data = json.load(f)
        else:
            print("Loading and preprocessing data...")
            with open(data_path, 'r') as f:
                self.data = json.load(f)
            self.valid_data = []
            for item in tqdm(self.data):
                processed = self.is_valid(item)
                self.valid_data += processed
            # 保存预处理后的数据
            with open(self.cache_path, 'w', encoding='utf-8') as f:
                json.dump(self.valid_data, f, indent=4, ensure_ascii=False)
            print(f"Filtered data saved to {self.cache_path}")
        print(f"Initial data size: {len(self.data)}")
        print(f"Filtered data size: {len(self.valid_data)}")            


    def __len__(self):
        # return len(self.data)
        return len(self.valid_data)
    
    def __getitem__(self, idx):
        # item = self.data[idx]
        # item = self.valid_data[idx]
        item = self.preprocess(self.valid_data[idx])
        return item

    def is_valid(self, item):
        """
        判断一个样本在插入特殊标记后,经过最大长度阶段剩下的特殊token是否等于5个。
        """
        text = item['text']
        # Add_template = False
        Add_template = self.add_template

        if Add_template:
            messages = [{"role": "user", "content": text}]
            text_with_template = self.tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=False
            )
            
            # ori_text_input = self.tokenizer(text_with_template, return_tensors='pt')
            text_input = self.tokenizer(text_with_template, padding='max_length',max_length=256,truncation=True, return_tensors='pt')
            # text_add_template_length = len(ori_text_input['input_ids'][0])
            # text_length = len(self.tokenizer(text)['input_ids'])
            # template_length = text_add_template_length - text_length # 26 = 24 +2
        else:
            text_input = self.tokenizer(text, padding='max_length',max_length=256,truncation=True,return_tensors='pt')
            # relation = item['']
            
        result = []
        
        for relation in item['relation_list']:
            cur_item = {}
            # cur_item['text'] = text
            cur_item['text'] = text_with_template if Add_template else text
            relation['rel_label'] = self.relation_map[relation['predicate']]
            

            subj ,obj = relation['subject'], relation['object']
            
            if Add_template:
                relation["subj_tok_span"]  = [x + 24 for x in relation["subj_tok_span"]]
                relation["obj_tok_span"]  = [x + 24 for x in relation["obj_tok_span"]]
                # subj_tok_span += 24
                # obj_tok_span += 24
            subj_tok_span, obj_tok_span = relation["subj_tok_span"], relation["obj_tok_span"]
            cur_item['relations'] = relation
            # check entity
            sub_decode = self.tokenizer.decode(text_input['input_ids'][0][subj_tok_span[0]:subj_tok_span[1]])
            obj_decode = self.tokenizer.decode(text_input['input_ids'][0][obj_tok_span[0]:obj_tok_span[1]])
            
            if sub_decode != subj:
                print("subj decode error: {}---gold: {}".format(sub_decode, subj))
            if obj_decode != obj:
                print("obj decode error: {}---gold: {}".format(obj_decode, obj))
            result.append(cur_item)
            all_predicate.add(relation['predicate'])
        # print(self.all_predicate)
        return result


    def preprocess(self, item):
        text = item['text']
        text_input = self.tokenizer(text, padding='max_length',max_length=256,truncation=True,return_tensors='pt')
        result = {}
        result['input_ids'] = text_input['input_ids']
        result['attention_mask'] = text_input['attention_mask']
        result['labels'] = item['relations']
        return result



class qwen_dataset_list(Dataset):
    def __init__(self, data_path, 
                tokenizer  = None,
                cache_path = None,
                add_template = False 
                ):
        super().__init__()
        self.data = json.load(open(data_path, 'r'))
        self.add_template = add_template
        self.cache_path = cache_path
        self.relation_map = {'/business/company/advisors':0,
                            '/business/company/founders':1, 
                            '/business/company/industry':2, 
                            '/business/company/major_shareholders':3, 
                            '/business/company/place_founded':4, 
                            '/business/company_shareholder/major_shareholder_of':5, 
                            '/business/person/company':6, 
                            '/location/administrative_division/country':7, 
                            '/location/country/administrative_divisions':8, 
                            '/location/country/capital':9, 
                            '/location/location/contains':10, 
                            '/location/neighborhood/neighborhood_of':11, 
                            '/people/deceased_person/place_of_death':12, 
                            '/people/ethnicity/geographic_distribution':13, 
                            '/people/ethnicity/people':14, 
                            '/people/person/children':15, 
                            '/people/person/ethnicity':16, 
                            '/people/person/nationality':17, 
                            '/people/person/place_lived':18, 
                            '/people/person/place_of_birth':19, 
                            '/people/person/profession':20, 
                            '/people/person/religion':21, 
                            '/sports/sports_team/location':22, 
                            '/sports/sports_team_location/teams':23}
        # self.all_predicate = set()
        if tokenizer is None:
            self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
            special_tokens = {"additional_special_tokens": ["<|emb|>"]}
            self.tokenizer.add_special_tokens(special_tokens)        
            self.special_token_id = self.tokenizer.additional_special_tokens_ids[0]
        else:
            self.tokenizer = tokenizer
        self.valid_data = []
        if os.path.exists(self.cache_path):
            print(f"Loading cached filtered data from {self.cache_path}...")
            with open(self.cache_path, 'r') as f:
                self.valid_data = json.load(f)
        else:
            print("Loading and preprocessing data...")
            with open(data_path, 'r') as f:
                self.data = json.load(f)
            self.valid_data = []
            for item in tqdm(self.data):
                processed = self.is_valid(item)
                self.valid_data += processed
            # 保存预处理后的数据
            with open(self.cache_path, 'w', encoding='utf-8') as f:
                json.dump(self.valid_data, f, indent=4, ensure_ascii=False)
            print(f"Filtered data saved to {self.cache_path}")
        print(f"Initial data size: {len(self.data)}")
        print(f"Filtered data size: {len(self.valid_data)}")            


    def __len__(self):
        # return len(self.data)
        return len(self.valid_data)
    
    def __getitem__(self, idx):
        # item = self.data[idx]
        # item = self.valid_data[idx]
        item = self.preprocess(self.valid_data[idx])
        return item

    def is_valid(self, item):
        """
        判断一个样本在插入特殊标记后,经过最大长度阶段剩下的特殊token是否等于5个。
        """
        text = item['text']
        # Add_template = False
        Add_template = self.add_template

        if Add_template:
            messages = [{"role": "user", "content": text}]
            text_with_template = self.tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=False
            )
            
            # ori_text_input = self.tokenizer(text_with_template, return_tensors='pt')
            text_input = self.tokenizer(text_with_template, padding='max_length',max_length=256,truncation=True, return_tensors='pt')
            # text_add_template_length = len(ori_text_input['input_ids'][0])
            # text_length = len(self.tokenizer(text)['input_ids'])
            # template_length = text_add_template_length - text_length # 26 = 24 +2
        else:
            text_input = self.tokenizer(text, padding='max_length',max_length=256,truncation=True,return_tensors='pt')
            # relation = item['']


        # 创建一个字典来存储实体pair和对应的关系列表
        # entity_pair_dict = {}    
        result = []
        
        merged_items = {}
        for relation in item['relation_list']:
            cur_item = {}
            # cur_item['text'] = text
            cur_item['text'] = text_with_template if Add_template else text
            relation['rel_label'] = self.relation_map[relation['predicate']]
            

            subj ,obj = relation['subject'], relation['object']
            
            if Add_template:
                relation["subj_tok_span"]  = [x + 24 for x in relation["subj_tok_span"]]
                relation["obj_tok_span"]  = [x + 24 for x in relation["obj_tok_span"]]
                # subj_tok_span += 24
                # obj_tok_span += 24
            subj_tok_span, obj_tok_span = relation["subj_tok_span"], relation["obj_tok_span"]
            key = ( tuple(subj_tok_span), tuple(obj_tok_span))
            
            cur_item['relations'] = relation
            # check entity
            sub_decode = self.tokenizer.decode(text_input['input_ids'][0][subj_tok_span[0]:subj_tok_span[1]])
            obj_decode = self.tokenizer.decode(text_input['input_ids'][0][obj_tok_span[0]:obj_tok_span[1]])
            
            if sub_decode != subj:
                print("subj decode error: {}---gold: {}".format(sub_decode, subj))
            if obj_decode != obj:
                print("obj decode error: {}---gold: {}".format(obj_decode, obj))
            
            if key in merged_items:
                merged_items[key]['relations']['rel_label'].append(relation['rel_label'])
            else:
                relation['rel_label'] = [relation['rel_label']]
                cur_item['relations'] = relation
                merged_items[key] = cur_item
            
        result = list(merged_items.values())  
            # result.append(cur_item)
            # all_predicate.add(relation['predicate'])

            # # 创建实体pair的唯一标识
            # pair_key = (subj, obj, tuple(subj_tok_span), tuple(obj_tok_span))
            
            # # 如果这个实体pair已经存在，就添加到关系列表中
            # if pair_key in entity_pair_dict:
            #     entity_pair_dict[pair_key]['relations'].append(relation)
            # else:
            #     # 否则创建一个新的条目
            #     entity_pair_dict[pair_key] = {
            #         'text': text_content,
            #         'relations': [relation]
            #     }
    


        # print(self.all_predicate)
        return result


    def preprocess(self, item):
        text = item['text']
        text_input = self.tokenizer(text, padding='max_length',max_length=256,truncation=True,return_tensors='pt')
        result = {}
        result['input_ids'] = text_input['input_ids']
        result['attention_mask'] = text_input['attention_mask']
        result['labels'] = item['relations']
        return result

class qwen_dataset_anker(Dataset):
    def __init__(self, data_path, 
                tokenizer  = None,
                cache_path = None,
                add_template = False 
                ):
        super().__init__()
        self.data = json.load(open(data_path, 'r'))
        self.add_template = add_template
        self.cache_path = cache_path
        self.relation_map = {'/business/company/advisors':0,
                            '/business/company/founders':1, 
                            '/business/company/industry':2, 
                            '/business/company/major_shareholders':3, 
                            '/business/company/place_founded':4, 
                            '/business/company_shareholder/major_shareholder_of':5, 
                            '/business/person/company':6, 
                            '/location/administrative_division/country':7, 
                            '/location/country/administrative_divisions':8, 
                            '/location/country/capital':9, 
                            '/location/location/contains':10, 
                            '/location/neighborhood/neighborhood_of':11, 
                            '/people/deceased_person/place_of_death':12, 
                            '/people/ethnicity/geographic_distribution':13, 
                            '/people/ethnicity/people':14, 
                            '/people/person/children':15, 
                            '/people/person/ethnicity':16, 
                            '/people/person/nationality':17, 
                            '/people/person/place_lived':18, 
                            '/people/person/place_of_birth':19, 
                            '/people/person/profession':20, 
                            '/people/person/religion':21, 
                            '/sports/sports_team/location':22, 
                            '/sports/sports_team_location/teams':23}
        self.special_token_id = 151665
        self.special_token_end_id = 151666
        # self.all_predicate = set()
        if tokenizer is None:
            self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
            special_tokens = {"additional_special_tokens": ["<|emb|>"]}
            self.tokenizer.add_special_tokens(special_tokens)        
            self.special_token_id = self.tokenizer.additional_special_tokens_ids[0]
        else:
            self.tokenizer = tokenizer
        self.valid_data = []
        if os.path.exists(self.cache_path):
            print(f"Loading cached filtered data from {self.cache_path}...")
            with open(self.cache_path, 'r') as f:
                self.valid_data = json.load(f)
        else:
            print("Loading and preprocessing data...")
            with open(data_path, 'r') as f:
                self.data = json.load(f)
            self.valid_data = []
            for item in tqdm(self.data):
                processed = self.is_valid(item)
                self.valid_data += processed
            # 保存预处理后的数据
            with open(self.cache_path, 'w', encoding='utf-8') as f:
                json.dump(self.valid_data, f, indent=4, ensure_ascii=False)
            print(f"Filtered data saved to {self.cache_path}")
        print(f"Initial data size: {len(self.data)}")
        print(f"Filtered data size: {len(self.valid_data)}")            


    def __len__(self):
        # return len(self.data)
        return len(self.valid_data)
    
    def __getitem__(self, idx):
        # item = self.data[idx]
        # item = self.valid_data[idx]
        item = self.preprocess(self.valid_data[idx])
        return item

    def is_valid(self, item):
        """
        判断一个样本在插入特殊标记后,经过最大长度阶段剩下的特殊token是否等于5个。
        """
        text = item['text']
        origin_text = text

        special_token_start = "<|emb|>"
        special_token_end = "<|emb_end|>"
        # Add_template = False
        Add_template = self.add_template
            
        result = []
        
        for relation in item['relation_list']:
            text =  origin_text
            cur_item = {}
            cur_item['text'] = text
            # cur_item['text'] = text_with_template if Add_template else text
            relation['rel_label'] = self.relation_map[relation['predicate']]
            subj ,obj = relation['subj_char_span'], relation['obj_char_span']

            entities = [(subj, 'subj'),(obj, 'obj')]
            entities.sort(key=lambda x: (x[0][0], -x[0][1]))

            # 按照从左到右的顺序重构字符串
            result_parts = []
            current_pos = 0
            
            first_entity_pos, _ = entities[0]
            second_entity_pos, _ = entities[1]

            subj_tok_span, obj_tok_span = relation["subj_tok_span"], relation["obj_tok_span"]

            # Part 1: 从文本开始到第一个实体
            result_parts.append(text[current_pos:first_entity_pos[0]])
            
            # Part 2: 第一个实体
            result_parts.append(special_token_start + text[first_entity_pos[0]:first_entity_pos[1]] + special_token_end)
            current_pos = first_entity_pos[1]
            
            # Part 3: 两个实体之间
            result_parts.append(text[current_pos:second_entity_pos[0]])
            
            # Part 4: 第二个实体
            result_parts.append(special_token_start + text[second_entity_pos[0]:second_entity_pos[1]] + special_token_end)
            current_pos = second_entity_pos[1]
            
            # Part 5: 从第二个实体到文本末尾
            result_parts.append(text[current_pos:])
            text = ''.join(result_parts)


            # offset = 0
            # for entity in [subj,obj]:
            #     # start = text.find(entity, offset)
            #     start = text.find(entity)
            #     if start == -1:
            #         continue
            #     end = start + len(entity)
            #     # 插入前后特殊 token
            #     text = text[:start] + special_token_start + entity + special_token_end + text[end:]
            #     # 更新偏移量，避免错误插入
            #     offset = end + len(special_token_start) + len(special_token_end)

            # token index 变化分析
            # subject begin +1, subject_end +1
            # object begin +3, object_end +3
            # 原始标注数据中没有对subject 与object进行区分,比较begin_idx,小的当作subject,大的当作object

            # if relation["subj_tok_span"][0] < relation["obj_tok_span"][0]:
            #     relation["subj_tok_span"][0] += 1
            #     relation["subj_tok_span"][1] += 1
            #     relation["obj_tok_span"][0] += 3
            #     relation["obj_tok_span"][1] += 3
            # else:
            #     relation["subj_tok_span"][0] += 3
            #     relation["subj_tok_span"][1] += 3
            #     relation["obj_tok_span"][0] += 1
            #     relation["obj_tok_span"][1] += 1


            
            # if Add_template:
            #     relation["subj_tok_span"]  = [x + 24 for x in relation["subj_tok_span"]]
            #     relation["obj_tok_span"]  = [x + 24 for x in relation["obj_tok_span"]]
            #     # subj_tok_span += 24
            #     # obj_tok_span += 24
            # subj_tok_span, obj_tok_span = relation["subj_tok_span"], relation["obj_tok_span"]
            cur_item['relations'] = relation

            if Add_template:
                messages = [{"role": "user", "content": text}]
                text_with_template = self.tokenizer.apply_chat_template(
                    messages,
                    tokenize=False,
                    add_generation_prompt=False
                )
                cur_item['text'] = text_with_template
                # ori_text_input = self.tokenizer(text_with_template, return_tensors='pt')
                text_input = self.tokenizer(text_with_template, padding='max_length',max_length=256,truncation=True, return_tensors='pt')
                # text_add_template_length = len(ori_text_input['input_ids'][0])
                # text_length = len(self.tokenizer(text)['input_ids'])
                # template_length = text_add_template_length - text_length # 26 = 24 +2
            else:
                text_input = self.tokenizer(text, padding='max_length',max_length=256,truncation=True,return_tensors='pt')
                # relation = item['']

            # check entity
            # 需要完善
            start_indices = (text_input['input_ids'] == self.special_token_id).nonzero(as_tuple=True)  # (batch_idx, start_idx)
            end_indices = (text_input['input_ids'] == self.special_token_end_id).nonzero(as_tuple=True)  # (batch_idx, end_idx)
            subj_tok_span[0] = start_indices[1][0].item()
            subj_tok_span[1] = end_indices[1][0].item()
            obj_tok_span[0] = start_indices[1][1].item()
            obj_tok_span[1] = end_indices[1][1].item()


            sub_word = relation['subject']
            obj_word = relation['object']

            sub_decode = self.tokenizer.decode(text_input['input_ids'][0][subj_tok_span[0]+1:subj_tok_span[1]])
            obj_decode = self.tokenizer.decode(text_input['input_ids'][0][obj_tok_span[0]+1:obj_tok_span[1]])
            
            if sub_decode not in  sub_word and sub_decode not in  obj_word:
                print("subj decode error: {}---gold: {}".format(sub_decode, subj))
            if obj_decode not in  obj_word and obj_decode not in  sub_word: 
                print("obj decode error: {}---gold: {}".format(obj_decode, obj))
            cur_item['relations']['subj_tok_span'] = subj_tok_span
            cur_item['relations']['obj_tok_span'] = obj_tok_span
            result.append(cur_item)
            all_predicate.add(relation['predicate'])
        # print(self.all_predicate)
        return result


    def preprocess(self, item):
        text = item['text']
        text_input = self.tokenizer(text, padding='max_length',max_length=256,truncation=True,return_tensors='pt')
        result = {}
        result['input_ids'] = text_input['input_ids']
        result['attention_mask'] = text_input['attention_mask']
        result['labels'] = item['relations']
        return result



class Bert_dataset(Dataset):
    def __init__(self, data_path, 
                tokenizer  = None,
                cache_path = None 
                ):
        super().__init__()
        self.data = json.load(open(data_path, 'r'))
        self.cache_path = cache_path
        self.relation_map = {'/business/company/advisors':0,
                            '/business/company/founders':1, 
                            '/business/company/industry':2, 
                            '/business/company/major_shareholders':3, 
                            '/business/company/place_founded':4, 
                            '/business/company_shareholder/major_shareholder_of':5, 
                            '/business/person/company':6, 
                            '/location/administrative_division/country':7, 
                            '/location/country/administrative_divisions':8, 
                            '/location/country/capital':9, 
                            '/location/location/contains':10, 
                            '/location/neighborhood/neighborhood_of':11, 
                            '/people/deceased_person/place_of_death':12, 
                            '/people/ethnicity/geographic_distribution':13, 
                            '/people/ethnicity/people':14, 
                            '/people/person/children':15, 
                            '/people/person/ethnicity':16, 
                            '/people/person/nationality':17, 
                            '/people/person/place_lived':18, 
                            '/people/person/place_of_birth':19, 
                            '/people/person/profession':20, 
                            '/people/person/religion':21, 
                            '/sports/sports_team/location':22, 
                            '/sports/sports_team_location/teams':23}
        # self.all_predicate = set()
        if tokenizer is None:
            self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
            special_tokens = {"additional_special_tokens": ["<|emb|>"]}
            self.tokenizer.add_special_tokens(special_tokens)        
            self.special_token_id = self.tokenizer.additional_special_tokens_ids[0]
        else:
            self.tokenizer = tokenizer
        self.valid_data = []
        if os.path.exists(self.cache_path):
            print(f"Loading cached filtered data from {self.cache_path}...")
            with open(self.cache_path, 'r') as f:
                self.valid_data = json.load(f)
        else:
            print("Loading and preprocessing data...")
            with open(data_path, 'r') as f:
                self.data = json.load(f)
            self.valid_data = []
            for item in tqdm(self.data):
                processed = self.is_valid(item)
                self.valid_data += processed
            # 保存预处理后的数据
            with open(self.cache_path, 'w', encoding='utf-8') as f:
                json.dump(self.valid_data, f, indent=4, ensure_ascii=False)
            print(f"Filtered data saved to {self.cache_path}")
        print(f"Initial data size: {len(self.data)}")
        print(f"Filtered data size: {len(self.valid_data)}")            


    def __len__(self):
        # return len(self.data)
        return len(self.valid_data)
    
    def __getitem__(self, idx):
        # item = self.data[idx]
        # item = self.valid_data[idx]
        item = self.preprocess(self.valid_data[idx])
        return item

    def is_valid(self, item):
        """
        判断一个样本在插入特殊标记后,经过最大长度阶段剩下的特殊token是否等于5个。
        """
        text = item['text']
        Add_template = False

        if Add_template:
            messages = [{"role": "user", "content": text}]
            text_with_template = self.tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=False
            )
            
            input_ids = self.tokenizer(text_with_template, return_tensors='pt')['input_ids'][0]
            text_add_template_length = len(input_ids)
            text_length = len(self.tokenizer(text)['input_ids'])
            template_length = text_add_template_length - text_length
        else:
            text_input = self.tokenizer(text, padding='max_length',max_length=256,truncation=True,return_tensors='pt')
            # relation = item['']
            
        result = []
        
        for relation in item['relation_list']:
            cur_item = {}
            cur_item['text'] = text
            relation['rel_label'] = self.relation_map[relation['predicate']]
            cur_item['relations'] = relation

            subj ,obj = relation['subject'], relation['object']
            subj_tok_span, obj_tok_span = relation["subj_tok_span"], relation["obj_tok_span"]

            # # check entity
            # sub_decode = self.tokenizer.decode(text_input['input_ids'][0][subj_tok_span[0]:subj_tok_span[1]])
            # obj_decode = self.tokenizer.decode(text_input['input_ids'][0][obj_tok_span[0]:obj_tok_span[1]])
            
            # if sub_decode != subj:
            #     print("subj decode error: {}---gold: {}".format(sub_decode, subj))
            # if obj_decode != obj:
            #     print("obj decode error: {}---gold: {}".format(obj_decode, obj))
            result.append(cur_item)
            # all_predicate.add(relation['predicate'])
        # print(self.all_predicate)
        return result


    def preprocess(self, item):
        text = item['text']
        text_input = self.tokenizer(text, padding='max_length',max_length=256,truncation=True,return_tensors='pt')
        result = {}
        result['input_ids'] = text_input['input_ids']
        result['attention_mask'] = text_input['attention_mask']
        result['labels'] = item['relations']
        return result


class Bert_dataset_list(Dataset):
    def __init__(self, data_path, 
                tokenizer  = None,
                cache_path = None 
                ):
        super().__init__()
        self.data = json.load(open(data_path, 'r'))
        self.cache_path = cache_path
        self.relation_map = {'/business/company/advisors':0,
                            '/business/company/founders':1, 
                            '/business/company/industry':2, 
                            '/business/company/major_shareholders':3, 
                            '/business/company/place_founded':4, 
                            '/business/company_shareholder/major_shareholder_of':5, 
                            '/business/person/company':6, 
                            '/location/administrative_division/country':7, 
                            '/location/country/administrative_divisions':8, 
                            '/location/country/capital':9, 
                            '/location/location/contains':10, 
                            '/location/neighborhood/neighborhood_of':11, 
                            '/people/deceased_person/place_of_death':12, 
                            '/people/ethnicity/geographic_distribution':13, 
                            '/people/ethnicity/people':14, 
                            '/people/person/children':15, 
                            '/people/person/ethnicity':16, 
                            '/people/person/nationality':17, 
                            '/people/person/place_lived':18, 
                            '/people/person/place_of_birth':19, 
                            '/people/person/profession':20, 
                            '/people/person/religion':21, 
                            '/sports/sports_team/location':22, 
                            '/sports/sports_team_location/teams':23}
        # self.all_predicate = set()
        if tokenizer is None:
            self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
            special_tokens = {"additional_special_tokens": ["<|emb|>"]}
            self.tokenizer.add_special_tokens(special_tokens)        
            self.special_token_id = self.tokenizer.additional_special_tokens_ids[0]
        else:
            self.tokenizer = tokenizer
        self.valid_data = []
        if os.path.exists(self.cache_path):
            print(f"Loading cached filtered data from {self.cache_path}...")
            with open(self.cache_path, 'r') as f:
                self.valid_data = json.load(f)
        else:
            print("Loading and preprocessing data...")
            with open(data_path, 'r') as f:
                self.data = json.load(f)
            self.valid_data = []
            for item in tqdm(self.data):
                processed = self.is_valid(item)
                self.valid_data += processed
            # 保存预处理后的数据
            with open(self.cache_path, 'w', encoding='utf-8') as f:
                json.dump(self.valid_data, f, indent=4, ensure_ascii=False)
            print(f"Filtered data saved to {self.cache_path}")
        print(f"Initial data size: {len(self.data)}")
        print(f"Filtered data size: {len(self.valid_data)}")            


    def __len__(self):
        # return len(self.data)
        return len(self.valid_data)
    
    def __getitem__(self, idx):
        # item = self.data[idx]
        # item = self.valid_data[idx]
        item = self.preprocess(self.valid_data[idx])
        return item

    def is_valid(self, item):
        """
        判断一个样本在插入特殊标记后,经过最大长度阶段剩下的特殊token是否等于5个。
        """
        text = item['text']
        # Add_template = False

        # if Add_template:
        #     messages = [{"role": "user", "content": text}]
        #     text_with_template = self.tokenizer.apply_chat_template(
        #         messages,
        #         tokenize=False,
        #         add_generation_prompt=False
        #     )
            
        #     input_ids = self.tokenizer(text_with_template, return_tensors='pt')['input_ids'][0]
        #     text_add_template_length = len(input_ids)
        #     text_length = len(self.tokenizer(text)['input_ids'])
        #     template_length = text_add_template_length - text_length
        # else:
        #     text_input = self.tokenizer(text, padding='max_length',max_length=256,truncation=True,return_tensors='pt')
        #     # relation = item['']
            
        result = []
        merged_items = {}
        for relation in item['relation_list']:
            cur_item = {}
            cur_item['text'] = text
            relation['rel_label'] = self.relation_map[relation['predicate']]
            cur_item['relations'] = relation

            # subj ,obj = relation['subject'], relation['object']
            subj_tok_span, obj_tok_span = relation["subj_tok_span"], relation["obj_tok_span"]
            key = ( tuple(subj_tok_span), tuple(obj_tok_span))
            if key in merged_items:
                merged_items[key]['relations']['rel_label'].append(relation['rel_label'])
            else:
                relation['rel_label'] = [relation['rel_label']]
                cur_item['relations'] = relation
                merged_items[key] = cur_item

        result = list(merged_items.values())
        return result


    def preprocess(self, item):
        text = item['text']
        text_input = self.tokenizer(text, padding='max_length',max_length=256,truncation=True,return_tensors='pt')
        result = {}
        result['input_ids'] = text_input['input_ids']
        result['attention_mask'] = text_input['attention_mask']
        result['labels'] = item['relations']
        return result


def convert_to_multi_hot(labels, num_classes = 24):
    """
    将实际类别标签转换为多热编码。

    参数:
    labels (list): 包含每个样本实际类别索引列表的列表。
    num_classes (int): 类别总数。

    返回:
    torch.Tensor: 多热编码的张量。
    """
    batch_size = len(labels)
    multi_hot = torch.zeros(batch_size, num_classes)
    # multi_hot = torch.full((batch_size, num_classes), -1)
    for i, sample_labels in enumerate(labels):
        valid_labels = [label.item() for label in sample_labels if isinstance(label.item(), int) and 0 <= label.item() < num_classes]
        multi_hot[i,valid_labels] = 1
        # multi_hot[i, sample_labels] = 1
    return multi_hot

def data_list_collator(batch):
    # 将list[dict]转换为dict[list]
    result = {}
    keys = list(batch[0].keys())
    for key in keys[:2]:
        result[key] = torch.stack([item[key] for item in batch]).squeeze(1)
    label = keys[-1]
    labels = [torch.tensor(item[label]['rel_label']) for item in batch]
    result['labels'] = convert_to_multi_hot(pad_sequence(labels, batch_first=True,padding_value = -1))
    # result['labels'] = torch.tensor([item[label]['rel_label'] for item in batch])
    result['span_info'] = [item[label] for item in batch]
    
    return result

def data_list_collator_tacred(batch):
    # 将list[dict]转换为dict[list]
    result = {}
    keys = list(batch[0].keys())
    for key in keys[:2]:
        result[key] = torch.stack([item[key] for item in batch]).squeeze(1)
    label = keys[-1]
    labels = [torch.tensor(item[label]['rel_label']) for item in batch]
    result['labels'] = convert_to_multi_hot(pad_sequence(labels, batch_first=True,padding_value = -1),num_classes=42)
    # result['labels'] = torch.tensor([item[label]['rel_label'] for item in batch])
    result['span_info'] = [item[label] for item in batch]
    
    return result

def data_list_collator_retacred(batch):
    # 将list[dict]转换为dict[list]
    result = {}
    keys = list(batch[0].keys())
    for key in keys[:2]:
        result[key] = torch.stack([item[key] for item in batch]).squeeze(1)
    label = keys[-1]
    labels = [torch.tensor(item[label]['rel_label']) for item in batch]
    result['labels'] = convert_to_multi_hot(pad_sequence(labels, batch_first=True,padding_value = -1),num_classes=40)
    # result['labels'] = torch.tensor([item[label]['rel_label'] for item in batch])
    result['span_info'] = [item[label] for item in batch]
    
    return result


def my_data_collator(batch):
    # 将list[dict]转换为dict[list]
    result = {}
    keys = list(batch[0].keys())
    for key in keys[:2]:
        result[key] = torch.stack([item[key] for item in batch]).squeeze(1)
    label = keys[-1]
    
    result['labels'] = torch.tensor([item[label]['rel_label'] for item in batch])
    result['span_info'] = [item[label] for item in batch]
    # result['span_info'] = [item[label] for item in batch]
    
    return result


if __name__ == "__main__":
    mydataset = Bert_dataset_list(data_path="./unirel/UniRel/data_preprocess/data4qwen/nyt_qwen/test_data.json",
                                    cache_path="./dataset/cache_0620_bert.json")
    # print(all_predicate)
    dataloader = DataLoader(mydataset,
                            batch_size=2,
                            shuffle=True)
    for i in dataloader:
        print(i)
        break