from rdkit import Chem, RDLogger
from rdkit.Chem import AllChem, MACCSkeys
from rdkit import DataStructs

RDLogger.DisableLog('rdApp.*')
from mini_moses.metrics.metrics import internal_diversity
from mini_moses.metrics.utils import get_mol, mapper

import re
import time
import random
random.seed(0)
import numpy as np
from multiprocessing import Pool

import copy
import wandb

def calculate_fingerprints(smiles):
    if smiles is None:
        return None
    mol = get_mol(smiles)
    if mol is None:
        return None
    return AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=512)

def calculate_similarity(pair):
    gen_fp, ref_fp = pair
    if gen_fp is not None and ref_fp is not None:
        return DataStructs.TanimotoSimilarity(gen_fp, ref_fp)
    return 0

def calculate_context_same_match(target_similar_smiles, smiles_gen, print_info=False, check=False):
    valid_pairs = [(t_list, g) for t_list, g in zip(target_similar_smiles, smiles_gen) if t_list is not None and g is not None]
    if not valid_pairs:
        return 0, 0.0
    
    # Print length and first element of each target group
    if print_info:
        print(f"Valid pairs:", valid_pairs)
        for i, (t_list, g) in enumerate(valid_pairs):
            if g not in t_list:
                print(f"Gen/Tar:{g}, Target group {i}: length={len(t_list)}, elements={t_list}")
    if check:
        for i, (t_list, g) in enumerate(valid_pairs):
            if g not in t_list:
                print(f"Gen/Tar:{g}, Target group {i}: length={len(t_list)}, elements={t_list}")
        
    matches = sum(1 for t_list, g in valid_pairs if g in t_list)
    return matches, matches / len(valid_pairs)

def calculate_rdkit_fingerprints(smiles):
    """Calculate RDKit fingerprints for a given SMILES string."""
    if smiles is None:
        return None
    mol = get_mol(smiles)
    if mol is None:
        return None
    fpgen = AllChem.GetRDKitFPGenerator()
    return fpgen.GetFingerprint(mol)

def calculate_rdkit_similarity(pair):
    """Calculate Tanimoto similarity between two RDKit fingerprints."""
    gen_fp, ref_fp = pair
    if gen_fp is not None and ref_fp is not None:
        return DataStructs.TanimotoSimilarity(gen_fp, ref_fp)
    return 0

def calculate_maccs_fingerprints(smiles):
    """Calculate MACCS fingerprints for a given SMILES string."""
    if smiles is None:
        return None
    mol = get_mol(smiles)
    if mol is None:
        return None
    return MACCSkeys.GenMACCSKeys(mol)

def calculate_maccs_similarity(pair):
    """Calculate Tanimoto similarity between two MACCS fingerprints."""
    gen_fp, ref_fp = pair
    if gen_fp is not None and ref_fp is not None:
        return DataStructs.TanimotoSimilarity(gen_fp, ref_fp)
    return 0

def remove_stereochemistry_smiles(smiles):
    """Remove stereochemistry from a SMILES string and return the modified SMILES."""
    if len(smiles) > 1:
        if smiles is None:
            return None
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return None
        Chem.rdmolops.RemoveStereochemistry(mol)
        smiles_no_stereo = Chem.MolToSmiles(mol)
        
        # Get the largest connected component by selecting the longest SMILES string
        if "." in smiles_no_stereo:
            components = smiles_no_stereo.split(".")
            largest_component = max(components, key=len)
            return largest_component
        return smiles_no_stereo
    else:
        return smiles

class BasicMolecularMetrics(object):
    def __init__(self, tokenizer, known_smiles=None, stat_ref=None, task_evaluator=None, n_jobs=8, device='cpu', batch_size=512):
        self.known_smiles_list = known_smiles
        self.tokenizer = tokenizer
        self.n_jobs = n_jobs
        self.device = device
        self.batch_size = batch_size
        self.stat_ref = stat_ref
        self.task_evaluator = task_evaluator

    def calculate_novelty(self, all_smiles_gen):
        """Calculate novelty of generated molecules compared to known dataset."""
        if not self.known_smiles_list or not all_smiles_gen:
            return -1.0, [None] * len(all_smiles_gen)
            
        valid_indices = [i for i, s in enumerate(all_smiles_gen) if s is not None]
        valid_smiles = [all_smiles_gen[i] for i in valid_indices]
        
        if not valid_smiles:
            return 0.0, [None] * len(all_smiles_gen)
        
        novel_indices = [i for i in valid_indices if all_smiles_gen[i] not in self.known_smiles_list]
        novelty_ratio = len(novel_indices) / len(valid_indices) if valid_indices else 0.0
        
        result_smiles = [False] * len(all_smiles_gen)
        for i in novel_indices:
            result_smiles[i] = True
            
        return novelty_ratio, result_smiles

    def graph_to_valid_smiles(self, generated):
        valid_list = []
        num_components = []
        all_smiles = []
        covered_atoms = set()
        for graph_idx, graph in enumerate(generated):
            node_types, bond_adj, position_adj = graph
            bond_adj = bond_adj - 1 # ensure null edge == -1
            position_adj = position_adj - 1 # ensure null edge == -1

            smiles1, smiles2, _ = self.tokenizer.decode(
                node_types.tolist(), 
                bond_adj.tolist(), 
                position_adj.tolist()
            )
            mol = get_mol(smiles1) or get_mol(smiles2)
            if mol is None:
                all_smiles.append(None)
                continue
            smiles = mol2smiles(mol)
            if smiles:
                components = smiles.split('.')
                num_components.append(len(components))
                largest_smiles = max(components, key=len)
                
                largest_mol = Chem.MolFromSmiles(largest_smiles)
                for atom in largest_mol.GetAtoms():
                    covered_atoms.add(atom.GetSymbol())
                largest_smiles = mol2smiles(largest_mol)
                valid_list.append(largest_smiles)
                all_smiles.append(largest_smiles)
            else:
                all_smiles.append(None)

        return valid_list, np.array(num_components), all_smiles, covered_atoms

    def calculate_fp_similarity(self, target_smiles, generated_smiles, group_target=True):
        if not group_target:
            assert len(generated_smiles) == len(target_smiles)
            N = len(generated_smiles)
            
            if self.n_jobs == 1:
                # Sequential processing
                generated_fps = [calculate_maccs_fingerprints(sm) for sm in generated_smiles]
                target_fps = [calculate_maccs_fingerprints(sm) for sm in target_smiles]
                similarities = [calculate_maccs_similarity(pair) for pair in zip(generated_fps, target_fps)]
            else:
                # Parallel processing
                with Pool(self.n_jobs) as pool:

                    generated_fps = pool.map(calculate_maccs_fingerprints, generated_smiles)
                    target_fps = pool.map(calculate_maccs_fingerprints, target_smiles)
                    similarities = pool.map(calculate_maccs_similarity, zip(generated_fps, target_fps))
        else:
            # Group target mode - target_smiles is a list of lists
            assert len(generated_smiles) == len(target_smiles)
            N = len(generated_smiles)
            similarities = []
            
            for i, (gen_sm, target_group) in enumerate(zip(generated_smiles, target_smiles)):
                if gen_sm is None or not target_group:
                    similarities.append(0)
                    continue
                    
                gen_fp = calculate_maccs_fingerprints(gen_sm)
                    
                if gen_fp is None:
                    similarities.append(0)
                    continue
                    
                # Calculate fingerprints for all target molecules in this group
                target_fps = []
                for t_sm in target_group:
                    if t_sm is not None:
                        t_fp = calculate_maccs_fingerprints(t_sm)
                        if t_fp is not None:
                            target_fps.append(t_fp)
                
                # Find maximum similarity
                if not target_fps:
                    similarities.append(0)
                    continue
                    
                # Calculate similarity between generated molecule and each target in the group
                group_similarities = []
                for t_fp in target_fps:
                    sim = calculate_maccs_similarity((gen_fp, t_fp))
                    group_similarities.append(sim)
                    
                # Use maximum similarity for this group
                similarities.append(max(group_similarities))
        
        return sum(similarities) / N

    def evaluate(self, generated, targets, target_similar_smiles):

        target_similar_smiles = [[remove_stereochemistry_smiles(sm) for sm in t_list] for t_list in target_similar_smiles]

        """Evaluates molecular generation metrics."""
        def print_section(title, content, indent=0):
            pad = " " * indent
            print(f"\n{pad}{title}")
            print(f"{pad}" + "-" * len(title))
            for line in content:
                print(f"{pad}{line}")

        # Process molecules
        valid_list_gen, num_components_gen, all_smiles_gen, covered_atoms_gen = self.graph_to_valid_smiles(generated)
        valid_list_target, num_components_target, all_smiles_target, covered_atoms_target = self.graph_to_valid_smiles(targets)
        
        # Calculate basic metrics
        metrics = {
            'validity': {
                'generated': {
                    'total': len(valid_list_gen) / len(generated),
                },
                'target': {
                    'total': len(valid_list_target) / len(targets),
                }
            },
            'components': {
                'min': num_components_gen.min() if len(num_components_gen) > 0 else 0,
                'mean': num_components_gen.mean() if len(num_components_gen) > 0 else 0,
                'max': num_components_gen.max() if len(num_components_gen) > 0 else 0
            }
        }

        # Format and print basic metrics
        sections = {
            'Coverage': [
                f"Generated: {len(covered_atoms_gen)} atoms {covered_atoms_gen}",
                f"Target:    {len(covered_atoms_target)} atoms {covered_atoms_target}"
            ],
            'Validity': [
                f"Generated ({len(generated)} molecules): {metrics['validity']['generated']['total']*100:.2f}%",
                f"Target ({len(targets)} molecules): {metrics['validity']['target']['total']*100:.2f}%",
            ],
            'Components': [
                f"Min:  {metrics['components']['min']:.2f}",
                f"Mean: {metrics['components']['mean']:.2f}",
                f"Max:  {metrics['components']['max']:.2f}"
            ]
        }

        print("\n=== Molecule Generation Report ===")
        for title, content in sections.items():
            print_section(title, content, indent=2)

        # Process advanced metrics if valid molecules exist
        if metrics['validity']['generated']['total'] > 0:
            start_time = time.time()
            pool = Pool(self.n_jobs) if self.n_jobs != 1 else 1
            
            # Calculate additional metrics
            valid_mol_gen = mapper(pool)(get_mol, valid_list_gen)
            valid_mol_target = mapper(pool)(get_mol, valid_list_target)
            
            for i in range(len(all_smiles_target)):
                target_similar_smiles[i][-1] = all_smiles_target[i]
            
            novelty_ratio, valid_novel_checklist = self.calculate_novelty(all_smiles_gen)
            dist_metrics = {
                'target_unique': len(set(valid_list_target)) / len(valid_list_target),
                'target_intdiv': internal_diversity(valid_mol_target, pool, device=self.device),
                'gen_validity': metrics['validity']['generated']['total'],
                'gen_unique': len(set(valid_list_gen)) / len(valid_list_gen),
                'gen_novelty': novelty_ratio,
                'gen_intdiv': internal_diversity(valid_mol_gen, pool, device=self.device),
                'sim/maccs': self.calculate_fp_similarity(target_similar_smiles, all_smiles_gen),
            }

            # Print advanced metrics
            nuv_smiles = list(set(smiles for smiles, is_valid_novel in zip(all_smiles_gen, valid_novel_checklist) if is_valid_novel))
            print(f"\nDetailed Metrics ({len(valid_list_gen)}/{len(generated)} molecules) \n N.U.V {len(nuv_smiles)} molecules: {nuv_smiles[:5]}")
            print(f"Time: {time.time() - start_time:.2f}s")
            print("-" * 40)
            
            max_key = max(len(k) for k in dist_metrics)
            for i, (key, val) in enumerate(dist_metrics.items()):
                if isinstance(val, (int, float, np.floating, np.integer)):
                    print(f"{key:>{max_key}}: {val:<7.4f}")
                if (i + 1) % 4 == 0:
                    print()

            if self.n_jobs != 1:
                pool.close()
                pool.join()
        else:
            dist_metrics = {}
            valid_list_gen = []
            valid_novel_checklist = []

        return (
            list(set(valid_list_gen)),
            {'nc_min': metrics['components']['min'], 
            'nc_max': metrics['components']['max'], 
            'nc_mu': metrics['components']['mean']},
            {"generation": all_smiles_gen, "target": all_smiles_target, "valid_novel": valid_novel_checklist},
            dist_metrics,
            None
        )
 
def mol2smiles(mol):
    if mol is None:
        return None
    try:
        Chem.SanitizeMol(mol)
    except ValueError:
        return None
    return Chem.MolToSmiles(mol)

def compute_molecular_metrics(molecule_list, targets, target_similar_smiles,known_smiles, stat_ref, dataset_info, task_evaluator, comput_config):
    """ molecule_list: (dict) """

    tokenizer = dataset_info.tokenizer
    metrics = BasicMolecularMetrics(tokenizer, known_smiles, stat_ref, task_evaluator, **comput_config)
    evaluated_res = metrics.evaluate(molecule_list, targets, target_similar_smiles)
    all_smiles_dict = evaluated_res[-3]
    all_metrics = evaluated_res[-2]
    targets_log = evaluated_res[-1]
    unique_smiles = evaluated_res[0]

    if wandb.run:
        to_log = {}
        for key, value in all_metrics.items():
            if isinstance(value, (int, float, np.integer, np.floating)):
                to_log[key] = value

        wandb.log(to_log)

    return unique_smiles, all_smiles_dict, all_metrics, targets_log

if __name__ == '__main__':
    pass