## For Single-Objective Molecule Optimization Benchmark

import os
import json
import numpy as np
import pandas as pd
from tqdm import tqdm

from rdkit import DataStructs
from rdkit.Chem import Descriptors
import rdkit.Chem as Chem
from rdkit.Chem import rdFMCS
from rdkit.Chem import AllChem
from rdkit.Chem.Scaffolds.MurckoScaffold import GetScaffoldForMol, MurckoScaffoldSmiles

from tdc import oracles
from itertools import combinations

from chem_metrics.eval_mol_edit import check_edit_add_valid, check_edit_del_valid, check_edit_sub_valid

# 绘制acc分布图,绘制四个子图
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns


def mol_prop(prop):
    def inner(smol):
        try:
            mol = Chem.MolFromSmiles(smol)
        except:
            return None
        # always remember to check if mol is None
        if mol is None:
            # print(f"invalid mol: {smol}")
            return None
        
        ## Basic Properties
        if prop == 'logp':
            return Descriptors.MolLogP(mol)
        elif prop == 'weight':
            return Descriptors.MolWt(mol)
        elif prop == 'qed':
            return Descriptors.qed(mol)
        elif prop == 'TPSA':
            return Descriptors.TPSA(mol)
        elif prop == 'HBA': # Hydrogen Bond Acceptor
            return Descriptors.NumHAcceptors(mol)
        elif prop == 'HBD': # Hydrogen Bond Donor
            return Descriptors.NumHDonors(mol)
        elif prop == 'rot_bonds': # rotatable bonds
            return Descriptors.NumRotatableBonds(mol)
        elif prop == 'ring_count':
            return Descriptors.RingCount(mol)
        elif prop == 'mr': # Molar Refractivity
            return Descriptors.MolMR(mol)
        elif prop == 'balabanJ':
            return Descriptors.BalabanJ(mol)
        elif prop == 'hall_kier_alpha':
            return Descriptors.HallKierAlpha(mol)
        elif prop == 'logD':
            return Descriptors.MolLogP(mol)
        elif prop == 'MR':
            return Descriptors.MolMR(mol)

        ## If Molecule is valid
        elif prop == 'validity':   
            # print(mol)
            return True
        
        ## Bond Counts
        elif prop == 'num_single_bonds':
            return sum([bond.GetBondType() == Chem.rdchem.BondType.SINGLE for bond in mol.GetBonds()])
        elif prop == 'num_double_bonds':
            return sum([bond.GetBondType() == Chem.rdchem.BondType.DOUBLE for bond in mol.GetBonds()])
        elif prop == 'num_triple_bonds':
            return sum([bond.GetBondType() == Chem.rdchem.BondType.TRIPLE for bond in mol.GetBonds()])
        elif prop == 'num_aromatic_bonds':
            return sum([bond.GetBondType() == Chem.rdchem.BondType.AROMATIC for bond in mol.GetBonds()])
        elif prop == 'num_rotatable_bonds': # rotatable bonds
            return Descriptors.NumRotatableBonds(mol)

        
        ## Common Atom Counts
        elif prop == 'num_carbon':
            return sum([atom.GetAtomicNum() == 6 for atom in mol.GetAtoms()])
        elif prop == 'num_nitrogen':
            return sum([atom.GetAtomicNum() == 7 for atom in mol.GetAtoms()])
        elif prop == 'num_oxygen':
            return sum([atom.GetAtomicNum() == 8 for atom in mol.GetAtoms()])
        elif prop == 'num_fluorine':
            return sum([atom.GetAtomicNum() == 9 for atom in mol.GetAtoms()])
        elif prop == 'num_phosphorus':
            return sum([atom.GetAtomicNum() == 15 for atom in mol.GetAtoms()])
        elif prop == 'num_sulfur':
            return sum([atom.GetAtomicNum() == 16 for atom in mol.GetAtoms()])
        elif prop == 'num_chlorine':
            return sum([atom.GetAtomicNum() == 17 for atom in mol.GetAtoms()])
        elif prop == 'num_bromine':
            return sum([atom.GetAtomicNum() == 35 for atom in mol.GetAtoms()])
        elif prop == 'num_iodine':
            return sum([atom.GetAtomicNum() == 53 for atom in mol.GetAtoms()])
        elif prop == "num_boron":
            return sum([atom.GetAtomicNum() == 5 for atom in mol.GetAtoms()])
        elif prop == "num_silicon":
            return sum([atom.GetAtomicNum() == 14 for atom in mol.GetAtoms()])
        elif prop == "num_selenium":
            return sum([atom.GetAtomicNum() == 34 for atom in mol.GetAtoms()])
        elif prop == "num_tellurium":
            return sum([atom.GetAtomicNum() == 52 for atom in mol.GetAtoms()])
        elif prop == "num_arsenic":
            return sum([atom.GetAtomicNum() == 33 for atom in mol.GetAtoms()])
        elif prop == "num_antimony":
            return sum([atom.GetAtomicNum() == 51 for atom in mol.GetAtoms()])
        elif prop == "num_bismuth":
            return sum([atom.GetAtomicNum() == 83 for atom in mol.GetAtoms()])
        elif prop == "num_polonium":
            return sum([atom.GetAtomicNum() == 84 for atom in mol.GetAtoms()])
        
        ## Functional groups
        elif prop == "num_benzene_ring":
            smarts = '[cR1]1[cR1][cR1][cR1][cR1][cR1]1'
            matches = mol.GetSubstructMatches(Chem.MolFromSmarts(smarts))
            return len(matches)
        elif prop == "num_hydroxyl":
            smarts = '[OX2H]'   # Hydroxyl including phenol, alcohol, and carboxylic acid.
            matches = mol.GetSubstructMatches(Chem.MolFromSmarts(smarts))
            return len(matches)
        elif prop == "num_anhydride":
            smarts = '[CX3](=[OX1])[OX2][CX3](=[OX1])'
            matches = mol.GetSubstructMatches(Chem.MolFromSmarts(smarts))
            return len(matches)
        elif prop == "num_aldehyde":
            smarts = '[CX3H1](=O)[#6]'
            matches = mol.GetSubstructMatches(Chem.MolFromSmarts(smarts))
            return len(matches)
        elif prop == "num_ketone":
            smarts = '[#6][CX3](=O)[#6]'
            matches = mol.GetSubstructMatches(Chem.MolFromSmarts(smarts))
            return len(matches)
        elif prop == "num_carboxyl":
            smarts = '[CX3](=O)[OX2H1]'
            matches = mol.GetSubstructMatches(Chem.MolFromSmarts(smarts))
            return len(matches)
        elif prop == "num_ester":
            smarts = '[#6][CX3](=O)[OX2H0][#6]'    # Ester Also hits anhydrides but won't hit formic anhydride.
            matches = mol.GetSubstructMatches(Chem.MolFromSmarts(smarts))
            return len(matches)
        elif prop == "num_amide":
            smarts = '[NX3][CX3](=[OX1])[#6]'
            matches = mol.GetSubstructMatches(Chem.MolFromSmarts(smarts))
            return len(matches)
        elif prop == "num_amine":
            smarts = '[NX3;H2,H1;!$(NC=O)]'    # Primary or secondary amine, not amide.
            matches = mol.GetSubstructMatches(Chem.MolFromSmarts(smarts))
            return len(matches)
        elif prop == "num_nitro":
            smarts = '[$([NX3](=O)=O),$([NX3+](=O)[O-])][!#8]'
            matches = mol.GetSubstructMatches(Chem.MolFromSmarts(smarts))
            return len(matches)
        elif prop == "num_halo":
            smarts = '[F,Cl,Br,I]'
            matches = mol.GetSubstructMatches(Chem.MolFromSmarts(smarts))
            return len(matches)
        elif prop == "num_thioether":
            smarts = '[SX2][CX4]'
            matches = mol.GetSubstructMatches(Chem.MolFromSmarts(smarts))
            return len(matches)
        elif prop == "num_nitrile":
            smarts = '[NX1]#[CX2]'
            matches = mol.GetSubstructMatches(Chem.MolFromSmarts(smarts))
            return len(matches)
        elif prop == "num_thiol":
            smarts = '[#16X2H]'
            matches = mol.GetSubstructMatches(Chem.MolFromSmarts(smarts))
            return len(matches)
        elif prop == "num_sulfide":
            smarts = '[#16X2H0]'    #  Won't hit thiols. Hits disulfides too.
            matches = mol.GetSubstructMatches(Chem.MolFromSmarts(smarts))
            exception = '[#16X2H0][#16X2H0]'
            matches_exception = mol.GetSubstructMatches(Chem.MolFromSmarts(exception))
            return len(matches) - len(matches_exception)
        elif prop == "num_disulfide":
            smarts = '[#16X2H0][#16X2H0]'    
            matches = mol.GetSubstructMatches(Chem.MolFromSmarts(smarts))
            return len(matches)
        elif prop == "num_sulfoxide":
            smarts = '[$([#16X3]=[OX1]),$([#16X3+][OX1-])]'
            matches = mol.GetSubstructMatches(Chem.MolFromSmarts(smarts))
            return len(matches)
        elif prop == "num_sulfone":
            smarts = '[$([#16X4](=[OX1])=[OX1]),$([#16X4+2]([OX1-])[OX1-])]'
            matches = mol.GetSubstructMatches(Chem.MolFromSmarts(smarts))
            return len(matches)
        elif prop == "num_borane":
            smarts = '[BX3]'
            matches = mol.GetSubstructMatches(Chem.MolFromSmarts(smarts))
            return len(matches)

        else:
            raise ValueError(f'Property {prop} not supported')
    
    return inner

def calculate_solubility(smiles):
    ## Calculate aqueous solubility(logS) using RDKit descriptors and a simple linear model.
    try:
        # Convert SMILES to RDKit molecule
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            raise ValueError("Invalid SMILES string")
        
        # Calculate relevant descriptors
        mw = Descriptors.MolWt(mol)
        logp = Descriptors.MolLogP(mol)
        h_bond_donors = Descriptors.NumHDonors(mol)
        h_bond_acceptors = Descriptors.NumHAcceptors(mol)
        rotatable_bonds = Descriptors.NumRotatableBonds(mol)
        # Simple linear model (based on published QSPR models)
        logS = 0.16 - 0.63*logp - 0.0062*mw + 0.066*h_bond_donors - 0.074*h_bond_acceptors
        return logS
    
    except Exception as e:
        # print(f"Error calculating solubility: {e}")
        return 0.0

def compute_statistics(numbers, prop):
    if numbers == []:
        return {
            "mean": 0,
            "variance": 0,
            "min": 0,
            "max": 0,
            "success_rate": 0, # success opt that increase the property
            "best_rate": 0, # rate of best property mol-opt
        }

    easy_thres, hard_thres = 0.5, 0.3
    threshold_dict = {'gsk3b':hard_thres, 'qed':hard_thres, 'drd2':hard_thres, 'jnk3':hard_thres, 'logp':easy_thres, 'solubility':easy_thres}
    
    n = len(numbers)
    mean = sum(numbers) / n
    # Calculate variance (using population variance: 1/N * sum((x_i - mean)^2))
    variance = sum((x - mean) ** 2 for x in numbers) / n
    min_val = min(numbers)
    max_val = max(numbers)
    
    success_rate = sum(1 for itm in numbers if itm > 0) / len(numbers)
    best_rate = sum(1 for itm in numbers if itm >= threshold_dict[prop]) / len(numbers)
    
    return {
        "mean": mean,
        "variance": variance,
        "min": min_val,
        "max": max_val,
        "success_rate": success_rate, # success opt that increase the property
        "best_rate": best_rate, # rate of best property mol-opt
    }

class mol_opt_evaluater():
    def __init__(self, prop=None, ) -> None:
        ## prop: item in ['gsk3b', 'qed', 'drd2', 'jnk3']
        self.prop = prop
        if prop in ['gsk3b', 'drd2', 'jnk3']:
            self.property_oracle = oracles.Oracle(name=prop)
        elif prop == 'solubility':
            self.property_oracle = calculate_solubility
        elif prop in ['qed', 'logp']:
            self.property_oracle = mol_prop(prop)
        else:
            raise ValueError(f"Unknown property: {prop}")
    
    def property_improvement(self, src_mol_list, tgt_mol_list, total_num):
        ## evaluate the property improvement after the mol-opt
        assert len(src_mol_list) == len(tgt_mol_list)
        delete_idx = []
        for i, d in enumerate(tgt_mol_list):
            try:
                mol1 = Chem.MolFromSmiles(d)
                mol2 = Chem.MolFromSmiles(src_mol_list[i])
                if mol1 == None or mol2 == None:
                    delete_idx.append(i)
            except:
                delete_idx.append(i)
        
        try:
            prop_improve_list = [
                0 if i in delete_idx else self.property_oracle(str(tgt_mol_list[i]))-self.property_oracle(str(src_mol_list[i])) for i in range(len(tgt_mol_list))
            ]
        except:
            print(src_mol_list)
            print(tgt_mol_list)
            print(self.prop)
            print(delete_idx)
            raise ValueError
        prop_improve_list = prop_improve_list + [0.0]*(total_num - len(src_mol_list))
        statistic = compute_statistics(prop_improve_list, self.prop)
        return statistic

    def scaffold_consistency(self, src_mol_list, tgt_mol_list):
        ## evaluate the scaffold consistency before&after mol-opt, consistency includes: same or contain
        assert len(src_mol_list) == len(tgt_mol_list)
        
        count_same = 0
        scaffold_score = list()
        
        for i in range(len(tgt_mol_list)):
            src_smiles, tgt_smiles = src_mol_list[i], tgt_mol_list[i]
            try:
                src_mol, tgt_mol = Chem.MolFromSmiles(src_smiles), Chem.MolFromSmiles(tgt_smiles)
            except:
                continue
            
            if src_mol == None or tgt_mol == None:
                scaffold_score.append(0.0)
                continue
            
            opt_smiles = [src_smiles, tgt_smiles]
            murcko_scaffold_list = [MurckoScaffoldSmiles(smiles) if smiles else None for smiles in opt_smiles]
            assert len(murcko_scaffold_list) == 2
            murcko_scaffold_list = [m for m in murcko_scaffold_list if m != None]
            
            if len(murcko_scaffold_list) < 2:
                scaffold_score.append(0.0)
                continue
            elif len(set(murcko_scaffold_list)) == 1:
                scaffold_score.append(1.0)
                count_same += 1
            else:
                ## Morgan Fingerprint for scaffold similarity
                murcko_scaffold_mol_list = [Chem.MolFromSmiles(murcko_scaffold_list[0]), Chem.MolFromSmiles(murcko_scaffold_list[1])]
                mcs = rdFMCS.FindMCS(murcko_scaffold_mol_list)
                mcs_mol = Chem.MolFromSmarts(mcs.smartsString) if mcs.numAtoms > 0 else None
                
                if mcs_mol:
                    # 计算基于指纹的Tanimoto相似度
                    fp1 = AllChem.GetMorganFingerprintAsBitVect(murcko_scaffold_mol_list[0], 2, nBits=1024)
                    fp2 = AllChem.GetMorganFingerprintAsBitVect(murcko_scaffold_mol_list[1], 2, nBits=1024)
                    similarity = DataStructs.TanimotoSimilarity(fp1, fp2)
                else:
                    similarity = 0.0
                
                scaffold_score.append(similarity)  
        
        if len(tgt_mol_list) == 0:
            return 0.0, 0.0
        
        return count_same, sum(scaffold_score)
    
    def CoT_explanation(self, src_mol_list, tgt_mol_list, cot_list):
        ## For explanation, CoT format and correctness
        pass

def smiles_similarity(src_mol_list, tgt_mol_list):
    ## evaluate the smiles similarity before&after mol-opt
    assert len(src_mol_list) == len(tgt_mol_list)
    similarity_score = list()
    for i in range(len(tgt_mol_list)):
        try:
            mol1 = Chem.MolFromSmiles(src_mol_list[i])
            mol2 = Chem.MolFromSmiles(tgt_mol_list[i])

            if mol1 == None or mol2 == None:
                similarity_score.append(0.0)
                continue

            fp1 = AllChem.GetMorganFingerprintAsBitVect(mol1, radius=2, nBits=2048)
            fp2 = AllChem.GetMorganFingerprintAsBitVect(mol2, radius=2, nBits=2048)
            similarity_score.append(DataStructs.FingerprintSimilarity(fp1, fp2))
        except:
            similarity_score.append(0.0)
        
    return max(similarity_score)

def calculate_kendall_tau(gts, preds):
    """
    Computes Kendall's Tau between predicted and ground truth sequences.

    Args:
        gts (List[List[str]]): Ground truth step sequences.
        preds (List[List[str]]): Predicted step sequences.

    Returns:
        float: Average Kendall's Tau score.
    """
    total_pairs = 0
    concordant_pairs = 0

    for gt, pr in zip(gts, preds):
        gt_rank = {step: i for i, step in enumerate(gt)}
        pr_rank = {step: i for i, step in enumerate(pr)}

        for a, b in combinations(gt_rank.keys(), 2):
            gt_order = gt_rank[a] - gt_rank[b]
            pr_order = pr_rank[a] - pr_rank[b]
            if gt_order * pr_order > 0:
                concordant_pairs += 1
            total_pairs += 1

    if total_pairs == 0:
        return 0
    return (2 * concordant_pairs - total_pairs) / total_pairs

def compute_classification_metrics(preds, gts):
    """
    Computes accuracy, precision, recall, and F1 score.

    Args:
        preds (List[bool]): Predicted labels.
        gts (List[bool]): Ground truth labels.

    Returns:
        dict: Dictionary with accuracy, precision, recall, and F1.
    """
    TP = sum((p is False and g is False) for p, g in zip(preds, gts))  # Correctly predicted incorrect
    FP = sum((p is False and g is True) for p, g in zip(preds, gts))   # Incorrectly flagged as incorrect
    FN = sum((p is True and g is False) for p, g in zip(preds, gts))   # Missed incorrect

    accuracy = sum(p == g for p, g in zip(preds, gts)) / len(preds) if preds else 0
    precision = TP / (TP + FP) if (TP + FP) > 0 else 0
    recall = TP / (TP + FN) if (TP + FN) > 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

    return {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1": f1
    }
### 绘制柱状图
# mol_und = [d['eval_result']['acc'] for d in data if d['task'] == 'mol_und']
# mol_edit = [d['eval_result']['acc'] for d in data if d['task'] == 'mol_edit']
# reaction = [d['eval_result']['acc'] for d in data if d['task'] == 'reaction']
# mol_opt = [d['eval_result']['acc'] for d in data if d['task'] == 'mol_opt']



# data = [mol_und, mol_edit, reaction, mol_opt]
# labels = ['mol_und', 'mol_edit', 'reaction', 'mol_opt']

# # 四个子图，绘制柱状图
# fig, axes = plt.subplots(2, 2, figsize=(12, 8))
# fig.suptitle('Accuracy Distribution by Task', fontsize=16, fontweight='bold')

# for i, task in enumerate(labels):
#     ax = axes[i // 2, i % 2]
#     accuracies = data[i]

#     counts, bins_edges, patches = ax.hist(
#         accuracies, 
#         bins=10, 
#         alpha=0.7, 
#         color=plt.cm.Set3(i),
#         edgecolor='black',
#         linewidth=0.5
#     )

#     ax.set_title(task, fontsize=14, fontweight='bold')
#     ax.set_xlabel('Accuracy', fontsize=12, fontweight='bold')
#     ax.set_ylabel('Frequency', fontsize=12, fontweight='bold')
#     ax.set_ylim(0, max(counts) * 1.1 if counts.size > 0 else 1)

#     mean_acc = np.mean(accuracies)
#     std_acc = np.std(accuracies)
#     ax.axvline(mean_acc, color='red', linestyle='--', alpha=0.8, label=f'Mean: {mean_acc:.3f}')
#     ax.text(0.02, 0.98, f'Mean: {mean_acc:.3f}\nStd: {std_acc:.3f}\nN: {len(accuracies)}', transform=ax.transAxes, verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
#     ax.grid(True, alpha=0.3)

# plt.tight_layout()
# plt.savefig('classify_acc.png')


if __name__ == "__main__":
    smiles1 = "O=C1N(C)C(Cc2ccc(OC)cc2)C(=O)N(C)C1COC"
    smiles2 = "O=C1NC(Cc2ccc(O)cc2)C(=O)NC1CO"
    property_oracle = oracles.Oracle(name="logp")
    print(property_oracle(smiles1), property_oracle(smiles2))
