# pip install rdkit-pypi  (if needed)
# Optional for SA score (one of the following):
#   pip install sascorer
#   # or place RDKit's contrib SA scorer (sascorer.py + fpscores.pk) on your PYTHONPATH

from __future__ import annotations
from dataclasses import dataclass
from typing import Iterable, List, Optional, Tuple, Dict
import math, random
from collections import Counter

from rdkit import Chem
from rdkit.Chem import AllChem, DataStructs
from rdkit.Contrib.SA_Score import sascorer
from chem import calculate_qed, calculate_logp, calculate_mw, brics_decomposition_connectivity
import re
from rdkit import Chem
from typing import Dict
from optimizer import optimize
from run import multi_agent_molecule_generation_llm_exact

import random
import json
from tqdm import tqdm
import os
from datetime import datetime

# ---------- Optional SA score ----------
def _get_sascorer():
    """
    Try to import a synthetic accessibility scorer.
    Returns a callable mol->float or None if unavailable.
    """
    # try:
    #     import sascorer  # community package or RDKit contrib file
    #     return sascorer.calculateScore
    # except Exception:
    #     try:
    #         # Some distributions expose it as rdkit.Chem.SA_Score
    #         return sascorer.calculateScore
    #     except Exception:
    #         return None
    return sascorer.calculateScore

SA_SCORER = _get_sascorer()

# ---------- Helpers ----------
def canonical_smiles(mol: Chem.Mol) -> str:
    return Chem.MolToSmiles(mol, isomericSmiles=True, canonical=True)

def morgan_fp(mol: Chem.Mol, radius: int = 2, n_bits: int = 2048):
    return AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits)

def pairwise_sample_indices(n: int, max_pairs: int, rng: random.Random) -> List[Tuple[int,int]]:
    """
    Uniformly sample up to max_pairs unique index pairs (i<j) from n items.
    Exhaustive if small.
    """
    if n < 2:
        return []
    total = n * (n - 1) // 2
    if max_pairs >= total:
        # enumerate all
        idxs = []
        for i in range(n):
            for j in range(i+1, n):
                idxs.append((i, j))
        return idxs
    # Rejection-sample pairs
    seen = set()
    out = []
    while len(out) < max_pairs:
        i = rng.randrange(n)
        j = rng.randrange(n)
        if i == j:
            continue
        a, b = (i, j) if i < j else (j, i)
        if (a, b) in seen:
            continue
        seen.add((a, b))
        out.append((a, b))
    return out

@dataclass
class MetricResults:
    n_input: int
    n_valid: int
    validity: float
    n_unique_valid: int
    uniqueness: float
    novelty: Optional[float]
    sa_mean: Optional[float]
    sa_median: Optional[float]
    diversity_mean: Optional[float]
    diversity_median: Optional[float]

def evaluate_molecule_set(
    gen_smiles: Iterable[str],
    *,
    reference_smiles: Optional[Iterable[str]] = None,
    radius: int = 2,
    n_bits: int = 2048,
    max_pairs: int = 50_000,
    seed: int = 0,
) -> Tuple[MetricResults, List[Dict]]:
    """
    Compute Validity/Uniqueness/Novelty, SA score, and Diversity from a list of SMILES.
    - Novelty is computed w.r.t. canonical SMILES in `reference_smiles` (if provided)
    - Diversity is 1 - mean Tanimoto(similarity) over ECFP4 (sampled if many)
    Returns (metrics, per_molecule_records)
    """
    rng = random.Random(seed)
    gen_smiles = list(gen_smiles)
    per_mol: List[Dict] = []

    # Parse & canonicalize
    valid_mols = []
    valid_cans = []
    for s in gen_smiles:
        mol = Chem.MolFromSmiles(s)
        if mol is None:
            per_mol.append({"input_smiles": s, "valid": False, "canonical_smiles": None, "sa_score": None})
        else:
            can = canonical_smiles(mol)
            valid_mols.append(mol)
            valid_cans.append(can)
            per_mol.append({"input_smiles": s, "valid": True, "canonical_smiles": can, "sa_score": None})

    n_input = len(gen_smiles)
    n_valid = len(valid_mols)
    validity = n_valid / n_input if n_input > 0 else 0.0

    # Uniqueness among valid (by canonical SMILES)
    unique_valid_cans = sorted(set(valid_cans))
    n_unique_valid = len(unique_valid_cans)
    uniqueness = (n_unique_valid / n_valid) if n_valid > 0 else 0.0

    # Novelty vs reference (unique valid only)
    novelty = None
    if reference_smiles is not None:
        ref_cans = set()
        for s in reference_smiles:
            m = Chem.MolFromSmiles(s)
            if m is not None:
                ref_cans.add(canonical_smiles(m))
        if n_unique_valid > 0:
            n_novel = sum(1 for s in unique_valid_cans if s not in ref_cans)
            novelty = n_novel / n_unique_valid
        else:
            novelty = 0.0

    # SA score (if scorer available)
    sa_scores = []
    if SA_SCORER is not None:
        for rec in per_mol:
            if rec["valid"]:
                mol = Chem.MolFromSmiles(rec["canonical_smiles"])
                s = SA_SCORER(mol)
                rec["sa_score"] = float(s)
                sa_scores.append(float(s))
        if sa_scores:
            sa_scores_sorted = sorted(sa_scores)
            sa_mean = sum(sa_scores_sorted) / len(sa_scores_sorted)
            mid = len(sa_scores_sorted) // 2
            sa_median = (sa_scores_sorted[mid] if len(sa_scores_sorted) % 2 == 1
                         else 0.5 * (sa_scores_sorted[mid-1] + sa_scores_sorted[mid]))
        else:
            sa_mean = sa_median = None
    else:
        sa_mean = sa_median = None  # scorer not available

    # Diversity over unique valid molecules (avoid duplicate bias)
    diversity_mean = diversity_median = None
    if n_unique_valid >= 2:
        # Build mols/fps for unique set
        can_to_mol = {}
        for rec in per_mol:
            if rec["valid"]:
                can_to_mol.setdefault(rec["canonical_smiles"], Chem.MolFromSmiles(rec["canonical_smiles"]))
        uniq_mols = list(can_to_mol.values())
        fps = [morgan_fp(m, radius, n_bits) for m in uniq_mols]

        # Sample pairs
        pairs = pairwise_sample_indices(len(fps), max_pairs=max_pairs, rng=rng)
        sims = []
        for i, j in pairs:
            sim = DataStructs.TanimotoSimilarity(fps[i], fps[j])
            sims.append(sim)
        if sims:
            # diversity = 1 - similarity
            divs = [1.0 - s for s in sims]
            divs.sort()
            diversity_mean = sum(divs) / len(divs)
            mid = len(divs) // 2
            diversity_median = (divs[mid] if len(divs) % 2 == 1
                                else 0.5 * (divs[mid-1] + divs[mid]))
        else:
            diversity_mean = diversity_median = 0.0

    metrics = MetricResults(
        n_input=n_input,
        n_valid=n_valid,
        validity=validity,
        n_unique_valid=n_unique_valid,
        uniqueness=uniqueness,
        novelty=novelty,
        sa_mean=sa_mean,
        sa_median=sa_median,
        diversity_mean=diversity_mean,
        diversity_median=diversity_median,
    )
    return metrics, per_mol

# === Utility Functions ===

def extract_smiles(output_text):
    match = re.search(r"<SMILES>(.+?)</SMILES>", output_text)
    return match.group(1).strip() if match else None

def is_valid_smiles(smiles):
    try:
        return Chem.MolFromSmiles(smiles) is not None
    except:
        return False

def optimize_best(input_text, query_qed, query_logp, query_mw, num_samples=10):
    results = []

    for _ in range(20):  # Try up to 20 times to collect 10 valid outputs
        try:
            output_text = optimize(input_text)
            # print(output_text)
            smiles = extract_smiles(output_text).replace(' ', '')
            if not smiles or not is_valid_smiles(smiles):
                continue

            qed = calculate_qed(smiles)
            logp = calculate_logp(smiles)
            mw = calculate_mw(smiles)

            qed_error = abs(qed - query_qed)
            logp_error = abs(logp - query_logp)
            mw_error = abs(mw - query_mw)

            # Normalize errors
            norm_qed = qed_error / 1.0
            norm_logp = logp_error / 10.0
            norm_mw = mw_error / 700.0
            total_normalized_error = norm_qed + norm_logp + norm_mw

            results.append({
                "smiles": smiles,
                "qed": qed,
                "logp": logp,
                "mw": mw,
                "qed_error": qed_error,
                "logp_error": logp_error,
                "mw_error": mw_error,
                "normalized_error": total_normalized_error
            })

            if len(results) >= num_samples:
                break

        except Exception as e:
            print("Skipped due to error:", e)
            continue

    if not results:
        print("❌ No valid SMILES generated.")
        return None

    # Select best per-property and best-overall
    best_qed = min(results, key=lambda r: r["qed_error"])
    best_logp = min(results, key=lambda r: r["logp_error"])
    best_mw = min(results, key=lambda r: r["mw_error"])
    best_overall = min(results, key=lambda r: r["normalized_error"])

    return [
        best_qed["smiles"],
        best_logp["smiles"],
        best_mw["smiles"],
        best_overall["smiles"]
    ]


# def extract_target_properties(query: str) -> Dict[str, float]:
#     qed_match = re.search(r'qed\s*=?\s*([\d.]+)', query, re.I)
#     logp_match = re.search(r'logp\s*=?\s*([\d.]+)', query, re.I)
#     mw_match = re.search(r'(mw|molecular weight)\s*=?\s*([\d.]+)', query, re.I)

#     return {
#         'qed': float(qed_match.group(1)) if qed_match else None,
#         'logp': float(logp_match.group(1)) if logp_match else None,
#         'mw': float(mw_match.group(2)) if mw_match else None,
#     }

def extract_target_properties(user_query: str) -> dict:
        """
        Extract QED, LogP, and MW from user query.
        Returns a Python dictionary with float values, supports negative values.
        """

        def extract_value(pattern: str, text: str):
            match = re.search(pattern, text, re.IGNORECASE)
            if match:
                try:
                    return float(match.group(1))
                except:
                    return None
            return None

        # Now supporting negative numbers (optional - or +)
        qed = extract_value(r'\bqed\s*=?\s*([-+]?[0-9]*\.?[0-9]+)', user_query)
        logp = extract_value(r'\blogp\s*=?\s*([-+]?[0-9]*\.?[0-9]+)', user_query)
        mw = extract_value(r'\b(?:molecular weight|mw)\s*=?\s*([-+]?[0-9]*\.?[0-9]+)', user_query)

        constraints = {}
        if qed is not None:
            constraints['qed'] = round(qed, 3)
        if logp is not None:
            constraints['logp'] = round(logp, 3)
        if mw is not None:
            constraints['mw'] = round(mw, 3)

        return constraints

def determine_direction(current: float, target: float) -> str:
    if target is None or current is None:
        return "unknown"
    return "higher" if target > current else "lower"

# === Main Evaluation Function ===

def evaluate_single_query(user_query: str, molecule_json_path: str):
    # Step 1: Get local optimal SMILES
    local_optimal_smiles,steps = multi_agent_molecule_generation_llm_exact(user_query, molecule_json_path)
    print(local_optimal_smiles)

    # Step 2: Get fragments from BRICS
    fragments,bonds = brics_decomposition_connectivity(local_optimal_smiles)
    print(fragments)

    # Step 3: Get local optimal properties
    qed_local = calculate_qed(local_optimal_smiles)
    logp_local = calculate_logp(local_optimal_smiles)
    mw_local = calculate_mw(local_optimal_smiles)

    # Step 4: Extract user target values
    target_props = extract_target_properties(user_query)
    print(target_props)

    # Step 5: Determine optimization direction
    qed_dir = determine_direction(qed_local, target_props['qed'])
    logp_dir = determine_direction(logp_local, target_props['logp'])
    mw_dir = determine_direction(mw_local, target_props['mw'])

    qed_error_local = abs(qed_local - target_props['qed']) if target_props['qed'] else None
    logp_error_local = abs(logp_local - target_props['logp']) if target_props['logp'] else None
    mw_error_local = abs(mw_local - target_props['mw']) if target_props['mw'] else None

    # Step 6: Format prompt using relative direction
    prompt = f"""Given the intermediate molecule SMILES <SMILES>{local_optimal_smiles}</SMILES>, \
which is composed of fragments {fragments}. Propose a single replace, add or remove step on fragment level \
that makes the new molecule's QED <QED>{qed_error_local:.3f}</QED> {qed_dir}, LogP <LogP>{logp_error_local:.3f}</LogP> {logp_dir}, \
and Molecular Weight <MW>{mw_error_local:.3f}</MW> {mw_dir}."""

    # Step 7: Optimize with LLM
    outputs = optimize_best(prompt,target_props['qed'],target_props['logp'],target_props['mw'])
    results = []

    for output in outputs:

        generated_smiles = output

        # Step 9: Compute new properties
        qed_new = calculate_qed(generated_smiles)
        logp_new = calculate_logp(generated_smiles)
        mw_new = calculate_mw(generated_smiles)

        # Step 10: Compute error from target
        qed_error_target = abs(qed_new - target_props['qed']) if target_props['qed'] else None
        logp_error_target = abs(logp_new - target_props['logp']) if target_props['logp'] else None
        mw_error_target = abs(mw_new - target_props['mw']) if target_props['mw'] else None


        # === Final Result ===
        result = {
            'original_smiles': local_optimal_smiles,
            'chain_of_thought': steps,
            'generated_smiles': generated_smiles,
            'target_qed': target_props['qed'],
            'target_logp': target_props['logp'],
            'target_mw': target_props['mw'],
            'local_qed': qed_local,
            'local_logp': logp_local,
            'local_mw': mw_local,
            'new_qed': qed_new,
            'new_logp': logp_new,
            'new_mw': mw_new,
            'qed_error_vs_target': qed_error_target,
            'logp_error_vs_target': logp_error_target,
            'mw_error_vs_target': mw_error_target,
            'qed_error_vs_local': qed_error_local,
            'logp_error_vs_local': logp_error_local,
            'mw_error_vs_local': mw_error_local,
            'normlized_total_error_vs_target': qed_error_target+logp_error_target/10.0+mw_error_target/700.0,
            'normlized_total_error_vs_local': qed_error_local+logp_error_local/10.0+mw_error_local/700.0,
            'prompt': prompt,
            'raw_output': output
        }
        results.append(result)

    return results

from typing import Dict, Optional

# def extract_target_properties(query: str) -> Dict[str, Optional[float]]:
#     qed_match = re.search(r'\bqed\s*=\s*([\d.]+)', query, re.IGNORECASE)
#     logp_match = re.search(r'\blogp\s*=\s*([\d.]+)', query, re.IGNORECASE)
#     mw_match = re.search(r'\b(?:mw|molecular weight)\s*=\s*([\d.]+)', query, re.IGNORECASE)

#     return {
#         'qed': float(qed_match.group(1)) if qed_match else None,
#         'logp': float(logp_match.group(1)) if logp_match else None,
#         'mw': float(mw_match.group(1)) if mw_match else None,
#     }

def generate_random_query():
    qed = round(random.uniform(0.4, 0.9), 2)
    logp = round(random.uniform(-2, 6), 2)
    mw = round(random.uniform(200, 600), 1)
    return f"Please help me generate a new valid molecule with qed={qed}, logp={logp}, molecular weight={mw}"