import re
import logging
from typing import Dict, List, Union, Set, Tuple, Optional
import rdkit
from rdkit import Chem
from collections import Counter
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from examples.evaluations.text_translation_metrics import text_evaluate
from examples.evaluations.mol_translation_metrics import mol_evaluate
from examples.evaluations.fingerprint_metrics import molfinger_evaluate
import numpy as np
from transformers import AutoTokenizer
import ast

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def canonicalize_smiles(smiles: str) -> Tuple[str, bool]:
    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return smiles, False
        return Chem.MolToSmiles(mol, canonical=True), True
    except:
        return smiles, False

def format_reward_general(predict: str) -> float:

    pattern = re.compile(r"<think>.*?</think>\s*<answer>.*?</answer>", re.DOTALL)
    format_match = re.fullmatch(pattern, predict.strip())
    return 1.0 if format_match else 0.0

def format_reward_captioning(predict: str) -> float:

    pattern = re.compile(r"<think>.*?</think>\s*<answer>The molecule is .*?</answer>", re.DOTALL)
    format_match = re.fullmatch(pattern, predict.strip())
    return 1.0 if format_match else 0.0

def extract_answer_content(text: str) -> str:

    m = re.search(r"<answer>(.*?)</answer>", text, flags=re.DOTALL | re.IGNORECASE)
    if m:
        return m.group(1).strip()
    return ""

def extract_answer_content_class(text: str) -> str:

    answer = extract_answer_content(text)
    if answer.lower() in {"yes", "no"}:
        return answer.capitalize()
    matches = list(re.finditer(r'\b(Yes|No)\b', text, flags=re.IGNORECASE))
    if matches:
        return matches[-1].group(1).capitalize()
    return ""

def extract_captioning_content(answer: str) -> str:
 
    answer = answer.strip()
    match = re.search(r"<answer>(.*?)</answer>", answer, re.DOTALL)
    if match:
        content = match.group(1).strip()
        if content.startswith("The molecule is"):
            return content[len("The molecule is"):].strip()
        else:
            return content
    return answer

def evaluate_binary_classification(predict: str, ground_truth: str) -> float:
    predict = extract_answer_content_class(predict)
    if not predict:
        return 0.0
    return 1.0 if predict.strip().lower() == ground_truth.strip().lower() else 0.0

def evaluate_name(predict: str, ground_truth: str) -> float:
  
    predict = extract_answer_content(predict)
    if not predict:
        return 0.0
    return 1.0 if predict.strip().lower() == ground_truth.strip().lower() else 0.0

def evaluate_smiles(predict: str, ground_truth: str) -> float:
    """评估单个SMILES"""
    predict = extract_answer_content(predict)
    if not predict:
        return 0.0
    pred_smiles, pred_valid = canonicalize_smiles(predict)
    if not pred_valid:
        return 0.0
    true_smiles, true_valid = canonicalize_smiles(ground_truth)
    if not true_valid:
        return 0.0
    return 1.0 if pred_smiles == true_smiles else 0.0

def evaluate_multi_smiles(predict: str, ground_truth: str) -> float:
   
    predict = extract_answer_content(predict)
    if not predict:
        return 0.0
    predict_smiles = []
    for s in predict.split("."):
        canonical, valid = canonicalize_smiles(s.strip())
        if not valid:
            return 0.0
        predict_smiles.append(canonical)
    truth_smiles = []
    for s in ground_truth.split("."):
        canonical, valid = canonicalize_smiles(s.strip())
        if not valid:
            return 0.0
        truth_smiles.append(canonical)
    return 1.0 if Counter(predict_smiles) == Counter(truth_smiles) else 0.0

def evaluate_ligand_selection(predict: str, ground_truth_list: List[str]) -> float:
   
    predict = extract_answer_content(predict)
    if not predict:
        return 0.0
    pred_smiles, pred_valid = canonicalize_smiles(predict)
    if not pred_valid:
        return 0.0
    for gt in ground_truth_list:
        true_smiles, true_valid = canonicalize_smiles(gt)
        if not true_valid:
            continue
        if pred_smiles == true_smiles:
            return 1.0
    return 0.0

def compute_score(predicts: List[str], ground_truths: List[Union[str, List[str]]], tasks: Optional[List[str]] = None) -> List[Dict[str, float]]:

    if tasks is None:
        if isinstance(ground_truths, dict) and "task" in ground_truths:
            tasks = ground_truths["task"]
            ground_truths = ground_truths["ground_truth"]
        else:
            tasks = ["smiles2iupac"] * len(predicts)
    
    if not (len(predicts) == len(ground_truths) == len(tasks)):
        raise ValueError(f"Length mismatch: predicts({len(predicts)}), ground_truths({len(ground_truths)}), tasks({len(tasks)})")

    tokenizer = None
    try:
        tokenizer = AutoTokenizer.from_pretrained('/mnt/shared-storage-user/wangweida/models--bert-base-uncased/snapshots/86b5e0934494bd15c9632b12f734a8a67f723594', trust_remote_code=True)
    except Exception as e:
        logger.warning(f"{e}")
    
    TOMG_TASKS = {
        "AddComponent", "AtomNum", "BondNum", "DelComponent", "FunctionalGroup", 
        "LogP", "MR", "QED", "SubComponent"
    }

    scores = []
    for predict, ground_truth, task in zip(predicts, ground_truths, tasks):
        
        molecule_captioning_score = np.nan
        molecule_design_score = np.nan
        property_prediction_score = np.nan
        yield_prediction_score = np.nan
        reactant_selection_score = np.nan
        solvent_selection_score = np.nan
        retrosynthesis_score = np.nan
        reaction_prediction_score = np.nan
        ligand_selection_score = np.nan
        smiles2iupac_score = np.nan
        iupac2smiles_score = np.nan

        current_accuracy_score = 0.0
        format_score = 0.0
        
        try:
            predict = re.sub(r"\s*(<|>|/)\s*", r"\1", predict)

            if task == "Molecule_Captioning":
                format_score = format_reward_captioning(predict)
                model_answer = extract_captioning_content(predict)
            else:
                format_score = format_reward_general(predict)
                if task in ["Property_prediction", "Yield_Prediction"]:
                     model_answer = extract_answer_content_class(predict)
                else:
                     model_answer = extract_answer_content(predict)

            if task == "Molecule_Captioning":
                try:
                    if tokenizer:
                        bleu2, bleu4, rouge1, rouge2, rougel, meteor = text_evaluate(tokenizer, ground_truth, model_answer)
                        molecule_captioning_score = (bleu2 + bleu4 + rouge1 + rouge2 + rougel + meteor) / 6
                except Exception as e:
                    logger.warning(f"Text evaluation failed: {str(e)}")
                current_accuracy_score = molecule_captioning_score
            
            elif task == "Molecule_Design":
                try:
                    bleu, exact_match, levenshtein, validity = mol_evaluate(ground_truth, model_answer)
                    molecule_design_score = exact_match
                except Exception as e:
                    logger.warning(f"Molecule evaluation failed: {str(e)}")
                current_accuracy_score = 0.4 * bleu + 0.2 * exact_match + 0.4 * validity
            
            elif task == "Property_prediction":
                property_prediction_score = evaluate_binary_classification(predict, ground_truth)
                current_accuracy_score = property_prediction_score

            elif task == "Yield_Prediction":
                yield_prediction_score = evaluate_binary_classification(predict, ground_truth)
                current_accuracy_score = yield_prediction_score

            elif task == "Reactant_Selection":
                reactant_selection_score = evaluate_smiles(predict, ground_truth)
                current_accuracy_score = reactant_selection_score
            
            elif task == "Solvent_Selection":
                solvent_selection_score = evaluate_smiles(predict, ground_truth)
                current_accuracy_score = solvent_selection_score
            
            elif task == "Retrosynthesis":
                retrosynthesis_score = evaluate_multi_smiles(predict, ground_truth)
                current_accuracy_score = retrosynthesis_score

            elif task == "Reaction_Prediction":
                reaction_prediction_score = evaluate_multi_smiles(predict, ground_truth)
                current_accuracy_score = reaction_prediction_score
            
            elif task == "Ligand_Selection":
                try:
                    ground_truth = ast.literal_eval(ground_truth)
                except (ValueError, SyntaxError):
                    logger.warning(f"Could not parse ground_truth string for Ligand_Selection: {ground_truth}")
                    pass 
                print(f"[DEBUG] Ground Truth Type: {type(ground_truth)}")
                print(f"Ground Truth Value: {ground_truth}")
                ligand_selection_score = evaluate_ligand_selection(predict, ground_truth) if isinstance(ground_truth, list) else 0.0
                current_accuracy_score = ligand_selection_score
            
            elif task == "smiles2iupac":
                smiles2iupac_score = evaluate_name(predict, ground_truth)
                current_accuracy_score = smiles2iupac_score

            elif task == "iupac2smiles":
                iupac2smiles_score = evaluate_smiles(predict, ground_truth)
                current_accuracy_score = iupac2smiles_score
            
            # elif task in TOMG_TASKS:
            #     tomg_score = evaluate_smiles(predict, ground_truth)
            #     current_accuracy_score = tomg_score

            overall_score = 0.9 * current_accuracy_score + 0.1 * format_score
            
            scores.append({
                "overall": overall_score,
                "format": format_score,
                "molecule_captioning_score": molecule_captioning_score,
                "molecule_design_score": molecule_design_score,
                "property_prediction_score": property_prediction_score,
                "yield_prediction_score": yield_prediction_score,
                "reactant_selection_score": reactant_selection_score,
                "solvent_selection_score": solvent_selection_score,
                "retrosynthesis_score": retrosynthesis_score,
                "reaction_prediction_score": reaction_prediction_score,
                "ligand_selection_score": ligand_selection_score,
                "smiles2iupac_score": smiles2iupac_score,
                "iupac2smiles_score": iupac2smiles_score,
                # "tomg_score": tomg_score,
            })

        except Exception as e:
            logger.error(f"{e}")
            scores.append({
                "overall": 0.0,
                "format": 0.0,
                "molecule_captioning_score": 0.0,
                "molecule_design_score": 0.0,
                "property_prediction_score": 0.0,
                "yield_prediction_score": 0.0,
                "reactant_selection_score": 0.0,
                "solvent_selection_score": 0.0,
                "retrosynthesis_score": 0.0,
                "reaction_prediction_score": 0.0,
                "ligand_selection_score": 0.0,
                "smiles2iupac_score": 0.0,
                "iupac2smiles_score": 0.0,
                # "tomg_score": 0.0,
            })

    return scores
