import torch
from torch.utils.data import Dataset, DataLoader
import json
import os

import random
from transformers import AutoTokenizer
from tqdm import tqdm

all_predicate = set()

class qwen_dataset(Dataset):
    def __init__(self, data_path, 
                tokenizer  = None,
                cache_path = None,
                add_template = False 
                ):
        super().__init__()
        self.data = json.load(open(data_path, 'r'))
        self.add_template = add_template
        self.cache_path = cache_path
        self.relation_map = {'per:other_family': 0, 'per:charges': 1, 'per:stateorprovinces_of_residence': 2, 'per:title': 3, 'org:founded': 4, 'per:alternate_names': 5, 'per:date_of_birth': 6, 'no_relation': 7, 'org:parents': 8, 'per:stateorprovince_of_birth': 9, 'per:religion': 10, 'org:country_of_headquarters': 11, 'per:age': 12, 'per:countries_of_residence': 13, 'per:country_of_birth': 14, 'org:number_of_employees/members': 15, 'org:dissolved': 16, 'org:members': 17, 'per:spouse': 18, 'per:cities_of_residence': 19, 'org:subsidiaries': 20, 'per:city_of_death': 21, 'per:origin': 22, 'org:website': 23, 'per:cause_of_death': 24, 'org:stateorprovince_of_headquarters': 25, 'org:alternate_names': 26, 'per:city_of_birth': 27, 'per:employee_of': 28, 'per:country_of_death': 29, 'per:parents': 30, 'org:top_members/employees': 31, 'org:political/religious_affiliation': 32, 'per:children': 33, 'per:schools_attended': 34, 'per:siblings': 35, 'per:date_of_death': 36, 'org:founded_by': 37, 'org:shareholders': 38, 'org:city_of_headquarters': 39, 'per:stateorprovince_of_death': 40, 'org:member_of': 41}
        
        # {'/business/company/advisors':0,
        #                     '/business/company/founders':1, 
        #                     '/business/company/industry':2, 
        #                     '/business/company/major_shareholders':3, 
        #                     '/business/company/place_founded':4, 
        #                     '/business/company_shareholder/major_shareholder_of':5, 
        #                     '/business/person/company':6, 
        #                     '/location/administrative_division/country':7, 
        #                     '/location/country/administrative_divisions':8, 
        #                     '/location/country/capital':9, 
        #                     '/location/location/contains':10, 
        #                     '/location/neighborhood/neighborhood_of':11, 
        #                     '/people/deceased_person/place_of_death':12, 
        #                     '/people/ethnicity/geographic_distribution':13, 
        #                     '/people/ethnicity/people':14, 
        #                     '/people/person/children':15, 
        #                     '/people/person/ethnicity':16, 
        #                     '/people/person/nationality':17, 
        #                     '/people/person/place_lived':18, 
        #                     '/people/person/place_of_birth':19, 
        #                     '/people/person/profession':20, 
        #                     '/people/person/religion':21, 
        #                     '/sports/sports_team/location':22, 
        #                     '/sports/sports_team_location/teams':23}
        self.all_predicate = set()
        if tokenizer is None:
            self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
            special_tokens = {"additional_special_tokens": ["<|emb|>"]}
            self.tokenizer.add_special_tokens(special_tokens)        
            self.special_token_id = self.tokenizer.additional_special_tokens_ids[0]
        else:
            self.tokenizer = tokenizer
        self.valid_data = []
        if os.path.exists(self.cache_path):
            print(f"Loading cached filtered data from {self.cache_path}...")
            with open(self.cache_path, 'r') as f:
                self.valid_data = json.load(f)
        else:
            print("Loading and preprocessing data...")
            with open(data_path, 'r') as f:
                self.data = json.load(f)
            self.valid_data = []
            for item in tqdm(self.data):
                processed = self.is_valid(item)
                self.valid_data += processed
            print(self.all_predicate)
            # 保存预处理后的数据
            with open(self.cache_path, 'w', encoding='utf-8') as f:
                json.dump(self.valid_data, f, indent=4, ensure_ascii=False)
            print(f"Filtered data saved to {self.cache_path}")
        print(f"Initial data size: {len(self.data)}")
        print(f"Filtered data size: {len(self.valid_data)}")            


    def __len__(self):
        # return len(self.data)
        return len(self.valid_data)
    
    def __getitem__(self, idx):
        # item = self.data[idx]
        # item = self.valid_data[idx]
        item = self.preprocess(self.valid_data[idx])
        return item

    def is_valid(self, item):
        text = item['text']
        # Add_template = False
        Add_template = self.add_template

        if Add_template:
            messages = [{"role": "user", "content": text}]
            text_with_template = self.tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=False
            )
            
            # ori_text_input = self.tokenizer(text_with_template, return_tensors='pt')
            text_input = self.tokenizer(text_with_template, padding='max_length',max_length=256,truncation=True, return_tensors='pt')
            # text_add_template_length = len(ori_text_input['input_ids'][0])
            # text_length = len(self.tokenizer(text)['input_ids'])
            # template_length = text_add_template_length - text_length # 26 = 24 +2
        else:
            text_input = self.tokenizer(text, padding='max_length',max_length=256,truncation=True,return_tensors='pt')
            # relation = item['']
            
        
        cur_item = {}
        cur_item['text'] = text_with_template if Add_template else text
        # cur_item['relations'] = []
        relations = []
        
        for relation in item['relation_list']:
            # relation['rel_label'] = self.relation_map[relation['predicate']]
            # subj ,obj = relation['subject'], relation['object']
            # if Add_template:
            #     relation["subj_tok_span"]  = [x + 24 for x in relation["subj_tok_span"]]
            #     relation["obj_tok_span"]  = [x + 24 for x in relation["obj_tok_span"]]
            # subj_tok_span, obj_tok_span = relation["subj_tok_span"], relation["obj_tok_span"]
            # # check entity
            # sub_decode = self.tokenizer.decode(text_input['input_ids'][0][subj_tok_span[0]:subj_tok_span[1]])
            # obj_decode = self.tokenizer.decode(text_input['input_ids'][0][obj_tok_span[0]:obj_tok_span[1]])
            # if sub_decode != subj:
            #     print("subj decode error: {}---gold: {}".format(sub_decode, subj))
            # if obj_decode != obj:
            #     print("obj decode error: {}---gold: {}".format(obj_decode, obj))
            relations.append(relation)
            all_predicate.add(relation['predicate'])

        cur_item['relations'] = relations
        return [cur_item]


    def preprocess(self, item):
        text = item['text']
        text_input = self.tokenizer(text, padding='max_length',max_length=256,truncation=True,return_tensors='pt')
        result = {}
        result['input_ids'] = text_input['input_ids']
        result['attention_mask'] = text_input['attention_mask']
        result['labels'] = item['relations']
        return result

class Bert_dataset(Dataset):
    def __init__(self, data_path, 
                tokenizer  = None,
                cache_path = None 
                ):
        super().__init__()
        self.data = json.load(open(data_path, 'r'))
        self.cache_path = cache_path
        self.relation_map = {'/business/company/advisors':0,
                            '/business/company/founders':1, 
                            '/business/company/industry':2, 
                            '/business/company/major_shareholders':3, 
                            '/business/company/place_founded':4, 
                            '/business/company_shareholder/major_shareholder_of':5, 
                            '/business/person/company':6, 
                            '/location/administrative_division/country':7, 
                            '/location/country/administrative_divisions':8, 
                            '/location/country/capital':9, 
                            '/location/location/contains':10, 
                            '/location/neighborhood/neighborhood_of':11, 
                            '/people/deceased_person/place_of_death':12, 
                            '/people/ethnicity/geographic_distribution':13, 
                            '/people/ethnicity/people':14, 
                            '/people/person/children':15, 
                            '/people/person/ethnicity':16, 
                            '/people/person/nationality':17, 
                            '/people/person/place_lived':18, 
                            '/people/person/place_of_birth':19, 
                            '/people/person/profession':20, 
                            '/people/person/religion':21, 
                            '/sports/sports_team/location':22, 
                            '/sports/sports_team_location/teams':23}
        # self.all_predicate = set()
        if tokenizer is None:
            self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
            special_tokens = {"additional_special_tokens": ["<|emb|>"]}
            self.tokenizer.add_special_tokens(special_tokens)        
            self.special_token_id = self.tokenizer.additional_special_tokens_ids[0]
        else:
            self.tokenizer = tokenizer
        self.valid_data = []
        if os.path.exists(self.cache_path):
            print(f"Loading cached filtered data from {self.cache_path}...")
            with open(self.cache_path, 'r') as f:
                self.valid_data = json.load(f)
        else:
            print("Loading and preprocessing data...")
            with open(data_path, 'r') as f:
                self.data = json.load(f)
            self.valid_data = []
            for item in tqdm(self.data):
                processed = self.is_valid(item)
                self.valid_data += processed
            # 保存预处理后的数据
            with open(self.cache_path, 'w', encoding='utf-8') as f:
                json.dump(self.valid_data, f, indent=4, ensure_ascii=False)
            print(f"Filtered data saved to {self.cache_path}")
        print(f"Initial data size: {len(self.data)}")
        print(f"Filtered data size: {len(self.valid_data)}")            


    def __len__(self):
        # return len(self.data)
        return len(self.valid_data)
    
    def __getitem__(self, idx):
        # item = self.data[idx]
        # item = self.valid_data[idx]
        item = self.preprocess(self.valid_data[idx])
        return item

    def is_valid(self, item):
        """
        判断一个样本在插入特殊标记后,经过最大长度阶段剩下的特殊token是否等于5个。
        """
        text = item['text']
        Add_template = False

        if Add_template:
            messages = [{"role": "user", "content": text}]
            text_with_template = self.tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=False
            )
            
            input_ids = self.tokenizer(text_with_template, return_tensors='pt')['input_ids'][0]
            text_add_template_length = len(input_ids)
            text_length = len(self.tokenizer(text)['input_ids'])
            template_length = text_add_template_length - text_length
        else:
            text_input = self.tokenizer(text, padding='max_length',max_length=256,truncation=True,return_tensors='pt')
            # relation = item['']
            
        result = []
        
        for relation in item['relation_list']:
            cur_item = {}
            cur_item['text'] = text
            relation['rel_label'] = self.relation_map[relation['predicate']]
            cur_item['relations'] = relation

            subj ,obj = relation['subject'], relation['object']
            subj_tok_span, obj_tok_span = relation["subj_tok_span"], relation["obj_tok_span"]

            # # check entity
            # sub_decode = self.tokenizer.decode(text_input['input_ids'][0][subj_tok_span[0]:subj_tok_span[1]])
            # obj_decode = self.tokenizer.decode(text_input['input_ids'][0][obj_tok_span[0]:obj_tok_span[1]])
            
            # if sub_decode != subj:
            #     print("subj decode error: {}---gold: {}".format(sub_decode, subj))
            # if obj_decode != obj:
            #     print("obj decode error: {}---gold: {}".format(obj_decode, obj))
            result.append(cur_item)
            # all_predicate.add(relation['predicate'])
        # print(self.all_predicate)
        return result


    def preprocess(self, item):
        text = item['text']
        text_input = self.tokenizer(text, padding='max_length',max_length=256,truncation=True,return_tensors='pt')
        result = {}
        result['input_ids'] = text_input['input_ids']
        result['attention_mask'] = text_input['attention_mask']
        result['labels'] = item['relations']
        return result






def my_data_collator(batch):
    # 将list[dict]转换为dict[list]
    result = {}
    keys = list(batch[0].keys())
    for key in keys[:2]:
        result[key] = torch.stack([item[key] for item in batch]).squeeze(1)
    label = keys[-1]
    # result['labels'] = torch.tensor([item[label]['rel_label'] for item in batch])
    all_rel_labels = []
    for item in batch:
        for relation in item[label]:
            all_rel_labels.append(relation['rel_label'])
    result['labels'] = torch.tensor(all_rel_labels)
    result['span_info'] = [item[label] for item in batch]
    # result['span_info'] = [item[label] for item in batch]
    
    return result

class qwen_dataset_list(Dataset):
    def __init__(self, data_path, 
                tokenizer  = None,
                cache_path = None,
                add_template = False 
                ):
        super().__init__()
        self.data = json.load(open(data_path, 'r'))
        self.add_template = add_template
        self.cache_path = cache_path
        self.relation_map = {'per:other_family': 0, 'per:charges': 1, 'per:stateorprovinces_of_residence': 2, 'per:title': 3, 'org:founded': 4, 'per:alternate_names': 5, 'per:date_of_birth': 6, 'no_relation': 7, 'org:parents': 8, 'per:stateorprovince_of_birth': 9, 'per:religion': 10, 'org:country_of_headquarters': 11, 'per:age': 12, 'per:countries_of_residence': 13, 'per:country_of_birth': 14, 'org:number_of_employees/members': 15, 'org:dissolved': 16, 'org:members': 17, 'per:spouse': 18, 'per:cities_of_residence': 19, 'org:subsidiaries': 20, 'per:city_of_death': 21, 'per:origin': 22, 'org:website': 23, 'per:cause_of_death': 24, 'org:stateorprovince_of_headquarters': 25, 'org:alternate_names': 26, 'per:city_of_birth': 27, 'per:employee_of': 28, 'per:country_of_death': 29, 'per:parents': 30, 'org:top_members/employees': 31, 'org:political/religious_affiliation': 32, 'per:children': 33, 'per:schools_attended': 34, 'per:siblings': 35, 'per:date_of_death': 36, 'org:founded_by': 37, 'org:shareholders': 38, 'org:city_of_headquarters': 39, 'per:stateorprovince_of_death': 40, 'org:member_of': 41}

        # self.all_predicate = set()
        if tokenizer is None:
            self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
            special_tokens = {"additional_special_tokens": ["<|emb|>"]}
            self.tokenizer.add_special_tokens(special_tokens)        
            self.special_token_id = self.tokenizer.additional_special_tokens_ids[0]
        else:
            self.tokenizer = tokenizer
        self.valid_data = []
        if os.path.exists(self.cache_path):
            print(f"Loading cached filtered data from {self.cache_path}...")
            with open(self.cache_path, 'r') as f:
                self.valid_data = json.load(f)
        else:
            print("Loading and preprocessing data...")
            with open(data_path, 'r') as f:
                self.data = json.load(f)
            self.valid_data = []
            for item in tqdm(self.data):
                processed = self.is_valid(item)
                self.valid_data += processed
            # 保存预处理后的数据
            with open(self.cache_path, 'w', encoding='utf-8') as f:
                json.dump(self.valid_data, f, indent=4, ensure_ascii=False)
            print(f"Filtered data saved to {self.cache_path}")
        print(f"Initial data size: {len(self.data)}")
        print(f"Filtered data size: {len(self.valid_data)}")            


    def __len__(self):
        # return len(self.data)
        return len(self.valid_data)
    
    def __getitem__(self, idx):
        # item = self.data[idx]
        # item = self.valid_data[idx]
        item = self.preprocess(self.valid_data[idx])
        return item

    def is_valid(self, item):
        """
        判断一个样本在插入特殊标记后,经过最大长度阶段剩下的特殊token是否等于5个。
        """
        text = item['text']
        # Add_template = False
        Add_template = self.add_template

        if Add_template:
            messages = [{"role": "user", "content": text}]
            text_with_template = self.tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=False
            )
            
            # ori_text_input = self.tokenizer(text_with_template, return_tensors='pt')
            text_input = self.tokenizer(text_with_template, padding='max_length',max_length=512,truncation=True, return_tensors='pt')
            # text_add_template_length = len(ori_text_input['input_ids'][0])
            # text_length = len(self.tokenizer(text)['input_ids'])
            # template_length = text_add_template_length - text_length # 26 = 24 +2
        else:
            try:
                sub = item['entity_list'][0]['text']
                obj = item['entity_list'][1]['text']
                add_instruction = f"predict the relation between the entity pair{sub} and {obj}:"
            except:
                add_instruction = "predict the relation between the entity pair:"

            
            text_input = self.tokenizer(text, padding='max_length',max_length=512,truncation=True,return_tensors='pt')
            
            text = text + add_instruction


        # 创建一个字典来存储实体pair和对应的关系列表
        # entity_pair_dict = {}    
        result = []
        
        merged_items = {}
        for relation in item['relation_list']:
            cur_item = {}
            # cur_item['text'] = text
            cur_item['text'] = text_with_template if Add_template else text
            relation['rel_label'] = self.relation_map[relation['predicate']]
            

            subj ,obj = relation['subject'], relation['object']
        
                # subj_tok_span += 24
                # obj_tok_span += 24
            subj_tok_span, obj_tok_span = relation["subj_tok_span"], relation["obj_tok_span"]
            key = ( tuple(subj_tok_span), tuple(obj_tok_span))
            
            cur_item['relations'] = relation
            # check entity
            sub_decode = self.tokenizer.decode(text_input['input_ids'][0][subj_tok_span[0]:subj_tok_span[1]])
            obj_decode = self.tokenizer.decode(text_input['input_ids'][0][obj_tok_span[0]:obj_tok_span[1]])
            
            if sub_decode != subj:
                print("subj decode error: {}---gold: {}".format(sub_decode, subj))
            if obj_decode != obj:
                print("obj decode error: {}---gold: {}".format(obj_decode, obj))
            
            if key in merged_items:
                merged_items[key]['relations']['rel_label'].append(relation['rel_label'])
            else:
                relation['rel_label'] = [relation['rel_label']]
                cur_item['relations'] = relation
                merged_items[key] = cur_item
            
        result = list(merged_items.values())  
            # result.append(cur_item)
            # all_predicate.add(relation['predicate'])

            # # 创建实体pair的唯一标识
            # pair_key = (subj, obj, tuple(subj_tok_span), tuple(obj_tok_span))
            
            # # 如果这个实体pair已经存在，就添加到关系列表中
            # if pair_key in entity_pair_dict:
            #     entity_pair_dict[pair_key]['relations'].append(relation)
            # else:
            #     # 否则创建一个新的条目
            #     entity_pair_dict[pair_key] = {
            #         'text': text_content,
            #         'relations': [relation]
            #     }
    


        # print(self.all_predicate)
        return result


    def preprocess(self, item):
        text = item['text']
        text_input = self.tokenizer(text, padding='max_length',max_length=256,truncation=True,return_tensors='pt')
        result = {}
        result['input_ids'] = text_input['input_ids']
        result['attention_mask'] = text_input['attention_mask']
        result['labels'] = item['relations']
        return result

class qwen_dataset_list_addtype(Dataset):
    def __init__(self, data_path, 
                tokenizer  = None,
                cache_path = None,
                add_template = False 
                ):
        super().__init__()
        self.data = json.load(open(data_path, 'r'))
        self.add_template = add_template
        self.cache_path = cache_path
        self.relation_map = {'per:other_family': 0, 'per:charges': 1, 'per:stateorprovinces_of_residence': 2, 'per:title': 3, 'org:founded': 4, 'per:alternate_names': 5, 'per:date_of_birth': 6, 'no_relation': 7, 'org:parents': 8, 'per:stateorprovince_of_birth': 9, 'per:religion': 10, 'org:country_of_headquarters': 11, 'per:age': 12, 'per:countries_of_residence': 13, 'per:country_of_birth': 14, 'org:number_of_employees/members': 15, 'org:dissolved': 16, 'org:members': 17, 'per:spouse': 18, 'per:cities_of_residence': 19, 'org:subsidiaries': 20, 'per:city_of_death': 21, 'per:origin': 22, 'org:website': 23, 'per:cause_of_death': 24, 'org:stateorprovince_of_headquarters': 25, 'org:alternate_names': 26, 'per:city_of_birth': 27, 'per:employee_of': 28, 'per:country_of_death': 29, 'per:parents': 30, 'org:top_members/employees': 31, 'org:political/religious_affiliation': 32, 'per:children': 33, 'per:schools_attended': 34, 'per:siblings': 35, 'per:date_of_death': 36, 'org:founded_by': 37, 'org:shareholders': 38, 'org:city_of_headquarters': 39, 'per:stateorprovince_of_death': 40, 'org:member_of': 41}

        # self.all_predicate = set()
        if tokenizer is None:
            self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
            special_tokens = {"additional_special_tokens": ["<|emb|>"]}
            self.tokenizer.add_special_tokens(special_tokens)        
            self.special_token_id = self.tokenizer.additional_special_tokens_ids[0]
        else:
            self.tokenizer = tokenizer
        self.valid_data = []
        if os.path.exists(self.cache_path):
            print(f"Loading cached filtered data from {self.cache_path}...")
            with open(self.cache_path, 'r') as f:
                self.valid_data = json.load(f)
        else:
            print("Loading and preprocessing data...")
            with open(data_path, 'r') as f:
                self.data = json.load(f)
            self.valid_data = []
            for item in tqdm(self.data):
                processed = self.is_valid(item)
                self.valid_data += processed
            # 保存预处理后的数据
            with open(self.cache_path, 'w', encoding='utf-8') as f:
                json.dump(self.valid_data, f, indent=4, ensure_ascii=False)
            print(f"Filtered data saved to {self.cache_path}")
        print(f"Initial data size: {len(self.data)}")
        print(f"Filtered data size: {len(self.valid_data)}")            


    def __len__(self):
        # return len(self.data)
        return len(self.valid_data)
    
    def __getitem__(self, idx):
        # item = self.data[idx]
        # item = self.valid_data[idx]
        item = self.preprocess(self.valid_data[idx])
        return item

    def is_valid(self, item):
        text = item['text']
        # Add_template = False
        Add_template = self.add_template
        
        # subject = item['entity_list'][0]
        # object = item['entity_list'][1]
        # subject_type = item['sub_type']
        # object_type = item['obj_type']

        if Add_template:
            messages = [{"role": "user", "content": text}]
            text_with_template = self.tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=False
            )
            
            # ori_text_input = self.tokenizer(text_with_template, return_tensors='pt')
            text_input = self.tokenizer(text_with_template, padding='max_length',max_length=512,truncation=True, return_tensors='pt')
            # text_add_template_length = len(ori_text_input['input_ids'][0])
            # text_length = len(self.tokenizer(text)['input_ids'])
            # template_length = text_add_template_length - text_length # 26 = 24 +2
        else:
            try:
                sub = item['entity_list'][0]['text']
                obj = item['entity_list'][1]['text']
                add_instruction = f"predict the relation between the entity pair{sub} and {obj}:"
            except:
                add_instruction = "predict the relation between the entity pair:"

            
            text_input = self.tokenizer(text, padding='max_length',max_length=512,truncation=True,return_tensors='pt')
            
            text = text + add_instruction
            # relation = item['']


        # 创建一个字典来存储实体pair和对应的关系列表
        # entity_pair_dict = {}    
        result = []
        
        merged_items = {}
        for relation in item['relation_list']:
            cur_item = {}
            # cur_item['text'] = text
            cur_item['text'] = text_with_template if Add_template else text
            relation['rel_label'] = self.relation_map[relation['predicate']]
            

            subj ,obj = relation['subject'], relation['object']
        
                # subj_tok_span += 24
                # obj_tok_span += 24
            subj_tok_span, obj_tok_span = relation["subj_tok_span"], relation["obj_tok_span"]
            key = ( tuple(subj_tok_span), tuple(obj_tok_span))
            
            cur_item['relations'] = relation
            # check entity
            sub_decode = self.tokenizer.decode(text_input['input_ids'][0][subj_tok_span[0]:subj_tok_span[1]])
            obj_decode = self.tokenizer.decode(text_input['input_ids'][0][obj_tok_span[0]:obj_tok_span[1]])
            
            if sub_decode != subj:
                print("subj decode error: {}---gold: {}".format(sub_decode, subj))
            if obj_decode != obj:
                print("obj decode error: {}---gold: {}".format(obj_decode, obj))
            
            if key in merged_items:
                merged_items[key]['relations']['rel_label'].append(relation['rel_label'])
            else:
                relation['rel_label'] = [relation['rel_label']]
                cur_item['relations'] = relation
                merged_items[key] = cur_item
            
        result = list(merged_items.values())  
        return result


    def preprocess(self, item):
        text = item['text']
        text_input = self.tokenizer(text, padding='max_length',max_length=256,truncation=True,return_tensors='pt')
        result = {}
        result['input_ids'] = text_input['input_ids']
        result['attention_mask'] = text_input['attention_mask']
        result['labels'] = item['relations']
        return result


if __name__ == "__main__":
    mydataset = qwen_dataset_list_addtype(data_path="./unirel/UniRel/data_preprocess/tacrev_data4qwen_fullchar/tacrev/test.json",cache_path="./qwen_nyt/dataset_addtype/tacred/08102valid_data_cached.json")
    dataloader = DataLoader(mydataset,
                            batch_size=2,
                            shuffle=True)
    for i in dataloader:
        print(i)
        break