#!/usr/bin/env python
# coding: utf-8

# In[1]:


from tqdm import tqdm
import os
import pandas as pd
import numpy as np
import torch
import pickle
import gc


# In[ ]:


# 参数设置
smi_prot_file = "./multidata/smi_prot.xlsx"
id2refseq_file = "./multidata/ID2Refseq.xlsx"
mrna_id_file = os.path.join("./multidata/gene_seq_data", "mRNA_ID.txt")
mrna_seq_file = os.path.join("./multidata/gene_seq_data", "mRNA_Seq.txt")
prot_emb_dir = "./multidata/prot_emb_data"
prot_graph_dir = "./multidata/prot_graph_data"
batch_size = 1000  # 每个批次保存的数据条数，可根据数据量调整
output_dir = "E:\BIBM2025\data\multidata/output_batches"
os.makedirs(output_dir, exist_ok=True)

# 1. 读取药物数据和 UniprotID 到 GeneID 映射
# 只读取前三列，降低内存开销
smi_prot_df = pd.read_excel(smi_prot_file, header=None, usecols=[0, 1, 2])
smi_prot_df.columns = ['smiles', 'target1', 'target2']

# 读取ID映射文件，并清洗空格和换行符
id2refseq_df = pd.read_excel(id2refseq_file, header=None, usecols=[0, 1])
id2refseq_df.columns = ['uniprot', 'geneid']
id2refseq_df['uniprot'] = id2refseq_df['uniprot'].astype(str).str.strip()
id2refseq_df['geneid'] = id2refseq_df['geneid'].astype(str).str.strip()
# 构造字典：UniprotID -> GeneID
uniprot_to_gene = dict(zip(id2refseq_df['uniprot'], id2refseq_df['geneid']))
# 删除中间变量，释放内存
del id2refseq_df
gc.collect()

# 2. 读取核苷酸序列数据，采用按行读取的方式构造映射
with open(mrna_id_file, "r") as f:
    mrna_ids = [line.strip() for line in f]
with open(mrna_seq_file, "r") as f:
    mrna_seqs = [line.strip().upper() for line in f]

# 清洗 mRNA id 中的空格和换行符
mrna_ids = [id_.strip() for id_ in mrna_ids]
gene_to_seq = dict(zip(mrna_ids, mrna_seqs))
# 释放列表
del mrna_ids, mrna_seqs
gc.collect()

# 定义核苷酸转 token 的映射
nuc_to_token = {'A': 1, 'C': 2, 'G': 3, 'T': 4}

def convert_seq_to_tokens(seq):
    """将核苷酸序列转换为 token 序列，若存在非ACGT字符则返回 None"""
    tokens = []
    for nuc in seq:
        if nuc not in nuc_to_token:
            return None
        tokens.append(nuc_to_token[nuc])
    return tokens

def process_row(row):
    """处理单行数据，返回格式化后的记录或 None"""
    smiles = row['smiles']
    # 清洗目标蛋白的 UniprotID
    target1 = str(row['target1']).strip()
    target2 = str(row['target2']).strip()
    
    if not target1 or not target2:
        return None

    target_info = []
    for target in [target1, target2]:
        geneid = uniprot_to_gene.get(target)
        if geneid is None:
            return None
        
        seq = gene_to_seq.get(geneid)
        if seq is None:
            return None
        token_seq = convert_seq_to_tokens(seq)
        if token_seq is None:
            return None
        
        # 检查蛋白质嵌入特征文件是否存在并加载
        emb_path = os.path.join(prot_emb_dir, f"{target}.npy")
        if not os.path.exists(emb_path):
            return None
        try:
            prot_emb = np.load(emb_path)
        except Exception:
            return None
        
        # 检查蛋白质图特征文件是否存在并加载
        graph_path = os.path.join(prot_graph_dir, f"alphafold_{target}.pt")
        if not os.path.exists(graph_path):
            return None
        try:
            prot_graph = torch.load(graph_path, map_location=torch.device('cpu'))
        except Exception:
            return None
        
        target_info.append((token_seq, prot_emb, prot_graph))
    
    if len(target_info) != 2:
        return None

    return (smiles, 
            target_info[0][0], target_info[0][1], target_info[0][2],
            target_info[1][0], target_info[1][1], target_info[1][2])

# 3. 逐行处理并批次保存数据
batch = []
valid_count = 0
total_count = 0

# 使用 DataFrame 的迭代器逐行遍历，避免一次加载所有行
for _, row in tqdm(smi_prot_df.iterrows()):
    total_count += 1
    record = process_row(row)
    if record is not None:
        valid_count += 1
        batch.append(record)
    
    # 当达到一个批次大小时，保存数据并清空 batch
    if len(batch) >= batch_size:
        batch_filename = os.path.join(output_dir, f"batch_{valid_count // batch_size}.pkl")
        with open(batch_filename, "wb") as f:
            pickle.dump(batch, f)
        print(f"保存了 {len(batch)} 条数据到 {batch_filename}")
        batch = []
        gc.collect()

# 保存最后不足一批的数据
if batch:
    batch_filename = os.path.join(output_dir, f"batch_{(valid_count // batch_size) + 1}.pkl")
    with open(batch_filename, "wb") as f:
        pickle.dump(batch, f)
    print(f"保存了 {len(batch)} 条数据到 {batch_filename}")

print(f"总共处理 {total_count} 条药物数据，符合条件的数据有 {valid_count} 条。")


# In[ ]:




