import numpy as np
from scipy.stats import gaussian_kde
from scipy.spatial.distance import jensenshannon
from scipy.stats import wasserstein_distance
from multiprocessing import Pool
from rdkit import Chem
from rdkit.Chem import rdMolDescriptors
from typing import Dict, List, Tuple, Set

from moses.metrics.metrics import (
    compute_intermediate_statistics,
    fraction_valid,
    fraction_unique,
    internal_diversity,
    remove_invalid,
    SNNMetric,
    novelty
)
from moses.metrics.utils import mapper, get_mol
from moses.utils import disable_rdkit_log, enable_rdkit_log

def get_all_metrics(gen, k=None, n_jobs=1,
                    device='cpu', batch_size=512, pool=None,
                    test=None, ptest=None, prop_test=None,
                    train=None,
                    distribution_method='wasserstein'):
    """
    Computes specified metrics between test and generated sets of SMILES.
    Parameters:
        gen: list of generated SMILES
        k: int or list with values for unique@k. Will calculate number of
            unique molecules in the first k molecules. Default [1000, 10000]
        n_jobs: number of workers for parallel processing
        device: 'cpu' or 'cuda:n', where n is GPU device number
        batch_size: batch size for compute_intermediate_statistics from MOSES
        pool: optional multiprocessing pool to use for parallelization
        test (None or list): test SMILES. If None, will load
            a default test set
        ptest (None or dict): precalculated statistics of the test set. If
            None, will load default test statistics.
        prop_test (None or dict): precalculated molecular properties of the test set.
        train (None or list): train SMILES. If None, will load a default
            train set
        distribution_method: 'wasserstein' or 'kde' for distribution comparison
    Available metrics:
        * valid: fraction of valid SMILES
        * p-valid: fraction of valid polymer SMILES
        * unique@k: fraction of unique molecules in first k molecules
        * Novel: fraction of novel molecules not in training set
        * SNN: similarity to nearest neighbour
        * IntDiv: internal diversity
        * Distribution-based metrics (Molar mass, Aromatic fraction, 
          Rotatable bond fraction, Heteroatom fraction, TPSA)
    """
    if test is None:
        raise ValueError(
            "You should specify test set")

    if k is None:
        k = [1000, 10000]
    disable_rdkit_log()
    metrics = {}
    
    if pool is None:
        pool = Pool(n_jobs) if n_jobs != 1 else None
            
    metrics['valid'] = fraction_valid(gen, n_jobs=n_jobs)
    metrics['p-valid'] = np.mean(p_valid(gen))
    
    gen = remove_invalid(gen, canonize=True)
    k_list = k if isinstance(k, (list, tuple)) else [k]
    for _k in k_list:
        metrics[f'unique@{_k}'] = fraction_unique(gen, _k, n_jobs)

    if ptest is None:
        ptest = compute_intermediate_statistics(test, n_jobs=n_jobs,
                                                device=device,
                                                batch_size=batch_size,
                                                pool=pool)

    if prop_test is None:
        prop_test = compute_molecular_properties(test, n_jobs=n_jobs,
                                                device=device,
                                                batch_size=batch_size,
                                                pool=pool)

    mols = mapper(n_jobs)(get_mol, gen)
    metrics['SNN'] = SNNMetric(n_jobs=n_jobs, device=device, batch_size=batch_size)(gen=mols, pref=ptest['SNN'])
    metrics['IntDiv'] = internal_diversity(mols, n_jobs, device=device)

    # Distribution-based properties
    properties = ['molar_mass', 'aromatic_fraction', 'rotatable_bond_fraction', 'heteroatom_fraction', 'tpsa']
    prop_gen = compute_molecular_properties(gen, n_jobs=n_jobs, device=device, batch_size=batch_size, pool=pool)
                                             
    for prop in properties:
        if distribution_method == 'wasserstein':
            metrics[prop] = compare_distributions_wasserstein(prop_gen[prop], prop_test[prop])
        else:  # kde method
            metrics[prop] = compare_distributions_kde(prop_gen[prop], prop_test[prop])

    if train is not None:
        metrics['Novel'] = novelty(mols, train, n_jobs)
    enable_rdkit_log()
    
    if pool is not None and n_jobs != 1:
        pool.close()
        pool.join()
    return metrics


def p_valid(data: List[str]) -> List[bool]:
    
    def is_valid_single(cur: str) -> bool:
        
        if cur.count('*') != 2:
            return False
            
        mol = Chem.MolFromSmiles(cur)
        if not mol:
            return False

        star_idxs = [atom.GetIdx() for atom in mol.GetAtoms() if atom.GetSymbol() == '*']
        
        bond_types_set = set()
        for star_idx in star_idxs:
            if len(mol.GetAtomWithIdx(star_idx).GetNeighbors()) != 1:
                return False
            bond_types_set.update({bond.GetBondType() for bond in mol.GetAtomWithIdx(star_idx).GetBonds()})
        
        if len(bond_types_set) != 1:
            return False
        
        return True

    return list(map(is_valid_single, data))


def compute_molecular_properties(smiles_list: List[str],
                                 n_jobs: int = 1,
                                 device: str = 'cpu',
                                 batch_size: int = 512,
                                 pool: Pool = None) -> Dict[str, List[float]]:
    """
    Compute various molecular properties in parallel for a list of SMILES strings.

    Args:
        smiles_list: List of SMILES strings.
        n_jobs: Number of jobs for parallel processing.
        device: Device to use.
        batch_size: Batch size for processing.
        pool: Optional multiprocessing Pool.

    Returns:
        Dictionary mapping property names to lists of computed values.
    """
    if pool is None and n_jobs != 1:
        pool = Pool(n_jobs)
        close_pool = True
    else:
        close_pool = False
        if n_jobs == 1:
            pool = None

    if pool is not None:
        results = pool.map(_process_single_smiles, smiles_list)
    else:
        results = [_process_single_smiles(smiles) for smiles in smiles_list]
    
    valid_results = [r for r in results if r is not None]

    output = {
        'molar_mass': [r['molar_mass'] for r in valid_results],
        'aromatic_fraction': [r['aromatic_fraction'] for r in valid_results],
        'rotatable_bond_fraction': [r['rotatable_bond_fraction'] for r in valid_results],
        'heteroatom_fraction': [r['heteroatom_fraction'] for r in valid_results],
        'tpsa': [r['tpsa'] for r in valid_results]
    }

    if close_pool:
        pool.terminate()

    return output


def _process_single_smiles(smiles: str) -> Dict[str, any]:
    """
    Helper function to process a single SMILES string.
    """
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None

    try:
        num_atoms = mol.GetNumHeavyAtoms()
        num_bonds = mol.GetNumBonds()
        if num_atoms == 0 or num_bonds == 0:
            return None

        return {
            'molar_mass': rdMolDescriptors.CalcExactMolWt(mol),
            'aromatic_fraction': sum(atom.GetIsAromatic() for atom in mol.GetAtoms()) / num_atoms,
            'rotatable_bond_fraction': rdMolDescriptors.CalcNumRotatableBonds(mol) / num_bonds,
            'heteroatom_fraction': sum(1 for atom in mol.GetAtoms() if atom.GetAtomicNum() not in {1, 6}) / num_atoms,
            'tpsa': rdMolDescriptors.CalcTPSA(mol)
        }
    except Exception as e:
        print(f"Error processing SMILES '{smiles}': {e}")
        return None

def compare_distributions_wasserstein(gen, prop_test):
    """
    Compare two distributions using Wasserstein distance.
    
    Parameters:
    - prop_test: array-like, reference distribution (e.g., test set)
    - gen: array-like, generated distribution to compare
    
    Returns:
    - wasserstein_dist: float, Wasserstein distance between distributions
    """
    if len(prop_test) == 0 or len(gen) == 0:
        return np.nan
    
    try:
        return wasserstein_distance(prop_test, gen)
    except Exception as e:
        print(f"Error computing Wasserstein distance: {e}")
        return np.nan


def compare_distributions_kde(gen, prop_test, n_points=1000):
    """
    Compare two distributions using KDE and Jensen-Shannon divergence.
    
    Parameters:
    - prop_test: array-like, reference distribution (e.g., test set)
    - gen: array-like, generated distribution to compare
    - n_points: number of points at which to evaluate KDE
    
    Returns:
    - js_divergence: float, Jensen-Shannon divergence between KDE-based distributions
    """
    if len(prop_test) == 0 or len(gen) == 0:
        return np.nan

    # Determine common evaluation range
    x = np.linspace(min(np.min(prop_test), np.min(gen)), max(np.max(prop_test), np.max(gen)), n_points)

    # Estimate KDEs
    try:
        kde_test = gaussian_kde(prop_test)(x)
        kde_gen = gaussian_kde(gen)(x)
    except np.linalg.LinAlgError:
        # Handle cases with nearly identical values causing singular matrix
        return np.nan

    # Normalize densities to sum to 1 (for proper probability distributions)
    kde_test /= np.trapz(kde_test, x)
    kde_gen /= np.trapz(kde_gen, x)

    # Compute JS divergence
    return jensenshannon(kde_test, kde_gen)
