import os
import json
import numpy as np
from tqdm import tqdm
import random
import torch.utils.data as data
from .utils import cached_property
from .bio_tokenizer import BioTokenizer


class PAIRWISE_CATH_MixDataset(data.Dataset):
    def __init__(self, path='./', path_train='./', split='train', max_length=500, test_name='All', data=None, removeTS=0, version=4.2,test_source='default'):
        self.version = version
        self.path = path
        self.path_train = os.path.join(path,path_train)  # 新增路径，用于读取训练集
        self.mode = split
        self.max_length = max_length
        self.test_name = test_name
        self.removeTS = removeTS
        self.test_source = test_source
        if self.removeTS:
            self.remove = json.load(open(self.path+'/remove.json', 'r'))['remove']
        
        if data is None:
            self.data = self.cache_data[split]
        else:
            self.data = data
            
    @cached_property
    def cache_data(self):
        alphabet = 'ACDEFGHIKLMNPQRSTVWY'
        alphabet_set = set([a for a in alphabet])

        if not os.path.exists(self.path):
            raise Exception(f"No such file: {self.path} !!!")
        
        data_list = []
        train_data_list=[]
        tokenizer = BioTokenizer()
        # train_set
        if os.path.exists(self.path_train):
            with open(self.path_train) as f:
                lines = f.readlines()
            for line in tqdm(lines):
                entry = json.loads(line)
                if self.removeTS and entry['title'] in self.remove:
                    continue
                seq = entry['seq']
                for key, val in entry.items():
                    if key not in ["name",'seq']:
                        entry[key] = np.asarray(val)
                
                if isinstance(seq,list):
                    seq = tokenizer.decode([seq])[0]
                bad_chars = set([s for s in seq]).difference(alphabet_set)
                if len(bad_chars) == 0 and len(entry['seq']) <= self.max_length:
                    chain_length = len(entry['seq'])
                    chain_mask = np.ones(chain_length)
                    train_data_list.append({
                        'title': entry['name'],
                        'seq': entry['seq'],
                        'CA': entry["coords"][:,1,:],
                        'C': entry["coords"][:,2,:],
                        'O': entry["coords"][:,3,:],
                        'N': entry["coords"][:,0,:],
                        'chain_mask': chain_mask,
                        'chain_encoding': 1 * chain_mask
                    })

        # 2. val and test in cath 4.2
        if self.test_source == 'afdb':
            with open(self.path + '/chain_set_afdb.jsonl') as f:
                lines = f.readlines()
            for line in tqdm(lines):
                entry = json.loads(line)
                if self.removeTS and entry['name'] in self.remove:
                    continue
                seq = entry['seq']
                for key, val in entry.items():
                    if key not in ["name",'seq']:
                        entry[key] = np.asarray(val)
                if isinstance(seq,list):
                    seq = tokenizer.decode([seq])[0]
                bad_chars = set([s for s in seq]).difference(alphabet_set)
                if len(bad_chars) == 0 and len(entry['seq']) <= self.max_length:
                    chain_length = len(entry['seq'])
                    chain_mask = np.ones(chain_length)
                    data_list.append({
                        'title': entry['name'],
                        'seq': seq,
                        'CA': entry["coords"][:,1,:],
                        'C': entry["coords"][:,2,:],
                        'O': entry["coords"][:,3,:],
                        'N': entry["coords"][:,0,:],
                        'chain_mask': chain_mask,
                        'chain_encoding': 1 * chain_mask
                    })
        else:
            with open(self.path + '/chain_set.jsonl') as f:
                lines = f.readlines()
            for line in tqdm(lines):
                entry = json.loads(line)
                if self.removeTS and entry['name'] in self.remove:
                    continue
                seq = entry['seq']
                for key, val in entry['coords'].items():
                    entry['coords'][key] = np.asarray(val)

                bad_chars = set([s for s in seq]).difference(alphabet_set)
                if len(bad_chars) == 0 and len(entry['seq']) <= self.max_length:
                    chain_length = len(entry['seq'])
                    chain_mask = np.ones(chain_length)
                    data_list.append({
                        'title': entry['name'],
                        'seq': entry['seq'],
                        'CA': entry['coords']['CA'],
                        'C': entry['coords']['C'],
                        'O': entry['coords']['O'],
                        'N': entry['coords']['N'],
                        'chain_mask': chain_mask,
                        'chain_encoding': 1 * chain_mask
                    })
        
        # 加载数据集拆分（split）
        if self.version == 4.2 or self.version == 4.3:
            with open(self.path + '/chain_set_splits.json') as f:
                dataset_splits = json.load(f)

        # 根据 test_name 修改拆分
        if self.test_name == 'L100':
            with open(self.path + '/test_split_L100.json') as f:
                test_splits = json.load(f)
            dataset_splits['test'] = test_splits['test']

        if self.test_name == 'sc':
            with open(self.path + '/test_split_sc.json') as f:
                test_splits = json.load(f)
            dataset_splits['test'] = test_splits['test']

        
        # 分类数据
        name2set = {}
        #name2set.update({name: 'train' for name in dataset_splits['train']})
        name2set.update({name: 'valid' for name in dataset_splits['validation']})
        name2set.update({name: 'test' for name in dataset_splits['test']})

        data_dict = {'train': [], 'valid': [], 'test': []}
        data_dict['train'] = train_data_list
        for data in data_list:
            if name2set.get(data['title']):
                if name2set[data['title']] == 'valid':
                    data_dict['valid'].append(data)
                if name2set[data['title']] == 'test':
                    data['category'] = 'Unknown'
                    data['score'] = 100.0
                    data_dict['test'].append(data)

        return data_dict

    def change_mode(self, mode):
        self.data = self.cache_data[mode]
    
    def __len__(self):
        return len(self.data)
    
    def get_item(self, index):
        return self.data[index]
    
    def __getitem__(self, index):
        item = self.data[index]
        L = len(item['seq'])
        if L > self.max_length:
            max_index = L - self.max_length
            truncate_index = random.randint(0, max_index)
            item['seq'] = item['seq'][truncate_index:truncate_index+self.max_length]
            item['CA'] = item['CA'][truncate_index:truncate_index+self.max_length]
            item['C'] = item['C'][truncate_index:truncate_index+self.max_length]
            item['O'] = item['O'][truncate_index:truncate_index+self.max_length]
            item['N'] = item['N'][truncate_index:truncate_index+self.max_length]
            item['chain_mask'] = item['chain_mask'][truncate_index:truncate_index+self.max_length]
            item['chain_encoding'] = item['chain_encoding'][truncate_index:truncate_index+self.max_length]
        return item
