import os
import glob
import torch
import argparse
from tqdm import tqdm



from pdb_graph_builder import PDBGraphBuilder





NODE_CHEM_SLICE = slice(28, 32)
NODE_RHO_SLICE = slice(38, 43)



EDGE_CHEM_DIFF_SLICE = slice(29, 31)  


def main(args):
    
    input_root = os.path.abspath(args.pdb_dir)
    output_root = os.path.abspath(args.output_dir)
    os.makedirs(output_root, exist_ok=True)
    
    
    pdb_files = glob.glob(os.path.join(input_root, '**', '*.pdb'), recursive=True)
    
    if not pdb_files:
        print(f"éè¯¯ï¼å¨ç®å½ '{input_root}' åå¶å­ç®å½ä¸­æ²¡ææ¾å°ä»»ä½ .pdb æä»¶ã")
        return

    print(f"æ¾å° {len(pdb_files)} ä¸ªPDBæä»¶ãå¼å§ç¬¬ä¸é¶æ®µï¼ç¹å¾æå...")

    
    builder = PDBGraphBuilder()
    raw_graphs_and_paths = []
    
    
    node_chem_features, node_rho_features = [], []
    interface_edge_chem_diff = []

    for pdb_path in tqdm(pdb_files, desc="Pass 1/2: Extracting features"):
        
        graph_data = builder.build(pdb_path)
        if graph_data is None:
            print(f"è­¦åï¼å¤ç {os.path.basename(pdb_path)} å¤±è´¥ï¼å·²è·³è¿ã")
            continue
        
        
        relative_path = os.path.relpath(pdb_path, input_root)
        output_path_pt = os.path.join(output_root, relative_path).replace('.pdb', '.pt')
        
        raw_graphs_and_paths.append({'graph': graph_data, 'path': output_path_pt})

        
        
        for chain_graph in [graph_data['chain_1_graph'], graph_data['chain_2_graph']]:
            if chain_graph.x.shape[0] > 0:
                node_chem_features.append(chain_graph.x[:, NODE_CHEM_SLICE])
                node_rho_features.append(chain_graph.x[:, NODE_RHO_SLICE])
        
        
        interface_graph = graph_data['interface_graph']
        if interface_graph.edge_attr is not None and interface_graph.edge_attr.shape[0] > 0:
            interface_edge_chem_diff.append(interface_graph.edge_attr[:, EDGE_CHEM_DIFF_SLICE])

    print("ç¬¬ä¸é¶æ®µå®æãå¼å§è®¡ç®å¨å±å½ä¸åç»è®¡æ°æ®...")

    
    stats = {}
    if node_chem_features:
        all_node_chem = torch.cat(node_chem_features, dim=0)
        stats['node_chem_mean'] = torch.mean(all_node_chem, dim=0)
        stats['node_chem_std'] = torch.std(all_node_chem, dim=0)
    if node_rho_features:
        all_node_rho = torch.cat(node_rho_features, dim=0)
        stats['node_rho_mean'] = torch.mean(all_node_rho, dim=0)
        stats['node_rho_std'] = torch.std(all_node_rho, dim=0)
    if interface_edge_chem_diff:
        all_edge_chem_diff = torch.cat(interface_edge_chem_diff, dim=0)
        stats['interface_edge_chem_diff_mean'] = torch.mean(all_edge_chem_diff, dim=0)
        stats['interface_edge_chem_diff_std'] = torch.std(all_edge_chem_diff, dim=0)

    
    stats_path = os.path.join(output_root, 'normalization_stats.pt')
    torch.save(stats, stats_path)
    print(f"å½ä¸åç»è®¡æ°æ®å·²ä¿å­è³: {stats_path}")
    
    print("\nå¼å§ç¬¬äºé¶æ®µï¼åºç¨å½ä¸åå¹¶ä¿å­å¤çå¥½çå¾...")

    
    for item in tqdm(raw_graphs_and_paths, desc="Pass 2/2: Normalizing and saving"):
        graph_data = item['graph']
        output_path = item['path']

        
        for chain_graph in [graph_data['chain_1_graph'], graph_data['chain_2_graph']]:
            if chain_graph.x.shape[0] > 0:
                if 'node_chem_mean' in stats:
                    chain_graph.x[:, NODE_CHEM_SLICE] = (chain_graph.x[:, NODE_CHEM_SLICE] - stats['node_chem_mean']) / (stats['node_chem_std'] + 1e-8)
                if 'node_rho_mean' in stats:
                    chain_graph.x[:, NODE_RHO_SLICE] = (chain_graph.x[:, NODE_RHO_SLICE] - stats['node_rho_mean']) / (stats['node_rho_std'] + 1e-8)
        
        
        interface_graph = graph_data['interface_graph']
        if interface_graph.edge_attr is not None and interface_graph.edge_attr.shape[0] > 0:
            if 'interface_edge_chem_diff_mean' in stats:
                interface_graph.edge_attr[:, EDGE_CHEM_DIFF_SLICE] = (interface_graph.edge_attr[:, EDGE_CHEM_DIFF_SLICE] - stats['interface_edge_chem_diff_mean']) / (stats['interface_edge_chem_diff_std'] + 1e-8)

        
        output_dir_for_file = os.path.dirname(output_path)
        os.makedirs(output_dir_for_file, exist_ok=True)
        torch.save(graph_data, output_path)

    print(f"\nå¤çå®æï¼ææ {len(raw_graphs_and_paths)} ä¸ªå¾å·²å½ä¸åå¹¶ä¿å­è³ '{output_root}' ç®å½ä¸­ï¼å¹¶ä¿æäºåæå­æä»¶å¤¹ç»æã")


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="Recursively process PDB files from a directory into normalized graph data, preserving the subfolder structure."
    )
    parser.add_argument('--pdb_dir', type=str, required=True,
                        help="åå«PDBæä»¶çæºæä»¶å¤¹æ ¹ç®å½ã")
    parser.add_argument('--output_dir', type=str, required=True,
                        help="ç¨äºä¿å­å¤çåç .pt æä»¶çç®æ æä»¶å¤¹æ ¹ç®å½ã")
    args = parser.parse_args()
    
    main(args)