import os
import pandas as pd
import ast
import torch
from rdkit import Chem
from torch_geometric.data import Data, InMemoryDataset
from torch.utils.data import Dataset
import os, torch, pandas as pd
from rdkit import Chem
import os, torch, pandas as pd
from rdkit import Chem
from rdkit.Chem import AllChem
from torch_geometric.data import Data, InMemoryDataset
from rdkit.Chem import AllChem, rdDistGeom
from pathlib import Path
from tqdm.auto import tqdm 



class SmilesGraphDataset(InMemoryDataset):
    """
    PyTorch Geometric dataset for SMILES-based molecule graphs and regression labels.
    """
    BOND2ID = {
        Chem.rdchem.BondType.SINGLE:   0,
        Chem.rdchem.BondType.DOUBLE:   1,
        Chem.rdchem.BondType.TRIPLE:   2,
        Chem.rdchem.BondType.AROMATIC: 3,
        Chem.rdchem.BondType.IONIC:       4,
        Chem.rdchem.BondType.DATIVE:      5,
        Chem.rdchem.BondType.QUADRUPLE:   6,
        Chem.rdchem.BondType.HYDROGEN:    7,
        Chem.rdchem.BondType.UNSPECIFIED: 8,  
    }
    def __init__(self, csv_path: str, vocab: dict = None, transform=None):
        df = pd.read_csv(csv_path).dropna(subset=['smiles', 'permeability'])
        self.vocab = vocab
        data_list = []
        for _, row in df.iterrows():
            data = self.smiles_to_pyg(row['smiles'], row['permeability'])
            if data is not None:
                data_list.append(data)

        super().__init__(os.path.dirname(csv_path), transform)
        self.data, self.slices = self.collate(data_list)

    @staticmethod
    def smiles_to_pyg(smiles: str, target: float):
        """
        Convert SMILES string and target value to a PyG Data object with atom features only.
        """
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return None

        atom_nums = [atom.GetAtomicNum() for atom in mol.GetAtoms()]
        # print(smiles)
        # print('atom_nums',atom_nums)
        x = torch.tensor(atom_nums, dtype=torch.long).unsqueeze(1)

        edge_index, edge_attr = [], []
        for bond in mol.GetBonds():
            i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
            bond_type_id = SmilesGraphDataset.BOND2ID[bond.GetBondType()]
            edge_index += [[i, j], [j, i]]
            edge_attr  += [bond_type_id, bond_type_id]  

        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        edge_attr  = torch.tensor(edge_attr,  dtype=torch.long)  # (E,)

        y = torch.tensor([float(target)], dtype=torch.float)
        num_nodes = x.size(0)   # or len(z) or pos.size(0)
        data = Data(z=x, edge_index=edge_index, y=y, edge_attr=edge_attr, num_nodes=num_nodes)
        data.smiles = smiles
        return data



class PeptideGraphDataset(InMemoryDataset):
    def __init__(self, csv_path: str, vocab: dict = None, transform=None):
        df = pd.read_csv(csv_path)

        if vocab is None:
            vocab = self._build_vocab(df['helm'])
        self.vocab = vocab

        data_list = []
        for _, row in df.iterrows():
            data = self.seq_to_graph(row['helm'], row['permeability'])
            if data is not None:
                data_list.append(data)

        super().__init__(os.path.dirname(csv_path), transform)
        self.data, self.slices = self.collate(data_list)

    def _build_vocab(self, sequences):
        tokens = set()
        for seq in sequences:
            core = seq.partition('{')[2].partition('}')[0]
            for aa in core.split('.'):
                tokens.add(aa)

        vocab = {'<unk>': 0}
        for idx, aa in enumerate(sorted(tokens), start=1):
            vocab[aa] = idx
        return vocab

    def seq_to_graph(self, seq_str, target):
        core = seq_str.partition('{')[2].partition('}')[0]
        if not core:
            return None

        aa_list = core.split('.')
        idxs = [self.vocab.get(aa, 0) for aa in aa_list]  
        x = torch.tensor(idxs, dtype=torch.long).unsqueeze(1) 

        edge_index = []
        for i in range(len(idxs) - 1):
            edge_index.append([i, i+1])
            edge_index.append([i+1, i])


        extra_edges = self._parse_helm_links(seq_str, len(idxs))
        edge_index.extend(extra_edges)

        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()

        y = torch.tensor([float(target)], dtype=torch.float)
        num_nodes = x.size(0)   # or len(z) or pos.size(0)
        data = Data(z=x, edge_index=edge_index, y=y, num_nodes=num_nodes)
        return data

    def _parse_helm_links(self, seq_str: str, n: int):
        """
        Parse the connection block like '1:R1-6:R2|...' in HELM and return bidirectional edges.
        """
        after_brace = seq_str.split('}', 1)[-1]  # everything after first '}'
        parts = after_brace.split('$')
        if len(parts) < 2 or not parts[1]:
            return []

        conn_block = parts[1]  # e.g. 'PEPTIDE13,PEPTIDE13,1:R1-6:R2'
        links = []
        for seg in conn_block.split(','):
            if ':' in seg and '-' in seg:
                for pair in seg.split('|'):
                    if not pair:
                        continue
                    # left= "1:R1", right="6:R2"
                    left, right = pair.split('-')
                    i = int(left.split(':')[0]) - 1
                    j = int(right.split(':')[0]) - 1
                    if 0 <= i < n and 0 <= j < n:
                        links.append((i, j))
                        links.append((j, i))
        return links




def gaussian_rbf(dist: torch.Tensor,
                 num_centers: int = 32,
                 cutoff: float = 6.0) -> torch.Tensor:
    centers = torch.linspace(0.0, cutoff, num_centers, device=dist.device)
    gamma = -0.5 / (centers[1] - centers[0]) ** 2
    return torch.exp(gamma * (dist[:, None] - centers[None, :]) ** 2)




class Smiles3DGraphDataset(InMemoryDataset):

    def __init__(self,
                 csv_path: str,
                 cutoff: float = 6.0,
                 transform=None,
                 pre_transform=None):
        self.csv_path = csv_path
        self.cutoff = cutoff
        self.csv_stem = Path(csv_path).stem
        super().__init__(root=os.path.dirname(csv_path),
                         transform=transform,
                         pre_transform=pre_transform)

        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def processed_file_names(self):
        return [f"{self.csv_stem}_smiles3d.pt"]

    def _smiles_to_graph(self, smi: str, y_val: float) -> Data | None:
        mol = Chem.AddHs(Chem.MolFromSmiles(smi))
        # print("Processing SMILES:", smi)
        if mol is None:
            return None

        params = getattr(rdDistGeom, 'ETKDGv3', getattr(rdDistGeom, 'ETKDGv2', rdDistGeom.ETKDG))()
        params.numThreads = 16 
        params.randomSeed = 42 
        params.useRandomCoords = True
  
        ids = rdDistGeom.EmbedMultipleConfs(mol, numConfs=5, params=params)
        if not ids:
            print("Embed failed, no conformers generated for SMILES:", smi)
            return None
        conf_id = ids[0]

        AllChem.UFFOptimizeMolecule(mol, confId=conf_id, maxIters=500)

        mol = Chem.RemoveHs(mol)  

        conf = mol.GetConformer(conf_id)
        pos = torch.tensor(conf.GetPositions(), dtype=torch.float)
        z   = torch.tensor([a.GetAtomicNum() for a in mol.GetAtoms()], dtype=torch.long)


        src, dst, dist = [], [], []
        for i in range(len(z)):
            for j in range(i + 1, len(z)):
                d = torch.norm(pos[i] - pos[j]).item()
                if d <= self.cutoff:
                    src += [i, j]; dst += [j, i]; dist += [d, d]

        edge_index = torch.tensor([src, dst], dtype=torch.long)
        edge_attr  = gaussian_rbf(torch.tensor(dist), cutoff=self.cutoff)
        y = torch.tensor([y_val], dtype=torch.float)
        num_nodes = z.size(0)   # or len(z) or pos.size(0)

        return Data(z=z.unsqueeze(1), pos=pos,
                    edge_index=edge_index,
                    edge_attr=edge_attr,
                    y=y,
                    num_nodes=num_nodes)

    def process(self):
        print(f"[INFO] Processing: {self.csv_path}")
        df = (pd.read_csv(self.csv_path)
                .dropna(subset=['smiles', 'permeability']))

        data_list = []
        for smi, y in tqdm(zip(df.smiles, df['permeability']),
                           total=len(df),
                           desc="Embedding SMILES→Graph"):
            g = self._smiles_to_graph(smi, float(y))
            if g is not None:
                data_list.append(g)

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])