import logging
import re
from functools import lru_cache

import numpy as np
from rdkit import Chem, RDLogger
from rdkit.Chem import QED, Crippen, Descriptors, FilterCatalog, Lipinski
from rdkit.Chem.SaltRemover import SaltRemover

logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)

remover = SaltRemover()

params = FilterCatalog.FilterCatalogParams()
params.AddCatalog(FilterCatalog.FilterCatalogParams.FilterCatalogs.PAINS)
pains_catalog = FilterCatalog.FilterCatalog(params)

_tag_re = re.compile(r"</?(bos|eos|pad)>")


@lru_cache(maxsize=None)
def clean_smiles(s: str):
    s = _tag_re.sub("", s).strip()
    m = Chem.MolFromSmiles(s)
    return Chem.MolToSmiles(m, canonical=True, isomericSmiles=False) if m else None


def molstats(smiles_batch):
    if smiles_batch is None:
        logging.warning("Input SMILES batch is None; treating as empty list.")
        smiles_batch = []

    total_count = len(smiles_batch)

    processed_properties = {
        "qed": [],
        "mw": [],
        "logp": [],
        "tpsa": [],
        "hba": [],
        "hbd": [],
        "rotatable_bonds": [],
        "num_rings": [],
        "fraction_csp3": [],
        "lipinski_violations": [],
    }

    if smiles_batch:
        for smi in smiles_batch:
            if not smi or not isinstance(smi, str):
                continue

            mol = Chem.MolFromSmiles(smi)
            if mol is None:
                continue

            try:
                Chem.SanitizeMol(mol)

                # collect all valid molecules’ properties
                processed_properties["qed"].append(QED.qed(mol))
                processed_properties["mw"].append(Descriptors.MolWt(mol))
                processed_properties["logp"].append(Crippen.MolLogP(mol))
                processed_properties["tpsa"].append(Descriptors.TPSA(mol))
                processed_properties["hba"].append(Lipinski.NumHAcceptors(mol))
                processed_properties["hbd"].append(Lipinski.NumHDonors(mol))
                processed_properties["rotatable_bonds"].append(
                    Lipinski.NumRotatableBonds(mol)
                )
                processed_properties["num_rings"].append(Lipinski.RingCount(mol))
                processed_properties["fraction_csp3"].append(Lipinski.FractionCSP3(mol))

                # Lipinski rule‑of‑5 violations
                v = 0
                if Descriptors.MolWt(mol) > 500:
                    v += 1
                if Crippen.MolLogP(mol) > 5:
                    v += 1
                if Lipinski.NumHDonors(mol) > 5:
                    v += 1
                if Lipinski.NumHAcceptors(mol) > 10:
                    v += 1
                processed_properties["lipinski_violations"].append(v)

            except Exception as e:
                logging.error(f"Error calculating properties for SMILES '{smi}': {e}")

    valid_count = len(processed_properties["qed"])

    results = {
        "total_molecules": total_count,
        "valid_molecules": valid_count,
        "validity_percent": (valid_count / total_count) * 100 if total_count else 0,
    }

    if valid_count == 0:
        nan_stats = {"mean": np.nan, "std": np.nan, "min": np.nan, "max": np.nan}
        for key in processed_properties:
            results[f"{key}_stats"] = nan_stats.copy()
        results["lipinski_rule_of_5_passed_percent (<=1 violation)"] = 0.0
        results["lipinski_strict_rule_of_5_passed_percent (0 violations)"] = 0.0

        if total_count == 0:
            logging.warning("SMILES list is empty – returning default statistics.")
        else:
            logging.warning("No valid molecules found – returning default statistics.")
        return results

    for key, values in processed_properties.items():
        arr = np.asarray(values)
        results[f"{key}_stats"] = {
            "mean": np.mean(arr),
            "std": np.std(arr),
            "min": np.min(arr),
            "max": np.max(arr),
        }

    v = np.asarray(processed_properties["lipinski_violations"])
    results["lipinski_rule_of_5_passed_percent (<=1 violation)"] = (
        (v <= 1).sum() / valid_count * 100
    )
    results["lipinski_strict_rule_of_5_passed_percent (0 violations)"] = (
        (v == 0).sum() / valid_count * 100
    )

    logging.info(f"Processed batch. Validity: {results['validity_percent']:.2f}%")
    return results


def print_molstats(stats):
    if stats is None:
        print("No statistics available.")
        return

    print("\n--- Molecular Statistics ---")
    print(f"Total Molecules Processed: {stats['total_molecules']}")
    print(f"Valid Molecules Found:     {stats['valid_molecules']}")
    print(f"Validity Percentage:       {stats['validity_percent']:.2f}%")
    print("-" * 28)

    def print_metric(name, data_key):
        if f"{data_key}_stats" in stats:
            s = stats[f"{data_key}_stats"]
            if not np.isnan(s["mean"]):
                print(f"{name}:")
                print(f"  Mean: {s['mean']:.3f}")
                print(f"  Std:  {s['std']:.3f}")
                print(f"  Min:  {s['min']:.3f}")
                print(f"  Max:  {s['max']:.3f}")
            else:
                print(f"{name}: N/A (No valid molecules)")
        else:
            print(f"{name}: Not Calculated")

    print_metric("QED", "qed")
    print_metric("Molecular Weight (MW)", "mw")
    print_metric("LogP (Crippen)", "logp")
    print_metric("Topological Polar Surface Area (TPSA)", "tpsa")
    print_metric("H-Bond Acceptors (HBA)", "hba")
    print_metric("H-Bond Donors (HBD)", "hbd")
    print_metric("Rotatable Bonds", "rotatable_bonds")
    print_metric("Number of Rings", "num_rings")
    print_metric("Fraction Csp3", "fraction_csp3")
    print_metric("Lipinski Violations", "lipinski_violations")

    print("-" * 28)
    print("Drug-Likeness (Lipinski Rule of 5):")
    print(
        f"  % Passing (<=1 Violation): {stats.get('lipinski_rule_of_5_passed_percent (<=1 violation)', 0.0):.2f}%"
    )
    print(
        f"  % Strictly Passing (0 Violations): {stats.get('lipinski_strict_rule_of_5_passed_percent (0 violations)', 0.0):.2f}%"
    )
    print("--- End Statistics ---\n")


def filter_smiles_str(example, colname="smiles"):
    smiles = example.get(colname)
    if not smiles:
        return False

    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return False

        # 1. Strip Salts
        mol = remover.StripMol(mol, dontRemoveEverything=True)
        if mol is None:
            return False

        # 2. Apply Property Constraints
        mw = Descriptors.MolWt(mol)
        logp = Crippen.MolLogP(mol)
        hbd = Lipinski.NumHDonors(mol)
        hba = Lipinski.NumHAcceptors(mol)
        tpsa = Descriptors.TPSA(mol)
        rot_bonds = Lipinski.NumRotatableBonds(mol)
        rings = Lipinski.RingCount(mol)
        qed_score = QED.qed(mol)

        if not (150 <= mw <= 500):
            return False
        if not (-1.0 <= logp <= 5.0):
            return False
        if not (hbd <= 5):
            return False
        if not (hba <= 10):
            return False
        if not (20 <= tpsa <= 140):
            return False
        if not (rot_bonds <= 10):
            return False
        if not (1 <= rings <= 6):
            return False

        if pains_catalog.HasMatch(mol):
            return False

        # 3. QED Filter
        if not (qed_score > 0.4):
            return False

        return True

    except Exception:
        return False
