import io
import os
import ase
import abc
import json
import torch
import ase.io
import numpy as np
from openbabel import pybel
from collections import Counter
from rdkit import Chem, RDLogger
from rdkit.Geometry import Point3D
from contextlib import contextmanager
from motiflow.utils import rigid_utils as ru
from typing import Sequence, Any, List, Tuple, Union

RDLogger.DisableLog('rdApp.*')


_SYMBOLS_QM9 = ["H", "C", "N", "O", "F"]
_SYMBOLS_GEOM = ['H', 'B', 'C', 'N', 'O', 'F', 'Al', 'Si', 'P', 'S', 'Cl', 'As', 'Br', 'I', 'Hg', 'Bi']

SEED = 0
URL_GDB9 = "https://ndownloader.figshare.com/files/3195389"
URL_UNCHARACTERIZED = "https://ndownloader.figshare.com/files/3195404"

N_GDB9 = 133885
N_UNCHARACTERIZED = 3054

# Bond lengths from:
# http://www.wiredchemist.com/chemistry/data/bond_energies_lengths.html
# And:
# http://chemistry-reference.com/tables/Bond%20Lengths%20and%20Enthalpies.pdf

BONDS1 = {'H': {'H': 74, 'C': 109, 'N': 101, 'O': 96, 'F': 92,
                'B': 119, 'Si': 148, 'P': 144, 'As': 152, 'S': 134,
                'Cl': 127, 'Br': 141, 'I': 161},
          'C': {'H': 109, 'C': 154, 'N': 147, 'O': 143, 'F': 135,
                'Si': 185, 'P': 184, 'S': 182, 'Cl': 177, 'Br': 194,
                'I': 214},
          'N': {'H': 101, 'C': 147, 'N': 145, 'O': 140, 'F': 136,
                'Cl': 175, 'Br': 214, 'S': 168, 'I': 222, 'P': 177},
          'O': {'H': 96, 'C': 143, 'N': 140, 'O': 148, 'F': 142,
                'Br': 172, 'S': 151, 'P': 163, 'Si': 163, 'Cl': 164,
                'I': 194},
          'F': {'H': 92, 'C': 135, 'N': 136, 'O': 142, 'F': 142,
                'S': 158, 'Si': 160, 'Cl': 166, 'Br': 178, 'P': 156,
                'I': 187},
          'B': {'H':  119, 'Cl': 175},
          'Si': {'Si': 233, 'H': 148, 'C': 185, 'O': 163, 'S': 200,
                 'F': 160, 'Cl': 202, 'Br': 215, 'I': 243 },
          'Cl': {'Cl': 199, 'H': 127, 'C': 177, 'N': 175, 'O': 164,
                 'P': 203, 'S': 207, 'B': 175, 'Si': 202, 'F': 166,
                 'Br': 214},
          'S': {'H': 134, 'C': 182, 'N': 168, 'O': 151, 'S': 204,
                'F': 158, 'Cl': 207, 'Br': 225, 'Si': 200, 'P': 210,
                'I': 234},
          'Br': {'Br': 228, 'H': 141, 'C': 194, 'O': 172, 'N': 214,
                 'Si': 215, 'S': 225, 'F': 178, 'Cl': 214, 'P': 222},
          'P': {'P': 221, 'H': 144, 'C': 184, 'O': 163, 'Cl': 203,
                'S': 210, 'F': 156, 'N': 177, 'Br': 222},
          'I': {'H': 161, 'C': 214, 'Si': 243, 'N': 222, 'O': 194,
                'S': 234, 'F': 187, 'I': 266},
          'As': {'H': 152}
          }

BONDS2 = {'C': {'C': 134, 'N': 129, 'O': 120, 'S': 160},
          'N': {'C': 129, 'N': 125, 'O': 121}, # CHANGED 125 -> 126 for N=N
          'O': {'C': 120, 'N': 121, 'O': 121, 'P': 150},
          'P': {'O': 150, 'S': 186},
          'S': {'P': 186}}


BONDS3 = {'C': {'C': 120, 'N': 116, 'O': 113},
          'N': {'C': 116, 'N': 110},
          'O': {'C': 113}}

STDV = {"H": 5, "C": 1, "N": 1, "O": 2, "F": 3}
MARGIN1, MARGIN2, MARGIN3 = 10, 5, 3

ALLOWED_BONDS = {'H': 1, 'C': 4, 'N': 3, 'O': 2, 'F': 1, 'B': 3, 'Al': 3,
                 'Si': 4, 'P': [3, 5],
                 'S': 4, 'Cl': 1, 'As': 3, 'Br': 1, 'I': 1, 'Hg': [1, 2],
                 'Bi': [3, 5]}

BOND_LIST = [
    None,
    Chem.rdchem.BondType.SINGLE,
    Chem.rdchem.BondType.DOUBLE,
    Chem.rdchem.BondType.TRIPLE,
]

def read_json(json_path: str):
    with open(json_path, encoding="utf-8", mode="r") as fp:
        return json.load(fp)

def save_json(json_dict: dict, json_path: str):
    def _fix_dict():
        for key in json_dict:
            if isinstance(json_dict[key], np.ndarray):
                json_dict[key] = json_dict[key].tolist()

    _fix_dict()
    with open(json_path, encoding="utf-8", mode="w") as fp:
        json.dump(json_dict, fp)
        
@contextmanager
def suppress_stderr():
    """Concise context manager to silence C-level stderr."""
    with open(os.devnull, "w") as devnull:
        old_err = os.dup(2)
        os.dup2(devnull.fileno(), 2)
        try: yield
        finally: os.dup2(old_err, 2); os.close(old_err)

def discrete_histogram(values: Sequence[Any], encoder: dict[Any, int], norm: bool = False) -> np.ndarray:
    counter = Counter(values)
    histogram = np.zeros(max(encoder.values()) + 1)
    for key in counter:
        histogram[encoder[key]] = counter[key]

    if norm:
        histogram /= np.sum(histogram)

    return histogram

class Metrics(abc.ABC):
    def __call__(self, atoms):
        return self.update(atoms)

    def update(self, atoms):
        raise NotImplementedError()

    def summarize(self) -> dict:
        raise NotImplementedError()

    def reset(self):
        raise NotImplementedError()
    
def read_xyz_file(file_handle, center: bool = True) -> Tuple[ase.Atoms, str]:
    content = [line.decode('UTF-8') for line in file_handle.readlines()]
    num_atoms = int(content[0])
    xyz_block = "".join(content[0 : num_atoms + 2]).replace('*^', 'e')
    
    with io.StringIO(xyz_block) as f:
        atoms = ase.io.read(f, format='xyz')
    
    if center:
        atoms.positions -= atoms.positions.mean(axis=0)

    return atoms, content[-2].split()[0] # SMILES

def read_uncharacterized(fpath: str) -> List[int]:
    with open(fpath) as f:
        return [int(x.split()[0]) - 1 for x in f.read().split("\n")[9:-2] if x.strip()]
    
def get_bond_order(atom1, atom2, distance, check_exists=True, single_bond=False):
    distance = 100 * distance  # We change the metric

    # Check exists for large molecules where some atom pairs do not have a
    # typical bond length.
    if check_exists:
        if atom1 not in BONDS1:
            print(f"Atom {atom1} not in bonds1")
            return 0
        if atom2 not in BONDS1[atom1]:
            print(f"Atom {atom2} not in bonds1[{atom1}]")
            return 0

    # margin1, margin2 and margin3 have been tuned to maximize the stability of
    # the QM9 true samples.
    if distance < BONDS1[atom1][atom2] + MARGIN1:
        # Check if atoms in bonds2 dictionary.
        if atom1 in BONDS2 and atom2 in BONDS2[atom1]:
            thr_bond2 = BONDS2[atom1][atom2] + MARGIN2
            if distance < thr_bond2:
                if atom1 in BONDS3 and atom2 in BONDS3[atom1]:
                    thr_bond3 = BONDS3[atom1][atom2] + MARGIN3
                    if distance < thr_bond3:
                        return 3 if not single_bond else 1  # Triple
                return 2 if not single_bond else 1  # Double
        return 1  # Single
    return 0  # No bond

def check_stability(
        atoms: ase.Atoms,
        bond_order_per_atom: np.ndarray,
):
    num_atoms = len(atoms)
    atom_stable = 0
    for symbol_i, nr_bonds_i in zip(atoms.symbols, bond_order_per_atom):
        possible_bonds = ALLOWED_BONDS[symbol_i]
        if type(possible_bonds) == int:
            is_stable = (possible_bonds == nr_bonds_i)
        else:
            is_stable = nr_bonds_i in possible_bonds
        atom_stable += int(is_stable)

    molecule_stable = (atom_stable == num_atoms)
    return molecule_stable, atom_stable, num_atoms

def make_mol_rdkit_qm9(
        atoms: ase.Atoms,
        with_conformer: bool = False,
        known_bonds: dict = None
):
    distances = atoms.get_all_distances()
    num_atoms = len(atoms)
    bond_order_per_atom = np.zeros(num_atoms, dtype="int")

    mol = Chem.RWMol()

    # Add all atoms first
    for symbol in atoms.symbols:
        a = Chem.Atom(symbol)
        mol.AddAtom(a)

    # Add all bonds
    for i in range(num_atoms):
        symbol_i = atoms.symbols[i]
        for j in range(i + 1, num_atoms):
            order = 0
            # 1. Check if we have a ground-truth bond from the fragment library
            if known_bonds is not None:
                if (i, j) in known_bonds:
                    order = known_bonds[(i, j)]
            # 2. If not known (or no override provided), infer from distance
            if order == 0:
                dist = distances[i, j]
                symbol_j = atoms.symbols[j]
                order = get_bond_order(
                    symbol_i, symbol_j, dist, single_bond=False
                )
            bond_order_per_atom[i] += order
            bond_order_per_atom[j] += order

            if order > 0:
                mol.AddBond(i, j, BOND_LIST[order])

    if with_conformer:
        conf = Chem.Conformer(mol.GetNumAtoms())
        for i in range(mol.GetNumAtoms()):
            x, y, z = atoms.positions[i]
            conf.SetAtomPosition(i, Point3D(x, y, z))
        mol.AddConformer(conf)

    return mol, bond_order_per_atom

def check_validity(mol: Chem.Mol):
    def mol_to_smi(mol_: Chem.Mol):
        try:
            Chem.SanitizeMol(mol_)
            smi = Chem.MolToSmiles(mol_, canonical=True)
            return mol_, smi
        except ValueError:
            return None, None

    mol, smi = mol_to_smi(mol)
    v, c = 0, 0

    if smi is not None:
        mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True)
        c = int(len(mol_frags) == 1)

        largest_mol = max(mol_frags, default=mol, key=lambda m: m.GetNumAtoms())
        mol, smi = mol_to_smi(largest_mol)
        v = int(smi is not None)

    return mol, smi, v, c

def reconstruct_atoms(generated_data, library_dict):
    """
    Converts generated rigid fragments back to atom positions and 
    extracts intra-fragment bonds if available.
    """
    pos_list = []
    z_list = []
    bond_overrides_list = []
    
    for mol_entry in generated_data:
        frag_ids = mol_entry['frag_ids'] 
        rigids_data = mol_entry['rigids'] 
        if not torch.is_tensor(rigids_data):
            rigids_tensor = torch.tensor(rigids_data, dtype=torch.float32)
        else:
            rigids_tensor = rigids_data.to(dtype=torch.float32).cpu()
        rigids = ru.Rigid.from_tensor_7(rigids_tensor)
        rot_mats = rigids.get_rots().get_rot_mats().detach().numpy()
        trans_vecs = rigids.get_trans().detach().numpy()
        
        mol_atoms_pos = []
        mol_atoms_z = []
        mol_known_bonds = {} # Map (global_idx_u, global_idx_v) -> order
        
        current_atom_offset = 0

        num_frags = len(frag_ids)
        for k in range(num_frags):
            cid = int(frag_ids[k].item())
            
            entry = library_dict[cid]
            ex_pos = entry['exemplar_pos']
            ex_z = entry['exemplar_z']
            
            # Normalize to numpy
            if torch.is_tensor(ex_pos): ex_pos = ex_pos.cpu().numpy()
            if torch.is_tensor(ex_z): ex_z = ex_z.cpu().numpy()
            
            # Apply Transformation: x_world = x_local @ R^T + t
            current_pos = np.dot(ex_pos, rot_mats[k].T) + trans_vecs[k]
            
            # Keep only real atoms
            real_mask = (ex_z != 0)
            if not np.any(real_mask):
                continue
            
            # Indices of real atoms in the exemplar
            real_indices_local = np.where(real_mask)[0]
            
            # Store data
            mol_atoms_pos.append(current_pos[real_mask])
            mol_atoms_z.append(ex_z[real_mask])
            
            # --- Extract Topology if available (GEOM) ---
            if 'exemplar_edge_index' in entry and 'exemplar_edge_attr' in entry:
                e_idx = entry['exemplar_edge_index']
                e_attr = entry['exemplar_edge_attr']
                
                if torch.is_tensor(e_idx): e_idx = e_idx.cpu().numpy()
                if torch.is_tensor(e_attr): e_attr = e_attr.cpu().numpy()
                
                # Map: Exemplar Local Index -> Global Molecule Index
                # Note: 'real_indices_local' array maps 0..len(real)-1 to local indices
                # We need reverse map: local_index -> 0..len(real)-1 (to add offset)
                local_to_global = {}
                for i_seq, loc_idx in enumerate(real_indices_local):
                    local_to_global[loc_idx] = current_atom_offset + i_seq
                
                num_edges = e_idx.shape[1]
                for e in range(num_edges):
                    u_local = e_idx[0, e]
                    v_local = e_idx[1, e]
                    order = int(e_attr[e])
                    
                    # Only add if both ends are Real atoms
                    if u_local in local_to_global and v_local in local_to_global:
                        u_glob = local_to_global[u_local]
                        v_glob = local_to_global[v_local]
                        
                        # Store both directions for easier lookup in make_mol
                        mol_known_bonds[(u_glob, v_glob)] = order
                        mol_known_bonds[(v_glob, u_glob)] = order
            
            # Increment offset by number of real atoms added
            current_atom_offset += len(real_indices_local)
            
        if len(mol_atoms_pos) > 0:
            full_pos = np.concatenate(mol_atoms_pos, axis=0)
            full_z = np.concatenate(mol_atoms_z, axis=0)

            pos_list.append(full_pos)
            z_list.append(full_z)
            bond_overrides_list.append(mol_known_bonds)
        else:
            raise ValueError("Warning: Molecule with zero fragments encountered during reconstruction.")
            
    return pos_list, z_list, bond_overrides_list

class Metrics(Metrics):
    def __init__(self,
                 atom_types_str: str = _SYMBOLS_QM9,
                 max_num_atoms: int = 29,
                 json_path: str = None,
                 summarize_hidden: bool = False,
                 hidden_prefix: str = "_"):

        self.encoder = {s: idx for idx, s in enumerate(atom_types_str)}
        self.max_num_atoms = max_num_atoms

        if json_path:
            dataset_infos = read_json(json_path=json_path)
            ref_smiles = set(dataset_infos.get("smiles"))
            ref_atom_hist = np.array(dataset_infos.get("atom_hist"))
        else:
            ref_smiles = set([])
            ref_atom_hist = None

        self.ref_smiles = ref_smiles
        self.ref_atom_hist = ref_atom_hist

        self.summarize_hidden = summarize_hidden
        self.hidden_prefix = hidden_prefix

        self.smiles = ...
        self.mols = []
        self.valid = ...
        self.valid_connected = ...
        self.n_atoms = ...
        self.molecule_stable = ...
        self.atom_hist = ...
        self.atom_stable = ...

        self.reset()

    def update(self, atoms, bond_overrides_list: List[dict] = None):
        if isinstance(atoms, ase.Atoms):
            atoms = [atoms]
            
        # Handle case where overrides are not provided (QM9 compatibility)
        if bond_overrides_list is None:
            bond_overrides_list = [None] * len(atoms)

        for a, known_bonds in zip(atoms, bond_overrides_list):
            raw_mol, bond_order_per_atom = make_mol_rdkit_qm9(a, known_bonds=known_bonds)
            mol, smi, v, c = check_validity(raw_mol)
            molecule_stable, atom_stable, n_atoms = check_stability(a, bond_order_per_atom)

            self.smiles.append(smi)
            self.mols.append(mol)
            self.valid.append(v)
            self.valid_connected.append(c)
            self.n_atoms.append(n_atoms)
            self.molecule_stable.append(molecule_stable)
            self.atom_stable.append(atom_stable)
            self.atom_hist.append(discrete_histogram(a.symbols, encoder=self.encoder))

    def summarize(self, return_molecules=False) -> dict:
        assert len(self.valid) == len(self.valid_connected)
        assert len(self.valid) == len(self.molecule_stable)
        assert len(self.valid) == len(self.smiles)

        n_samples = len(self.valid)
        n_atoms = sum(self.n_atoms)

        summary = {}

        summary["atom_stable"] = sum(self.atom_stable) / n_atoms
        summary["molecule_stable"] = sum(self.molecule_stable) / n_samples

        summary["valid"] = sum(self.valid) / n_samples
        summary["valid_connected"] = sum(self.valid_connected) / n_samples

        valid_unique_smiles = set([smiles for (v, smiles) in zip(self.valid, self.smiles) if v])
        summary["valid_unique"] = len(valid_unique_smiles) / n_samples
        if self.ref_smiles is not None:
            vun_smiles = valid_unique_smiles.difference(self.ref_smiles)
            summary["valid_unique_novel"] = len(vun_smiles) / n_samples

        atom_hist = np.sum(np.stack(self.atom_hist, axis=0), axis=0)
        atom_hist = (atom_hist / atom_hist.sum())
        if self.ref_atom_hist is not None:
            summary["tv_atom"] = np.sum(np.abs(self.ref_atom_hist - atom_hist)).item()

        if self.summarize_hidden:
            summary[f"{self.hidden_prefix}atom_hist"] = atom_hist
            summary[f"{self.hidden_prefix}num_atoms_hist"] = discrete_histogram(self.n_atoms,
                                                                                encoder={idx: idx for idx in
                                                                                         range(self.max_num_atoms + 1)},
                                                                                norm=True)
            summary[f"{self.hidden_prefix}smiles"] = list(valid_unique_smiles)
            
        gen_mols = {'SMILES': self.smiles,
                    'molecules': self.mols,}
        if return_molecules:
            return summary, gen_mols
        return summary

    def reset(self):
        self.smiles = []
        self.valid = []
        self.valid_connected = []
        self.n_atoms = []
        self.molecule_stable = []
        self.atom_hist = []
        self.atom_stable = []


class MoleculeScorer:
    def __init__(self):
        """
        Args:
            json_path: Path to reference stats json.
        """
        # Extended decoder covering both QM9 and GEOM-Drugs elements
        self.atom_decoder = {
            1: "H", 5: "B", 6: "C", 7: "N", 8: "O", 9: "F", 
            13: "Al", 14: "Si", 15: "P", 16: "S", 17: "Cl", 
            33: "As", 35: "Br", 53: "I", 80: "Hg", 83: "Bi"
        }
        self.qm9_metrics = Metrics(atom_types_str=_SYMBOLS_QM9, max_num_atoms=29)
        self.geom_metrics = Metrics(atom_types_str=_SYMBOLS_GEOM, max_num_atoms=181)
        
    def _score(self, pos_list: Union[List[torch.Tensor], List[np.ndarray]], 
               z_list: Union[List[torch.Tensor], List[np.ndarray]],
               bond_overrides_list: List[dict] = None, return_molecules = False) -> dict:
        ase_mols = []
        for (pos, z) in zip(pos_list, z_list):
            # Convert to numpy
            if isinstance(pos, torch.Tensor): pos = pos.cpu().numpy()
            if isinstance(z, torch.Tensor): z = z.cpu().numpy()
            
            # Map atomic numbers to symbols
            symbols = [self.atom_decoder.get(int(num), "X") for num in z]
            
            # Create ASE atoms
            mol = ase.Atoms(symbols=symbols, positions=pos)
            ase_mols.append(mol)
            
        self.metrics.update(ase_mols, bond_overrides_list=bond_overrides_list)
        if return_molecules:
            summary, gen_mols = self.metrics.summarize(return_molecules)
        else:
            summary = self.metrics.summarize(return_molecules)
            gen_mols = None
        return summary, gen_mols
    
    def score_composition(self, z_list: List[np.ndarray], targets: np.ndarray) -> dict:
        """
        Task 1 Metric: Composition Matching Rate.
        z_list: List of atom arrays (atomic numbers) for each generated sample.
        targets: Array of shape [N, 5] containing target counts [H, C, N, O, F].
        """
        matches = 0
        total = len(z_list)
        
        # Mapping for QM9: H=1, C=6, N=7, O=8, F=9 -> Indices 0,1,2,3,4
        # Assumes the target vector follows the sorted order of atomic numbers, which is satisfied for our preprocessed data
        z_to_idx = {1: 0, 6: 1, 7: 2, 8: 3, 9: 4}
        
        for i, z_gen in enumerate(z_list):
            # Count atoms in generated sample
            counts_gen = np.zeros(5, dtype=int)
            for atom_z in z_gen:
                if atom_z in z_to_idx:
                    counts_gen[z_to_idx[atom_z]] += 1
            
            # Compare with target (rounded to int just in case)
            target_counts = targets[i].astype(int)
            
            if np.array_equal(counts_gen, target_counts):
                matches += 1
                
        return {"composition_match_rate": matches / total if total > 0 else 0.0}
    
    def score_structure(self, pos_list: List[np.ndarray], z_list: List[np.ndarray], target_fps: np.ndarray) -> dict:
        """
        Task 2 Metric: Tanimoto Similarity.
        """
        if pybel is None:
            print("Warning: OpenBabel not installed. Skipping structure scoring.")
            return {"tanimoto_similarity": 0.0}

        total_sim = 0.0
        total = len(pos_list)
        
        # Helper map
        z_to_sym = {1: 'H', 6: 'C', 7: 'N', 8: 'O', 9: 'F'}

        for i, (pos, z) in enumerate(zip(pos_list, z_list)):
            # 1. Convert generated molecule to OpenBabel
            # Construct XYZ string
            lines = [str(len(z)), "Generated"]
            for j in range(len(z)):
                sym = z_to_sym.get(int(z[j]), 'C')
                lines.append(f"{sym} {pos[j,0]:.4f} {pos[j,1]:.4f} {pos[j,2]:.4f}")
            xyz_block = "\n".join(lines)
            
            try:
                mol = pybel.readstring("xyz", xyz_block)
                # Compute FP2
                fp_gen = mol.calcfp("FP2")
                
                # 2. Convert target tensor back to Pybel Fingerprint object (or bitset) for comparison
                # Target is a float tensor of 0s and 1s.
                # Tanimoto = (A & B) / (A | B)
                
                # Generated FP to vector
                # fp_gen.bits gives indices of set bits. 
                # Construct binary vector [1024]
                gen_vec = np.zeros(1024, dtype=int)
                for bit in fp_gen.bits:
                    if bit < 1024: gen_vec[bit] = 1
                
                target_vec = target_fps[i].astype(int)
                
                # Tanimoto Calculation
                intersection = np.sum(gen_vec & target_vec)
                union = np.sum(gen_vec | target_vec)
                
                sim = intersection / union if union > 0 else 0.0
                total_sim += sim
                
            except Exception:
                # If generation is garbage (e.g. disconnected/invalid), sim is 0
                pass
                
        return {"tanimoto_similarity": total_sim / total if total > 0 else 0.0}

    def score_qm9(self, pos_list, z_list, bond_overrides_list=None, return_molecules=False) -> dict:
        """Scores a batch of QM9 molecules."""
        self.metrics = self.qm9_metrics 
        score, gen_mols = self._score(pos_list, z_list, bond_overrides_list, return_molecules)
        if return_molecules:
            return score, gen_mols
        return score

    def score_geom_drugs(self, pos_list, z_list, bond_overrides_list=None, return_molecules=False) -> dict:
        """Scores a batch of GEOM-Drugs molecules."""
        self.metrics = self.geom_metrics
        score, gen_mols = self._score(pos_list, z_list, bond_overrides_list, return_molecules)
        if return_molecules:
            return score, gen_mols
        return score
