import contextlib
import io
import os
import re
import tempfile
from collections import defaultdict
from functools import lru_cache

import jax.numpy as jnp
import networkx as nx
import numpy as np
import pandas as pd

from cgspan_mining import cgSpan as gSpan
from rdkit import Chem
from rdkit.Chem import rdSubstructLibrary as SSL
from rdkit.Chem.Scaffolds import MurckoScaffold
from scipy.stats import fisher_exact
from tqdm.auto import tqdm

from ..tools.stem import clean_smiles

@contextlib.contextmanager
def suppress_output(to_file=False):
    if to_file:
        buf = io.StringIO()
        targets = (buf, buf)
    else:
        devnull = open(os.devnull, "w")
        targets = (devnull, devnull)

    with contextlib.redirect_stdout(targets[0]), contextlib.redirect_stderr(targets[1]):
        yield (targets[0] if to_file else None)

    if not to_file:
        devnull.close()


def keep_motif(smarts_pattern: str) -> bool:
    molecule = Chem.MolFromSmarts(smarts_pattern)
    if molecule is None:
        return False
    return True


def smiles_to_nx(smi: str):
    mol = Chem.MolFromSmiles(clean_smiles(smi))
    if mol is None:
        return None
    g = nx.Graph()
    for a in mol.GetAtoms():
        g.add_node(a.GetIdx(), label=a.GetAtomicNum())
    bmap = {
        Chem.BondType.SINGLE: 1,
        Chem.BondType.DOUBLE: 2,
        Chem.BondType.TRIPLE: 3,
        Chem.BondType.AROMATIC: 4,
    }
    for b in mol.GetBonds():
        g.add_edge(b.GetBeginAtomIdx(), b.GetEndAtomIdx(), label=bmap[b.GetBondType()])
    return g


def graphs_to_file(graphs):
    lines = []
    for gid, g in graphs:
        lines.append(f"t # {gid}")
        for n, d in g.nodes(data=True):
            lines.append(f"v {n} {d['label']}")
        for u, v, d in g.edges(data=True):
            lines.append(f"e {u} {v} {d['label']}")
    tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".data", mode="w")
    tmp.write("\n".join(lines))
    tmp.close()
    return tmp.name


def get_pattern_graph(pattern):
    if hasattr(pattern, "graph"):
        return pattern.graph
    if hasattr(pattern, "to_graph"):
        return pattern.to_graph(is_undirected=True)
    if hasattr(pattern, "vertices"):
        return pattern
    raise TypeError(f"Unrecognised pattern type: {type(pattern)}")


def dfs_to_smarts(obj):
    if hasattr(obj, "vertices"):
        g = obj
    elif hasattr(obj, "graph"):
        g = obj.graph
    elif hasattr(obj, "to_graph"):
        g = obj.to_graph(is_undirected=True)
    else:
        raise TypeError(f"Unknown motif type: {type(obj)}")

    mol = Chem.RWMol()
    v2a = {}
    bond_map = {
        1: Chem.BondType.SINGLE,
        2: Chem.BondType.DOUBLE,
        3: Chem.BondType.TRIPLE,
        4: Chem.BondType.AROMATIC,
    }
    for v in g.vertices.values():
        v2a[v.vid] = mol.AddAtom(Chem.Atom(int(v.vlb)))
    for u in g.vertices:
        for v, e in g.vertices[u].edges.items():
            if u < v:
                mol.AddBond(
                    v2a[u], v2a[v], bond_map.get(int(e.elb), Chem.BondType.SINGLE)
                )
    return Chem.MolToSmarts(mol, isomericSmiles=False)


def compile_cluster_smiles(
    dataset,
    sae_kit,
    layer_id,
    cluster_neurons,
    preprocess_fn,
    min_avg_act=0.05,
    min_specificity=2.0,
    max_per_cluster=10_000,
):
    tmp_smiles_scores = defaultdict(list)
    latent_size = sae_kit.sae_configs[layer_id].latent_size

    for batch_idx, batch in enumerate(tqdm(dataset, desc="Scanning dataset")):
        masked_acts, raw_inputs_batch = preprocess_fn(batch, sae_kit, layer_id)

        seq_max_acts = jnp.max(masked_acts, axis=1)
        seq_max_acts_np = np.array(seq_max_acts)

        for i in range(seq_max_acts_np.shape[0]):
            mol_activations = seq_max_acts_np[i, :]
            smi_original = "".join(raw_inputs_batch[i])

            for cid, cluster_neuron_idx_tuple in cluster_neurons.items():
                if not isinstance(cluster_neuron_idx_tuple, np.ndarray):
                    cluster_neuron_idx = np.array(cluster_neuron_idx_tuple)
                else:
                    cluster_neuron_idx = cluster_neuron_idx_tuple

                if not cluster_neuron_idx.size:
                    continue

                cluster_neuron_idx = cluster_neuron_idx[
                    cluster_neuron_idx < latent_size
                ]
                if not cluster_neuron_idx.size:
                    continue

                act_in_cluster = mol_activations[cluster_neuron_idx]
                avg_act_in_cluster = np.mean(act_in_cluster)

                if avg_act_in_cluster < min_avg_act:
                    continue

                out_of_cluster_mask = np.ones(latent_size, dtype=bool)
                out_of_cluster_mask[cluster_neuron_idx] = False

                if not np.any(out_of_cluster_mask):
                    avg_act_out_cluster = 1e-9
                else:
                    act_out_cluster = mol_activations[out_of_cluster_mask]
                    avg_act_out_cluster = np.mean(act_out_cluster)

                epsilon = 1e-9
                specificity_ratio = avg_act_in_cluster / (avg_act_out_cluster + epsilon)

                if specificity_ratio >= min_specificity:
                    composite_score = specificity_ratio * avg_act_in_cluster
                    tmp_smiles_scores[cid].append((composite_score, smi_original))

    cluster_smiles = defaultdict(list)
    for cid in sorted(tmp_smiles_scores.keys()):
        sorted_smiles = sorted(tmp_smiles_scores[cid], key=lambda x: x[0], reverse=True)

        cluster_smiles[cid] = [smi for score, smi in sorted_smiles[:max_per_cluster]]

    return cluster_smiles


@lru_cache(maxsize=None)
def smarts_query(motif):
    query_molecule = Chem.MolFromSmarts(motif)
    return query_molecule


def get_motif_stats(motifs, in_smiles, out_smiles, in_lib, out_lib):
    results = []
    for motif in motifs:
        query_molecule = smarts_query(motif)
        if query_molecule is None:
            continue

        hits_in_cluster = in_lib.CountMatches(query_molecule)
        misses_in_cluster = len(in_smiles) - hits_in_cluster
        hits_out_cluster = out_lib.CountMatches(query_molecule)
        misses_out_cluster = len(out_smiles) - hits_out_cluster

        if hits_in_cluster == 0:
            continue
        coverage = hits_in_cluster / len(in_smiles)
        enrichment = coverage / max(hits_out_cluster / len(out_smiles), 1e-9)

        odds_ratio, p_value = fisher_exact(
            [
                [hits_in_cluster, misses_in_cluster],
                [hits_out_cluster, misses_out_cluster],
            ],
            alternative="greater",
        )
        motif_score = coverage * np.log2(enrichment + 1)
        results.append(
            (
                motif,
                hits_in_cluster,
                coverage,
                enrichment,
                odds_ratio,
                p_value,
                motif_score,
            )
        )

    return results


def clean_scaffolds(smiles_list, max_graphs, rng):
    scaffold_representative = {}
    for smiles in smiles_list:
        clean_smiles_string = clean_smiles(smiles)
        if not clean_smiles_string:
            continue
        molecule = Chem.MolFromSmiles(clean_smiles_string)
        scaffold = MurckoScaffold.GetScaffoldForMol(molecule)
        scaffold_smiles = Chem.MolToSmiles(scaffold, isomericSmiles=False)
        if scaffold_smiles and scaffold_smiles not in scaffold_representative:
            scaffold_representative[scaffold_smiles] = clean_smiles_string

    scaffold_smiles = list(scaffold_representative.values())
    if len(scaffold_smiles) > max_graphs:
        scaffold_smiles = rng.choice(scaffold_smiles, max_graphs, replace=False)

    return scaffold_smiles


def make_global_library(cluster_smiles):
    mol_holder = SSL.CachedTrustedSmilesMolHolder()
    fp_holder = SSL.PatternHolder()  # no “Fingerprint” in the name
    cluster_indices = defaultdict(set)
    smiles_seen = {}

    for cid, smi_list in cluster_smiles.items():
        for smi in smi_list:
            cs = clean_smiles(smi)
            if not cs:
                continue
            if cs in smiles_seen:
                idx = smiles_seen[cs]
            else:
                idx = mol_holder.AddSmiles(
                    cs
                )  # holder API :contentReference[oaicite:1]{index=1}
                fp_holder.AddMol(Chem.MolFromSmiles(cs))  # fingerprints
                smiles_seen[cs] = idx
            cluster_indices[cid].add(idx)

    lib = SSL.SubstructLibrary(mol_holder, fp_holder)  # wrap them
    return lib, cluster_indices, set(smiles_seen.values())


def motif_stats_global(motifs, lib, cluster_indices, all_idx):
    results = defaultdict(list)
    total = len(all_idx)

    # pre-compute the hit sets – one RDKit call per motif only
    motif_hits = {}
    for m in motifs:
        q = smarts_query(m)
        if q is not None:
            motif_hits[m] = set(lib.GetMatches(q))  # fast + multithreaded

    for cid, in_set in cluster_indices.items():
        in_size = len(in_set)
        out_size = total - in_size
        out_set = all_idx - in_set

        for smarts, hit_set in motif_hits.items():
            hits_in = len(hit_set & in_set)
            if hits_in == 0:  # skip motifs absent in the cluster
                continue
            hits_out = len(hit_set & out_set)

            misses_in = in_size - hits_in
            misses_out = out_size - hits_out

            coverage = hits_in / in_size
            enrichment = coverage / max(hits_out / out_size, 1e-9)

            odds, p = fisher_exact(
                [[hits_in, misses_in], [hits_out, misses_out]],
                alternative="greater",
            )
            score = coverage * np.log2(enrichment + 1)
            results[cid].append((smarts, hits_in, coverage, enrichment, odds, p, score))

    return results


def mine_subgraphs(
    cluster_smiles, num_clusters=None, min_support=100, max_graphs=1000, seed=2002
):
    random_generator = np.random.default_rng(seed=seed)

    lib, cluster_idx, all_idx = make_global_library(cluster_smiles)

    print("Collecting SMILES from clusters...")

    if num_clusters is None:
        clusters_to_mine = list(cluster_smiles.keys())[:num_clusters]
    else:
        clusters_to_mine = list(cluster_smiles.keys())

    pattern_results = {}

    for cluster_id in tqdm(clusters_to_mine, desc="Mining clusters"):
        smiles_list = cluster_smiles[cluster_id]

        scaffold_smiles_list = clean_scaffolds(
            smiles_list, max_graphs, random_generator
        )

        molecular_graphs = []
        for index, smiles in enumerate(scaffold_smiles_list):
            graph = smiles_to_nx(smiles)
            if graph is not None:
                molecular_graphs.append((index, graph))

        if not molecular_graphs:
            print(f"Cluster {cluster_id}: no valid graphs after filtering")
            continue

        print(
            f"Cluster {cluster_id}: mining {len(molecular_graphs)} graphs (minsup {min_support})"
        )

        database_file = graphs_to_file(molecular_graphs)
        graph_miner = gSpan(
            database_file_name=database_file,
            min_support=min_support,
            is_undirected=True,
            verbose=False,
        )

        with suppress_output():
            graph_miner.run()

        os.remove(database_file)

        print(
            f"Cluster {cluster_id}: found {len(graph_miner._frequent_subgraphs)} raw motifs"
        )

        molecular_motifs = [
            dfs_to_smarts(pattern) for pattern in graph_miner._frequent_subgraphs
        ]
        valid_motifs = [motif for motif in molecular_motifs if keep_motif(motif)]

        cluster_results = motif_stats_global(valid_motifs, lib, cluster_idx, all_idx)[
            cluster_id
        ]
        results_df = pd.DataFrame(
            cluster_results,
            columns=[
                "SMARTS",
                "hits_in",
                "coverage",
                "enrichment",
                "odds_ratio",
                "p_value",
                "score",
            ],
        ).sort_values("score", ascending=False)

        pattern_results[cluster_id] = results_df

    return pattern_results
