import random
import pickle
import pandas as pd
from Bio import SeqIO
from collections import defaultdict
from model.utils.utils import load_config

"""
根据家族信息进行划分训练集和测试集
"""


if __name__ == '__main__':
    config, config_name = load_config('../configs/base_pretrain_ddp.yaml')
    df = pd.read_csv(config.dataset.summary_path_dir, sep='\t')
    df = df[~(pd.isna(df['Hchain']) | pd.isna(df['Lchain']))]  # 排除掉重链或者轻链未知的数据
    fasta_dict = SeqIO.to_dict(SeqIO.parse(config.dataset.root_path_dir + '/processed_Chothia_onlyV/SAbDab_processed_onlyV.fasta', 'fasta'))

    subclass_dict = {}
    family_counts = defaultdict(int)
    # 统计家族信息组合
    for line in df.iterrows():
        pdb_id = '%s_%s_%s_%s' % (line[1].pdb, line[1].Hchain, line[1].Lchain, str(line[1].antigen_chain).split(' | ')[0])
        pdb_id = pdb_id.replace('nan', '')
        if pdb_id in fasta_dict:
            try:
                subclass_dict["%s-%s" % (line[1].heavy_subclass, line[1].light_subclass)].append(pdb_id)
            except:
                subclass_dict["%s-%s" % (line[1].heavy_subclass, line[1].light_subclass)] = [pdb_id]
            family_counts["%s-%s" % (line[1].heavy_subclass, line[1].light_subclass)] += 1
    # 过滤 unknown
    known_families = {k: v for k, v in family_counts.items() if "unknown" not in k.lower()}
    # 筛选有效家族
    min_threshold = 50  # 作为测试集的所需最小数据量
    max_threshold = 100  # 作为测试集的所需最大数据量
    # 划分训练集和测试集家族
    test_families = [fam for fam, cnt in known_families.items() if cnt >= min_threshold and cnt <= max_threshold]
    train_families = [fam for fam, cnt in known_families.items() if cnt < min_threshold or cnt > max_threshold]
    # 按照划分好的家族分开训练集和测试集
    train_ids = [pdb_id for train_family in train_families for pdb_id in subclass_dict[train_family]]
    test_ids = [pdb_id for test_family in test_families for pdb_id in subclass_dict[test_family]]
    # 将含有unknown的数据合并到训练集中
    train_ids.extend([pdb_id for subclass in subclass_dict.keys() for pdb_id in subclass_dict[subclass] if 'unknown' in subclass])
    # 保存ids
    pickle.dump(train_ids, open(config.dataset.root_path_dir + '/processed_IMGT_onlyV/train_ids.pkl', 'wb'))
    pickle.dump(test_ids, open(config.dataset.root_path_dir + '/processed_IMGT_onlyV/test_ids.pkl', 'wb'))
