import os
os.environ.setdefault('OMP_NUM_THREADS', '1')
os.environ.setdefault('MKL_NUM_THREADS', '1')
os.environ.setdefault('OPENBLAS_NUM_THREADS', '1')
os.environ.setdefault('NUMEXPR_NUM_THREADS', '1')

import torch
import pickle
import tqdm
import glob
import signal
import random
import uuid
from multiprocessing import Pool, set_start_method
from rdkit import Chem
from rdkit.Chem import rdDetermineBonds
from rdkit import RDLogger
try:
    from rdkit import rdBase
except ImportError:
    rdBase = None
from torch_geometric.data import Data

RDLogger.DisableLog('rdApp.*')

MAX_SAMPLES = None
BATCH_SIZE = 5
TIMEOUT_SECONDS = 5.0

class TimeoutException(Exception): pass

def timeout_handler(signum, frame):
    raise TimeoutException("Timeout")

def bleach_mol(mol: Chem.Mol) -> Chem.Mol:
    """Strips properties to allow pure topological matching."""
    m = Chem.RWMol(mol)
    for b in m.GetBonds():
        b.SetBondType(Chem.BondType.SINGLE)
        b.SetIsAromatic(False)
        b.SetStereo(Chem.BondStereo.STEREONONE)
    for a in m.GetAtoms():
        a.SetFormalCharge(0)
        a.SetIsAromatic(False)
        a.SetChiralTag(Chem.ChiralType.CHI_UNSPECIFIED)
        a.SetNoImplicit(True)
    return m.GetMol()

def get_largest_fragment(mol: Chem.Mol) -> Chem.Mol:
    frags = Chem.GetMolFrags(mol, asMols=True, sanitizeFrags=False)
    if len(frags) == 1: return mol
    return max(frags, key=lambda m: m.GetNumAtoms())

def verify_geometry(tgt_mol, positions, threshold=3.5):
    if positions.shape[0] != tgt_mol.GetNumAtoms():
        return False, "Atom count mismatch"

    max_dist = 0.0
    for bond in tgt_mol.GetBonds():
        idx1 = bond.GetBeginAtomIdx()
        idx2 = bond.GetEndAtomIdx()
        
        p1 = positions[idx1]
        p2 = positions[idx2]
        dist = (p1 - p2).pow(2).sum().sqrt().item()
        
        if dist > max_dist: max_dist = dist
        if dist > threshold:
            return False, f"Bond {idx1}-{idx2} dist {dist:.2f}A > {threshold}A"
            
    return True, f"Geometry OK (Max Bond: {max_dist:.2f}A)"

def _process_logic(payload):
    index, mol_bytes, original_pos, original_smiles, energy, weight = payload
    
    if rdBase:
        try: rdBase.DisableMultithreading()
        except: pass

    src_raw = Chem.Mol(mol_bytes)
    if src_raw is None: return (index, "ERROR", "Mol corruption")
    
    src_zs = [a.GetAtomicNum() for a in src_raw.GetAtoms()]
    if len(src_zs) != len(original_pos):
            return (index, "ERROR", f"Size Mismatch: Atoms {len(src_zs)} vs Pos {len(original_pos)}")

    rw_mol = Chem.RWMol()
    conf = Chem.Conformer(len(src_zs))
    for i, z in enumerate(src_zs):
        rw_mol.AddAtom(Chem.Atom(z))
        conf.SetAtomPosition(i, original_pos[i].tolist())
    rw_mol.AddConformer(conf)
    src_geom = rw_mol.GetMol()

    try:
        rdDetermineBonds.DetermineConnectivity(src_geom, useHueckel=False)
    except Exception as e:
        return (index, "ERROR", f"Connectivity calc failed: {str(e)}")

    try:
        params_raw = Chem.SmilesParserParams()
        params_raw.removeHs = False
        mol_raw = Chem.MolFromSmiles(original_smiles, params=params_raw)
        if mol_raw is None: return (index, "ERROR", "SMILES Parse failed")

        mol_frag = get_largest_fragment(mol_raw)
        
        try: mol_clean = Chem.RemoveHs(mol_frag)
        except: mol_clean = mol_frag
            
        std_canon_smiles = Chem.MolToSmiles(mol_clean, isomericSmiles=True, canonical=True, allHsExplicit=False)
        
        tgt = Chem.MolFromSmiles(std_canon_smiles)
        if tgt is None: return (index, "ERROR", "Standard Reconstruction failed")
        tgt = Chem.AddHs(tgt)

        if tgt.GetNumAtoms() != src_geom.GetNumAtoms():
            return (index, "MISMATCH", f"Atoms: Pos {src_geom.GetNumAtoms()} vs StdTgt {tgt.GetNumAtoms()}")
        
    except Exception as e:
        return (index, "ERROR", f"Target Prep failed: {str(e)}")

    src_skel = bleach_mol(src_geom)
    tgt_skel = bleach_mol(tgt)

    matches = src_skel.GetSubstructMatches(tgt_skel, uniquify=True, useChirality=False)

    if not matches:
        return (index, "MISMATCH", "Graph Isomorphism failed")

    valid_perm = None
    error_msg = "All perms failed geometry"
    
    for match in matches:
        perm_torch = torch.tensor(match, dtype=torch.long)
        candidate_pos = original_pos[perm_torch]
        
        is_valid, msg = verify_geometry(tgt, candidate_pos, threshold=3.5)
        if is_valid:
            valid_perm = perm_torch
            final_pos = candidate_pos
            break
        else:
            error_msg = msg

    if valid_perm is None:
        return (index, "MISMATCH", f"Geo Check: {error_msg}")

    new_z = torch.tensor([a.GetAtomicNum() for a in tgt.GetAtoms()], dtype=torch.long)
    
    row, col, bond_types = [], [], []
    for bond in tgt.GetBonds():
        u, v = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        bt = bond.GetBondType()
        if bt == Chem.BondType.SINGLE: t_val = 1
        elif bt == Chem.BondType.DOUBLE: t_val = 2
        elif bt == Chem.BondType.TRIPLE: t_val = 3
        elif bt == Chem.BondType.AROMATIC: t_val = 4
        else: t_val = 0
        row.extend([u, v])
        col.extend([v, u])
        bond_types.extend([t_val, t_val])
    
    edge_index = torch.tensor([row, col], dtype=torch.long)
    edge_type = torch.tensor(bond_types, dtype=torch.long)

    def to_tensor(val):
        if torch.is_tensor(val): return val.float()
        if isinstance(val, (list, tuple)): return torch.tensor(val, dtype=torch.float)
        return torch.tensor([val], dtype=torch.float)

    new_data = Data(
        sample_id=index,
        h=new_z,
        pos=final_pos,
        smiles=std_canon_smiles, 
        original_smiles=original_smiles,
        edge_index=edge_index,
        edge_type=edge_type,
        totalenergy=to_tensor(energy),
        boltzmannweight=to_tensor(weight)
    )

    return (index, "OK", new_data)

def process_molecule_safe(payload):
    try:
        signal.signal(signal.SIGALRM, timeout_handler)
        signal.alarm(int(TIMEOUT_SECONDS))
        res = _process_logic(payload)
        signal.alarm(0)
        return res
    except TimeoutException:
        return (payload[0], "TIMEOUT", "Processing timed out")
    except Exception as e:
        signal.alarm(0)
        return (payload[0], "ERROR", f"Crash: {str(e)}")


if __name__ == "__main__":
    try: set_start_method('spawn', force=True)
    except RuntimeError: pass
    torch.multiprocessing.set_sharing_strategy('file_system')

    splits = ['val_data_5k_fixed', 'train_data_40k_fixed']
    base_dir = 'GEOM_Drugs' 
    num_workers = max(1, os.cpu_count() - 2)

    for split in splits:
        pkl_path = os.path.join(base_dir, f'{split}.pkl')
        
        if MAX_SAMPLES:
            print(f"\n--- DEBUG MODE: {MAX_SAMPLES} SAMPLES ---")
            out_path = os.path.join(base_dir, f'{split}_aligned_debug_{MAX_SAMPLES}.pt')
            temp_pattern = os.path.join(base_dir, f'{split}_debug_{MAX_SAMPLES}_part_*.pt')
            part_prefix = f'{split}_debug_{MAX_SAMPLES}_part_'
        else:
            out_path = os.path.join(base_dir, f'{split}_aligned.pt')
            temp_pattern = os.path.join(base_dir, f'{split}_part_*.pt')
            part_prefix = f'{split}_part_'

        if os.path.exists(out_path):
            print(f"Skipping {split}, output exists at {out_path}")
            continue
            
        print(f"\nLoading {split} from {pkl_path}...")
        if not os.path.exists(pkl_path):
            print("File not found.")
            continue

        with open(pkl_path, 'rb') as f:
            raw_data = pickle.load(f)

        all_tasks = []
        for i, d in enumerate(raw_data):
            if hasattr(d, 'rdmol') and d.rdmol is not None:
                all_tasks.append((i, d.rdmol.ToBinary(), d.pos, d.smiles, d.totalenergy, d.boltzmannweight))
        del raw_data

        if MAX_SAMPLES: all_tasks = all_tasks[:MAX_SAMPLES]

        print("Scanning existing partial files for completed indices...")
        processed_indices = set()
        existing_parts = glob.glob(temp_pattern)
        
        for p in tqdm.tqdm(existing_parts, desc="Scanning"):
            try:
                chunk = torch.load(p)
                for d in chunk:
                    if hasattr(d, 'sample_id'):
                        processed_indices.add(int(d.sample_id))
            except Exception as e:
                print(f"Warning: Could not read {p}: {e}")

        remaining_tasks = [t for t in all_tasks if t[0] not in processed_indices]
        
        random.shuffle(remaining_tasks)
        
        if not remaining_tasks:
            print("All tasks completed. Merging...")
        else:
            print(f"Processing remaining {len(remaining_tasks)} items with {num_workers} workers...")
            
            aligned_buffer = []
            stats = {"OK": 0, "MISMATCH": 0, "ERROR": 0, "TIMEOUT": 0}
            
            run_id = str(uuid.uuid4())[:8]
            batch_counter = 0

            with Pool(processes=num_workers) as pool:
                iterator = pool.imap_unordered(process_molecule_safe, remaining_tasks, chunksize=1)

                with tqdm.tqdm(total=len(remaining_tasks)) as pbar:
                    for idx, status, payload in iterator:
                        stats[status] += 1
                        
                        if status == "OK":
                            payload.h = payload.h.clone()
                            payload.pos = payload.pos.clone()
                            payload.edge_index = payload.edge_index.clone()
                            payload.edge_type = payload.edge_type.clone()
                            if payload.totalenergy is not None: payload.totalenergy = payload.totalenergy.clone()
                            if payload.boltzmannweight is not None: payload.boltzmannweight = payload.boltzmannweight.clone()
                            aligned_buffer.append(payload)

                        pbar.set_postfix({
                            "OK": f"{(stats['OK'] / (pbar.n + 1)) * 100:.1f}%",
                            "Fail": stats['MISMATCH'] + stats['ERROR'] + stats['TIMEOUT']
                        })
                        pbar.update(1)

                        if len(aligned_buffer) >= BATCH_SIZE:
                            part_path = os.path.join(base_dir, f'{part_prefix}{run_id}_{batch_counter}.pt')
                            torch.save(aligned_buffer, part_path)
                            aligned_buffer = []
                            batch_counter += 1

            if aligned_buffer:
                part_path = os.path.join(base_dir, f'{part_prefix}{run_id}_{batch_counter}.pt')
                torch.save(aligned_buffer, part_path)
        
        print(f"\nProcessing Complete/Resumed. Merging all parts...")
        all_parts = glob.glob(temp_pattern)
        final_list = []
        
        all_parts.sort()
        
        for p in tqdm.tqdm(all_parts, desc="Merging"):
            chunk = torch.load(p)
            final_list.extend(chunk)
        
        final_list.sort(key=lambda x: int(x.sample_id))

        print(f"Saving combined file to {out_path} ({len(final_list)} items)...")
        torch.save(final_list, out_path)

        if not MAX_SAMPLES: 
            print("Cleaning up partial files...")
            for p in all_parts:
                os.remove(p)

        print("Final Summary:")
        print(f"Total Extracted: {len(final_list)}")
