
import re


import numpy as np

# load metric stuff

from nltk.translate.bleu_score import corpus_bleu, sentence_bleu

from Levenshtein import distance as lev

from rdkit import Chem
from rdkit.Chem import MACCSkeys
from rdkit import DataStructs
from rdkit.Chem import AllChem
from rdkit.Chem.rdchem import KekulizeException

from rdkit import RDLogger
import selfies as sf

from fcd import get_fcd, load_ref_model, canonical_smiles

RDLogger.DisableLog("rdApp.*")

def filter_selfies(s):
    pattern = r'(\[[^\]]+\]\.?)'
    matches = re.findall(pattern, s)
    return ''.join(matches).replace('[[', '[').replace(']]', ']')


def parse_molecules(gt_smi, ot_smi):
    """Parse SMILES strings into molecule objects with placeholders for invalid ones"""
    parsed_pairs = []
    
    for gt, out in zip(gt_smi, ot_smi):
        try:
            m_gt = Chem.MolFromSmiles(gt)
            m_out = Chem.MolFromSmiles(out)
        except:
            m_gt, m_out = None, None
        
        parsed_pairs.append((gt, out, m_gt, m_out))
    
    return parsed_pairs

def _handle_invalid_mol(m_gt, m_out, default_score=0):
    """Return default score if either molecule is invalid"""
    return default_score if m_gt is None or m_out is None else None

def is_same_mol(mol_a, mol_b):
    return Chem.MolToInchi(mol_a) == Chem.MolToInchi(mol_b)

def evaluate(mol_pairs):
    """Return individual scores for each molecule pair"""
    bleu_scores = []
    exact_matches = []
    levs = []
    validity_scores = []
    gt_token = []
    hyp_token = []
    
    for gt, out, m_gt, m_out in mol_pairs:
        # Individual BLEU score
        ref = [[c for c in gt]]
        hyp = [c for c in out]
        gt_token.append(ref)
        hyp_token.append(hyp)
        bleu_scores.append(sentence_bleu(ref, hyp))
        
        # Exact match
        invalid_score = _handle_invalid_mol(m_gt, m_out)
        if invalid_score is not None:
            exact_matches.append(invalid_score)
            validity_scores.append(0)
        else:
            try:
                exact_matches.append(1 if Chem.MolToInchi(m_gt) == Chem.MolToInchi(m_out) else 0)
            except KekulizeException:
                exact_matches.append(0)
            validity_scores.append(1)
        
        # Levenshtein
        levs.append(lev(out, gt))
    corpus_bleu_scores = corpus_bleu(gt_token, hyp_token)
    
    return corpus_bleu_scores, bleu_scores, exact_matches, levs, validity_scores


def evaluate_fingerprint(mol_pairs, morgan_r, verbose=False):
    """Return individual fingerprint similarity scores for each molecule pair"""
    MACCS_sims = []
    morgan_sims = []
    RDK_sims = []
    
    for i, (_, _, gt_m, ot_m) in enumerate(mol_pairs):
        if verbose and i % 100 == 0:
            print(i, 'processed.')
        
        invalid_score = _handle_invalid_mol(gt_m, ot_m)
        if invalid_score is not None:
            MACCS_sims.append(invalid_score)
            RDK_sims.append(invalid_score)
            morgan_sims.append(invalid_score)
        else:
            MACCS_sims.append(DataStructs.FingerprintSimilarity(
                MACCSkeys.GenMACCSKeys(gt_m), MACCSkeys.GenMACCSKeys(ot_m), 
                metric=DataStructs.TanimotoSimilarity))
            RDK_sims.append(DataStructs.FingerprintSimilarity(
                Chem.RDKFingerprint(gt_m), Chem.RDKFingerprint(ot_m), 
                metric=DataStructs.TanimotoSimilarity))
            morgan_sims.append(DataStructs.TanimotoSimilarity(
                AllChem.GetMorganFingerprint(gt_m, morgan_r), 
                AllChem.GetMorganFingerprint(ot_m, morgan_r)))
    
    return MACCS_sims, RDK_sims, morgan_sims


def evaluate_fcd(mol_pairs, verbose=False):
    """Return FCD score (single value for entire set)"""
    # Filter valid molecules for FCD calculation
    valid_gt = [gt for gt, _, m_gt, m_out in mol_pairs if m_gt is not None and m_out is not None]
    valid_ot = [out for _, out, m_gt, m_out in mol_pairs if m_gt is not None and m_out is not None]
    
    if not valid_gt or not valid_ot:
        return 0
    
    # Handle empty outputs
    valid_ot = [s if len(s) > 0 else '[]' for s in valid_ot]
    
    model = load_ref_model()
    canon_gt_smis = [Chem.MolToSmiles(m_gt) for _, _, m_gt, _ in mol_pairs if m_gt is not None]
    canon_ot_smis = [Chem.MolToSmiles(m_out) for _, _, _, m_out in mol_pairs if m_out is not None]
    
    fcd_sim_score = get_fcd(canon_gt_smis, canon_ot_smis, model) if canon_gt_smis and canon_ot_smis else 0
    
    if verbose:
        print('FCD Similarity:', fcd_sim_score)
    
    return fcd_sim_score


def create_score_indie(truth_smis, ot_smis, mol_type="SMILES", gt_smiles=False, verbose=True):

    assert len(truth_smis) == len(
        ot_smis
    ), "Different number of ground truth and predictions."
    format_invalid = []
    content_invalid = []
    if mol_type == "SELFIES":
        new_mol_pred = []
        new_mol_truth = []
        for i in range(len(ot_smis)):
            if gt_smiles:
                gt_mol = truth_smis[i]
            else:
                gt_mol = sf.decoder(truth_smis[i].replace(" ", ''))
            clean_ot_smis = ot_smis[i].replace(" ", '')
            format_flag = False
            content_flag = False
            try:
                pd_mol = sf.decoder(clean_ot_smis)
            except:
                format_flag = True
            if format_flag:
                try:
                    pd_mol = sf.decoder(filter_selfies(clean_ot_smis))
                except:
                    content_flag = True
                    pd_mol = "E"
            format_invalid.append(format_flag)
            content_invalid.append(content_flag)
            new_mol_pred.append(pd_mol)
            new_mol_truth.append(gt_mol)
        truth_smis = new_mol_truth
        ot_smis = new_mol_pred
    # Parse molecules
    mol_pairs = parse_molecules(truth_smis, ot_smis)
    
    # Get individual scores
    corpus_bleu_score, bleu_scores, exact_matches, levs, validity_scores = evaluate(mol_pairs)
    maccs_sims, rdk_sims, morgan_sims = evaluate_fingerprint(mol_pairs, morgan_r=2)
    fcd_score = evaluate_fcd(mol_pairs)
    
    # Create individual scores dict
    individual_scores = {
        'bleu': bleu_scores,
        'exact_match': exact_matches,
        'levenshtein': levs,
        'validity': validity_scores,
        'maccs_similarity': maccs_sims,
        'rdk_similarity': rdk_sims,
        'morgan_similarity': morgan_sims
    }
    
    # Calculate and print averages
    averages = {k: np.mean(v) for k, v in individual_scores.items()}
    averages['fcd'] = fcd_score
    averages['corpus_bleu'] = corpus_bleu_score
    if verbose:
        print(f"Average scores: {averages}")
    
    return individual_scores


def create_retrosynthesis_score(gt, pd):
    mol_pairs = parse_molecules(gt, pd)
    maccs_sims, rdk_sims, morgan_sims = evaluate_fingerprint(mol_pairs, morgan_r=2)
    validity = [1 if m_gt is not None and m_out is not None else 0 for _, _, m_gt, m_out in mol_pairs]
    average = {}
    average["maccs"] = maccs_sims
    average["rdk"] = rdk_sims
    average["morgan"] = morgan_sims
    average["validity"] = validity
    averages = {k: np.mean(v) for k, v in average.items()}
    print(f"Average scores: {averages}")