import os
import time
from functools import lru_cache, partial
from tqdm.auto import tqdm

import pandas as pd
import numpy as np
from jax import random
import jax.numpy as jnp

from scipy.stats import ttest_ind

from rdkit import Chem
from rdkit.Chem import Descriptors, GraphDescriptors
from rdkit.Chem.Fragments import fr_benzene, fr_amide, fr_unbrch_alkane, fr_ether
from rdkit.Chem import Fragments as Frag
from rdkit.Chem.rdMolDescriptors import CalcNumSpiroAtoms, CalcNumBridgeheadAtoms

from ..tools import train_utils
from .sae import decode_latent, normalize
from ..impl import transformer, hooks as hooks_lib
from ..tools.stem import clean_smiles, molstats

descriptor_categories = {
    "Basic Properties": {
        "MolWt": Descriptors.MolWt,
        "ExactMolWt": Descriptors.ExactMolWt,
        "HeavyAtomCount": Descriptors.HeavyAtomCount,
        "NumHeteroatoms": Descriptors.NumHeteroatoms,
        "NumRotatableBonds": Descriptors.NumRotatableBonds,
        "RingCount": Descriptors.RingCount,
        "FractionCSP3": Descriptors.FractionCSP3,
    },
    "Polarity & Lipophilicity": {
        "LogP": Descriptors.MolLogP,
        "MR": Descriptors.MolMR,
        "TPSA": Descriptors.TPSA,
        "LabuteASA": Descriptors.LabuteASA,
    },
    "HBonding & Ionization": {
        "NumHDonors": Descriptors.NumHDonors,
        "NumHAcceptors": Descriptors.NumHAcceptors,
        "NHOH_Count": Descriptors.NHOHCount,
        "NOCount": Descriptors.NOCount,
        "NumAliphaticCarbocycles": Descriptors.NumAliphaticCarbocycles,
        "NumAliphaticHeterocycles": Descriptors.NumAliphaticHeterocycles,
    },
    "Rings & Aromaticity": {
        "NumAromaticRings": Descriptors.NumAromaticRings,
        "NumSaturatedRings": Descriptors.NumSaturatedRings,
        "NumAromaticHeterocycles": Descriptors.NumAromaticHeterocycles,
        "NumAromaticCarbocycles": Descriptors.NumAromaticCarbocycles,
        "NumSpiroAtoms": CalcNumSpiroAtoms,
        "NumBridgeheadAtoms": CalcNumBridgeheadAtoms,
    },
    "Complexity": {
        "BertzCT": Descriptors.BertzCT,
        "Kappa1": GraphDescriptors.Kappa1,
        "HallKierAlpha": GraphDescriptors.HallKierAlpha,
        "Chi0v": GraphDescriptors.Chi0v,
        "Chi1v": GraphDescriptors.Chi1v,
        "Kappa2": GraphDescriptors.Kappa2,
        "Kappa3": GraphDescriptors.Kappa3,
        "BalabanJ": GraphDescriptors.BalabanJ,
        "Chi2v": GraphDescriptors.Chi2v,
        "Chi3v": GraphDescriptors.Chi3v,
        "EState_VSA_4": Descriptors.EState_VSA4,
    },
    "Drug-likeness": {
        "QED": Descriptors.qed,
        "MaxAbsPartialCharge": Descriptors.MaxAbsPartialCharge,
        "MinAbsPartialCharge": Descriptors.MinAbsPartialCharge,
        "MaxPartialCharge": Descriptors.MaxPartialCharge,
        "MinPartialCharge": Descriptors.MinPartialCharge,
    },
    "Common Fragments": {
        "fr_benzene": fr_benzene,
        "fr_amide": fr_amide,
        "fr_ether": fr_ether,
        "fr_unbrch_alkane": fr_unbrch_alkane,
    },
}


descriptor_categories["Common Fragments"] = {
    # ── rings / hetero-rings ────────────────────────────────────────────
    "fr_benzene": Frag.fr_benzene,  # already had
    "fr_phenol": Frag.fr_phenol,
    "fr_pyridine": Frag.fr_pyridine,
    "fr_imidazole": Frag.fr_imidazole,
    "fr_thiazole": Frag.fr_thiazole,
    "fr_tetrazole": Frag.fr_tetrazole,
    # ── carbonyl chemistry ──────────────────────────────────────────────
    "fr_ketone": Frag.fr_ketone,
    "fr_aldehyde": Frag.fr_aldehyde,
    "fr_ester": Frag.fr_ester,
    "fr_amide": Frag.fr_amide,  # already had
    "fr_lactam": Frag.fr_lactam,
    "fr_acid": Frag.fr_COO,  # carboxylic acid
    # ── S- and P-containing groups ──────────────────────────────────────
    "fr_sulfone": Frag.fr_sulfone,
    "fr_sulfonamide": Frag.fr_sulfonamd,
    "fr_thiophene": Frag.fr_thiophene,
    # ── halogens / leaving groups ───────────────────────────────────────
    "fr_halogen": Frag.fr_halogen,
    "fr_alkyl_halide": Frag.fr_alkyl_halide,
    # ── nitro / azo / nitrile etc. ──────────────────────────────────────
    "fr_nitro": Frag.fr_nitro,
    "fr_azo": Frag.fr_azo,
    # ── simple hydrocarbons (already had un-branched) ───────────────────
    "fr_unbrch_alkane": Frag.fr_unbrch_alkane,  # already had
    "fr_ether": Frag.fr_ether,  # already had
    # ── cations / anions ────────────────────────────────────────────────
    "fr_quatN": Frag.fr_quatN,
}


def screen(sae_kit, layer_ids, dataset, process_fn, num_batches=50):
    os.makedirs("descriptor_analysis", exist_ok=True)

    all_descriptors = {}
    category_map = {}
    for category, descriptors in descriptor_categories.items():
        for name, func in descriptors.items():
            all_descriptors[name] = func
            category_map[name] = category

    descriptor_names = list(all_descriptors.keys())

    @lru_cache(maxsize=10000)
    def calculate_all_descriptors(smiles):
        mol = Chem.MolFromSmiles(smiles)
        if not mol:
            return None

        result = {}
        try:
            for name, func in all_descriptors.items():
                result[name] = func(mol)
            return result
        except Exception:
            return None

    layer_results = {}

    for layer_id in layer_ids:
        print(f"\n{'=' * 50}")
        print(f"Processing Layer {layer_id}")
        print(f"{'=' * 50}")

        start_time = time.time()

        latent_size = sae_kit.sae_configs[layer_id].latent_size
        max_molecules = num_batches * 1024

        all_activations = np.full(
            (max_molecules, latent_size), np.nan, dtype=np.float32
        )
        all_descriptors_values = {
            name: np.full(max_molecules, np.nan) for name in descriptor_names
        }
        all_smiles = [""] * max_molecules

        total_valid = 0

        print("Collecting molecule data and neuron activations...")
        for batch_idx, batch in enumerate(tqdm(dataset, total=num_batches)):
            if batch_idx >= num_batches:
                break

            masked_acts, raw_inputs = process_fn(batch, sae_kit, layer_id)

            for i in range(len(raw_inputs)):
                if total_valid >= max_molecules:
                    break

                smiles = "".join(raw_inputs[i])
                clean_smi = clean_smiles(smiles)

                if not clean_smi:
                    continue

                descriptors_dict = calculate_all_descriptors(clean_smi)

                if descriptors_dict:
                    mol_max_acts = np.max(masked_acts[i], axis=0)

                    all_activations[total_valid] = mol_max_acts
                    for desc_name in descriptor_names:
                        all_descriptors_values[desc_name][total_valid] = (
                            descriptors_dict[desc_name]
                        )
                    all_smiles[total_valid] = clean_smi

                    total_valid += 1

        all_activations = all_activations[:total_valid]
        all_smiles = all_smiles[:total_valid]
        for desc_name in descriptor_names:
            all_descriptors_values[desc_name] = all_descriptors_values[desc_name][
                :total_valid
            ]

        print(
            f"Collected data for {total_valid} valid molecules in {time.time() - start_time:.2f} seconds"
        )

        print("Calculating correlations with descriptors...")
        correlation_matrix = np.zeros((latent_size, len(descriptor_names)))
        pvalue_matrix = np.zeros((latent_size, len(descriptor_names)))

        from scipy.stats import pearsonr

        corr_start = time.time()
        for d_idx, desc_name in enumerate(tqdm(descriptor_names)):
            desc_values = all_descriptors_values[desc_name]

            if np.var(desc_values) < 1e-10:
                continue

            for n_idx in range(latent_size):
                neuron_acts = all_activations[:, n_idx]

                if np.var(neuron_acts) < 1e-10:
                    continue

                r, p = pearsonr(neuron_acts, desc_values)
                correlation_matrix[n_idx, d_idx] = r
                pvalue_matrix[n_idx, d_idx] = p

        print(
            f"Correlation calculation completed in {time.time() - corr_start:.2f} seconds"
        )

        neuron_names = [f"neuron_{n}" for n in range(latent_size)]

        correlation_df = pd.DataFrame(
            correlation_matrix, index=neuron_names, columns=descriptor_names
        )

        pvalue_df = pd.DataFrame(
            pvalue_matrix, index=neuron_names, columns=descriptor_names
        )

        # Save correlation matrices
        correlation_df.to_csv(f"descriptor_analysis/layer{layer_id}_correlations.csv")
        pvalue_df.to_csv(f"descriptor_analysis/layer{layer_id}_pvalues.csv")

        summary_rows = []
        for desc_idx, desc in enumerate(descriptor_names):
            abs_corrs = np.abs(correlation_matrix[:, desc_idx])

            # Get top neurons
            top_indices = np.argsort(abs_corrs)[-5:][::-1]

            for idx in top_indices:
                if abs_corrs[idx] > 0.2:
                    summary_rows.append(
                        {
                            "Category": category_map.get(desc, "Other"),
                            "Descriptor": desc,
                            "Neuron": idx,
                            "Correlation": correlation_matrix[idx, desc_idx],
                            "Abs_Correlation": abs_corrs[idx],
                            "P_Value": pvalue_matrix[idx, desc_idx],
                            "Significant": pvalue_matrix[idx, desc_idx] < 0.01,
                        }
                    )

        summary_df = pd.DataFrame(summary_rows)
        summary_df = summary_df.sort_values("Abs_Correlation", ascending=False)
        summary_df.to_csv(
            f"descriptor_analysis/layer{layer_id}_summary.csv", index=False
        )

        layer_results[layer_id] = {
            "correlation_df": correlation_df,
            "pvalue_df": pvalue_df,
            "summary_df": summary_df,
            "descriptor_names": descriptor_names,
            "category_map": category_map,
        }

        molecule_data = {"SMILES": all_smiles}

        for desc_name in descriptor_names:
            molecule_data[desc_name] = all_descriptors_values[desc_name]

        neurons_sorted = summary_df["Neuron"].unique()
        for n in neurons_sorted:
            molecule_data[f"neuron_{n}"] = all_activations[:, n]

        pd.DataFrame(molecule_data).to_csv(
            f"descriptor_analysis/layer{layer_id}_molecule_data.csv", index=False
        )

        end_time = time.time()
        print(
            f"Layer {layer_id} analysis completed in {end_time - start_time:.2f} seconds"
        )

    return layer_results, molecule_data

def steer_with(
    layer_results,
    layer_id,
    sae_kit,
    collectors,
    num_top_neurons=5,
    steer_factor=4.0,
    batch_size=1000,
    num_samples=1000,
):

    os.makedirs("descriptor_steering", exist_ok=True)

    summary_df = layer_results[layer_id]["summary_df"]

    available_descriptors = []
    for desc in summary_df["Descriptor"].unique():
        if hasattr(Descriptors, desc) or desc.startswith("fr_"):
            available_descriptors.append(desc)

    filtered_summary = summary_df[summary_df["Descriptor"].isin(available_descriptors)]
    top_descriptors = filtered_summary.drop_duplicates("Descriptor").nlargest(
        num_top_neurons, "Abs_Correlation"
    )

    print(f"Testing {len(top_descriptors)} descriptors with strong neuron correlations")
    print(f"Using steering factor: {steer_factor}")

    print("\nGenerating baseline molecules...")
    start_time = time.time()

    baseline_stats, baseline_mols = train_utils.get_eval_fn(
        tokenizer=sae_kit.tokenizer,
        model_config=sae_kit.lm_config,
        num_samples=num_samples,
        batch_size=batch_size,
        metrics_fn=molstats,
        max_tokens=256,
        temp=0.6,
        log_metrics=False,
        return_results=True,
        run_fn=transformer.run,
    )(random.key(42), step=-1, params=sae_kit.lm_params)

    baseline_smiles = [clean_smiles(m) for m in baseline_mols]
    valid_baseline = [s for s in baseline_smiles if s]

    print(
        f"Generated {len(valid_baseline)} valid baseline molecules in {time.time() - start_time:.2f} seconds"
    )

    @lru_cache(maxsize=None)
    def calculate_descriptor(smiles, descriptor_name):
        mol = Chem.MolFromSmiles(smiles)
        if not mol:
            return None

        try:
            if descriptor_name.startswith("fr_"):
                from rdkit.Chem.Fragments import (
                    fr_benzene,
                    fr_amide,
                    fr_unbrch_alkane,
                    fr_ether,
                )

                if descriptor_name == "fr_benzene":
                    return fr_benzene(mol)
                elif descriptor_name == "fr_amide":
                    return fr_amide(mol)
                elif descriptor_name == "fr_unbrch_alkane":
                    return fr_unbrch_alkane(mol)
                elif descriptor_name == "fr_ether":
                    return fr_ether(mol)
                return None
            else:
                func = getattr(Descriptors, descriptor_name, None)
                if func:
                    return func(mol)
                return None
        except Exception:
            return None

    print("Pre-calculating baseline descriptor values...")
    baseline_descriptors = {}

    for desc_name in available_descriptors:
        values = []
        for smiles in valid_baseline:
            value = calculate_descriptor(smiles, desc_name)
            if value is not None:
                values.append(value)

        if values:
            baseline_descriptors[desc_name] = values

    # Prepare for batch processing
    results_rows = []

    # Process each descriptor
    for _, row in top_descriptors.iterrows():
        descriptor = row["Descriptor"]
        neuron_id = row["Neuron"]
        correlation = row["Correlation"]
        category = row["Category"] if "Category" in row else "Unknown"

        # Skip if descriptor not available
        if descriptor not in baseline_descriptors:
            print(f"Skipping {descriptor} - not available in baseline molecules")
            continue

        print(f"\n{'=' * 50}")
        print(
            f"Testing steering with neuron {neuron_id} (r={correlation:.4f}) for {descriptor}"
        )
        print(f"{'=' * 50}")

        # Calculate steering magnitude based on neuron activation statistics
        if not hasattr(collectors[layer_id], "std"):
            print(
                "WARNING: collectors[layer_id].std not found. Using default steering value."
            )
            steer_magnitude = 1.0
        else:
            stat = collectors[layer_id].std[neuron_id]
            steer_magnitude = steer_factor * stat
            print(
                f"Neuron {neuron_id} statistic: {stat:.6f}, steering magnitude: {steer_magnitude:.6f}"
            )

        # Try both steering directions to verify which is more effective
        directions = [("Positive", steer_magnitude), ("Negative", -steer_magnitude)]

        for direction_name, steer_val in directions:
            print(f"\nTrying {direction_name} steering (value={steer_val:.6f})...")

            # Create steering vector
            add_vec = jnp.zeros((sae_kit.sae_configs[layer_id].latent_size,))
            add_vec = add_vec.at[neuron_id].set(steer_val)

            # Define callback for steering
            def make_sae_cb(layer_id, add_vec):
                cfg = sae_kit.sae_configs[layer_id]
                prm = sae_kit.sae_params[layer_id]

                def _cb(x):
                    if cfg.rescale_inputs:
                        x_norm, x_mean, x_std = normalize(x)
                    else:
                        x_norm, x_mean, x_std = x, 0.0, 1.0

                    if cfg.pre_enc_bias:
                        x_enc = x_norm - prm["b_dec"]
                    else:
                        x_enc = x_norm

                    z_cur = x_enc @ prm["W_enc"]

                    # Add debug output to verify steering is happening
                    if np.random.random() < 0.01:  # Only print occasionally
                        print(
                            f"DEBUG: z_cur shape: {z_cur.shape}, z_cur[0, neuron_id]: {float(z_cur[0, neuron_id])}"
                        )

                    z_added = z_cur + add_vec

                    if np.random.random() < 0.01:  # Only print occasionally
                        print(
                            f"DEBUG: z_added[0, neuron_id]: {float(z_added[0, neuron_id])}, change: {float(z_added[0, neuron_id] - z_cur[0, neuron_id])}"
                        )

                    x_added = decode_latent(
                        z_added, params=prm, config=cfg, x_mean=x_mean, x_std=x_std
                    )
                    return x_added.astype(x.dtype)

                return _cb

            callback = make_sae_cb(layer_id, add_vec)

            # Create editor for steering
            editor = hooks_lib.ActivationEditor(
                edits=(
                    hooks_lib.Edit(
                        layer=layer_id,
                        kind=sae_kit.sae_configs[layer_id].placement,
                        op="call",
                        callback=callback,
                    ),
                )
            )

            steered_run = partial(transformer.run, editor=editor)

            # Generate steered molecules
            print(f"Generating steered molecules...")
            start_time = time.time()

            # Use a different seed for each direction
            random_seed = hash(f"{neuron_id}_{direction_name}_{descriptor}") % 2**32

            steered_stats, steered_mols = train_utils.get_eval_fn(
                tokenizer=sae_kit.tokenizer,
                model_config=sae_kit.lm_config,
                num_samples=num_samples,
                batch_size=batch_size,
                metrics_fn=molstats,
                max_tokens=256,
                temp=0.6,
                log_metrics=False,
                return_results=True,
                run_fn=steered_run,
            )(random.key(random_seed), step=-1, params=sae_kit.lm_params)

            steered_smiles = [clean_smiles(m) for m in steered_mols]
            valid_steered = [s for s in steered_smiles if s]

            print(
                f"Generated {len(valid_steered)} valid steered molecules in {time.time() - start_time:.2f} seconds"
            )

            # Calculate descriptor values for steered molecules
            steered_values = []
            for smiles in valid_steered:
                value = calculate_descriptor(smiles, descriptor)
                if value is not None:
                    steered_values.append(value)

            # Compare with baseline
            baseline_values = baseline_descriptors.get(descriptor, [])

            if baseline_values and steered_values:
                baseline_mean = np.mean(baseline_values)
                baseline_std = np.std(baseline_values)
                steered_mean = np.mean(steered_values)
                steered_std = np.std(steered_values)

                # Add statistical significance test
                try:
                    t_stat, p_value = ttest_ind(
                        steered_values, baseline_values, equal_var=False
                    )
                    significant = p_value < 0.05
                except Exception:
                    t_stat, p_value = np.nan, np.nan
                    significant = False

                # Handle division by zero safely
                if baseline_std > 0:
                    effect_size = (steered_mean - baseline_mean) / baseline_std
                else:
                    effect_size = (
                        float("inf")
                        if steered_mean > baseline_mean
                        else (float("-inf") if steered_mean < baseline_mean else 0.0)
                    )

                # Handle division by zero for percent change
                if baseline_mean != 0:
                    percent_change = (
                        (steered_mean - baseline_mean) / abs(baseline_mean) * 100
                    )
                else:
                    percent_change = (
                        float("inf")
                        if steered_mean > 0
                        else (float("-inf") if steered_mean < 0 else 0.0)
                    )

                # For display purposes, cap the infinities
                display_effect = min(max(effect_size, -1000), 1000)
                display_percent = min(max(percent_change, -1000), 1000)

                results_rows.append(
                    {
                        "Category": category,
                        "Descriptor": descriptor,
                        "Neuron": neuron_id,
                        "Direction": direction_name,
                        "Steer_Value": steer_val,
                        "Correlation": correlation,
                        "Baseline_Mean": baseline_mean,
                        "Baseline_Std": baseline_std,
                        "Steered_Mean": steered_mean,
                        "Steered_Std": steered_std,
                        "Effect_Size": effect_size,
                        "Percent_Change": percent_change,
                        "T_Statistic": t_stat,
                        "P_Value": p_value,
                        "Significant": significant,
                        "Display_Effect": display_effect,
                        "Display_Percent": display_percent,
                        "Baseline_Count": len(baseline_values),
                        "Steered_Count": len(steered_values),
                    }
                )


    results_df = pd.DataFrame(results_rows)

    return results_df