# main.py

import openai
import os
from rdkit import Chem
import json
import random

# Local imports
from decomposer import QuestionDecomposerLLM
from retriever import Retriever
from action import ActionAgentLLM
from evaluator import RDKitEvaluator
from rdkit.Chem import QED, Crippen
from rdkit.Chem import Descriptors
from rdkit import Chem
import pandas as pd

import json, os, math, re
from pathlib import Path
from datetime import datetime
from tqdm import tqdm
import pandas as pd

# your project imports (must exist in your env)
from optimizer import optimize                # single-shot optimizer
from chem import (
    calculate_qed, calculate_logp, calculate_mw,
    brics_decomposition_connectivity
)


def multi_agent_molecule_generation_llm(user_query: str, dataset):
    """
    High-level orchestration with GPT-4 integration for
    both question decomposition and chain-of-thought generation.
    """

    # 1) Use GPT-4 for condition extraction
    decomposer = QuestionDecomposerLLM(model_name="gpt-4")
    constraints = decomposer.decompose(user_query)
    print(constraints)

    # 2) Retrieve relevant context (molecules, fragments)
    retriever = Retriever(dataset)
    retrieved_context = retriever.retrieve(constraints)
    #retrieved_context = ''
    random.shuffle(retrieved_context)
    #print(retrieved_context[:50])
    print('relevant molecule number:',len(retrieved_context))

    # 3) GPT-4-based action agent
    action_agent = ActionAgentLLM(dataset, retrieved_context, constraints, model_name="gpt-4o")

    # 4) RDKit evaluator
    evaluator = RDKitEvaluator()

    chain_of_thought = []
    candidate_smiles = 'start from scratch'

    max_iterations = 10  # limit the chain-of-thought loops

    for step in range(max_iterations):
        # The agent generates or modifies a molecule
        candidate_smiles, local_thought = action_agent.run_generation(candidate_smiles,chain_of_thought)

        # Merge chain-of-thought
        chain_of_thought += local_thought

        # Evaluate
        feedback = evaluator.evaluate(candidate_smiles, constraints)
        if feedback == 'All constraints are satisfied.':
            print(f"\nSUCCESS at step {step+1}: {candidate_smiles}")
            print("Final chain-of-thought (hidden from end-users):", chain_of_thought)
            return candidate_smiles, step
        else:
            # Provide feedback
            action_agent.receive_feedback(feedback)
            chain_of_thought.append(f"Evaluator says: {feedback}")

    print("\nFAILED to generate a valid molecule within iteration limit.")
    return None, max_iterations


import random

def multi_agent_molecule_generation_llm_exact(user_query: str, dataset, tol_range=0.05,retrieval_num = 30, iter_num = 6):
    """
    High-level orchestration with GPT-4 integration for
    both question decomposition and chain-of-thought generation.
    """

    # 1) Use GPT-4 for condition extraction
    decomposer = QuestionDecomposerLLM(model_name="gpt-4")
    constraints = decomposer.decompose_exact_hardcoded(user_query)
    print("Constraints:", constraints)

    # 2) Retrieve relevant context
    retriever = Retriever(dataset)
    retrieved_context = retriever.retrieve_exact(constraints, tol_range)
    random.shuffle(retrieved_context)
    print('Relevant molecule', retrieved_context)

    if len(retrieved_context) > retrieval_num:
        retrieved_context = retrieved_context[0:retrieval_num]

    # 3) Action agent
    action_agent = ActionAgentLLM(dataset, retrieved_context, constraints, model_name="gpt-4o")

    # 4) Evaluator
    evaluator = RDKitEvaluator()

    chain_of_thought = []
    candidate_smiles = 'start from scratch'

    max_iterations = iter_num
    best_valid_smiles = None
    best_step = -1

    for step in range(max_iterations):
        candidate_smiles, local_thought = action_agent.run_generation_exact(candidate_smiles, chain_of_thought)

        chain_of_thought += local_thought
        feedback = evaluator.evaluate_exact(candidate_smiles, constraints)

        if Chem.MolFromSmiles(candidate_smiles):
            best_valid_smiles = candidate_smiles
            best_step = step

        if feedback == 'All constraints are satisfied.':
            print(f"\n✅ SUCCESS at step {step + 1}: {candidate_smiles}")
            print("Final chain-of-thought (hidden from end-users):", chain_of_thought)
            return candidate_smiles, step

        # Add feedback to CoT and continue
        action_agent.receive_feedback(feedback)
        chain_of_thought.append(f"Evaluator says: {feedback}")

    print("\n❌ FAILED to generate a fully valid molecule within iteration limit.")
    if best_valid_smiles:
        print(f"Returning best valid molecule seen at step {best_step + 1}: {best_valid_smiles}")
        return best_valid_smiles, best_step + 1
    else:
        print("No valid molecule was ever produced.")
        return candidate_smiles, max_iterations



def _has_element(smiles: str, element_symbol) -> bool:
        """
        Returns True if the SMILES string contains the specified element symbol, False otherwise.
        """
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return False  # Invalid SMILES

        for atom in mol.GetAtoms():
            if atom.GetSymbol() == element_symbol:
                return True

        return False

# --- Notebook cell 2: helpers ---

from typing import Dict, List, Any, Tuple, Optional

def norm_total_error(qed, logp, mw, t_qed, t_logp, t_mw):
    """Normalized L1 error you used before."""
    return abs(qed - t_qed) / 1.0 + abs(logp - t_logp) / 10.0 + abs(mw - t_mw) / 700.0

def build_prompt_for_second_pass(curr_smiles, t_qed, t_logp, t_mw):
    """
    Compute current properties, deltas/directions vs target, add BRICS fragments,
    and ask for one fragment-level edit (replace/add/remove).
    """
    # current props
    q = calculate_qed(curr_smiles)
    l = calculate_logp(curr_smiles)
    m = calculate_mw(curr_smiles)

    # deltas + directions
    dq = abs(q - t_qed)
    dl = abs(l - t_logp)
    dm = abs(m - t_mw)
    qed_dir  = "higher" if t_qed  > q else "lower"
    logp_dir = "higher" if t_logp > l else "lower"
    mw_dir   = "higher" if t_mw   > m else "lower"

    # fragments (best-effort; fall back if BRICS fails)
    try:
        fragments, _ = brics_decomposition_connectivity(curr_smiles)
    except Exception:
        fragments = []

    prompt = (
        f"Given the intermediate molecule SMILES <SMILES>{curr_smiles}</SMILES>, "
        f"which is composed of fragments {fragments}. "
        f"Propose a single replace, add or remove step on fragment level "
        f"that makes the new molecule's QED <QED>{dq:.3f}</QED> {qed_dir}, "
        f"LogP <LogP>{dl:.3f}</LogP> {logp_dir}, and Molecular Weight "
        f"<MW>{dm:.3f}</MW> {mw_dir}."
    )
    return prompt, (q, l, m, dq, dl, dm)

def evaluate_smiles(smiles, t_qed, t_logp, t_mw):
    q = calculate_qed(smiles)
    l = calculate_logp(smiles)
    m = calculate_mw(smiles)
    e = norm_total_error(q, l, m, t_qed, t_logp, t_mw)
    return {"smiles": smiles, "qed": q, "logp": l, "mw": m, "norm_err": e}

from typing import Dict, List, Any, Tuple

def optimize_once(prompt, t_qed, t_logp, t_mw) -> Tuple[List[Dict[str, Any]], str]:
    """
    Single-shot pass through `optimize`, extracting <SMILES>...</SMILES>.
    Returns ([candidate_dict] or [], raw_output_str).
    """
    out = optimize(prompt)
    m = re.search(r"<SMILES>(.+?)</SMILES>", out or "")
    if not m:
        return [], (out or "")
    smi = m.group(1).strip().replace(" ", "")
    try:
        cand = evaluate_smiles(smi, t_qed, t_logp, t_mw)
        return [cand], out
    except Exception:
        return [], out or ""

def optimize_sampled(
    prompt, t_qed, t_logp, t_mw, tries: int = 20, keep_n: int = 10
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
    """
    Try up to `tries` optimize() calls, keep up to `keep_n` UNIQUE valid molecules.
    Returns:
      kept: best-by-norm_err candidates (len<=keep_n)
      trials: [{"raw_output": str, "candidate": {...}|None}, ...] in sampling order
    """
    uniq: Dict[str, Dict[str, Any]] = {}
    trials: List[Dict[str, Any]] = []

    runs = 0
    while runs < tries and len(uniq) < keep_n:
        cands, out = optimize_once(prompt, t_qed, t_logp, t_mw)
        runs += 1
        trials.append({"raw_output": out, "candidate": cands[0] if cands else None})
        for c in cands:
            s = c["smiles"]
            if s not in uniq or c["norm_err"] < uniq[s]["norm_err"]:
                uniq[s] = c

    kept = sorted(uniq.values(), key=lambda x: x["norm_err"])[:keep_n]
    return kept, trials

def multi_hop_optimize_smiles(
    start_smiles: str,
    t_qed: float,
    t_logp: float,
    t_mw: float,
    *,
    hops: int = 3,
    tries_per_hop: int = 20,
    keep_n_per_hop: int = 10,
    early_stop_delta: float = 0.0,
    patience: int = 1,
    verbose: bool = True,
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
    """
    Returns:
      best_overall: {"smiles","qed","logp","mw","norm_err"}
      best_path:    list of steps (prompt+raw_output chain) that produced best_overall, e.g.
                    [
                      {"hop": 1, "from": <smiles0>, "to": <smiles1>, "prompt": "...", "raw_output": "..."},
                      ...,
                      {"hop": h*, "from": <smiles_{h*-1}>, "to": <best_smiles>, "prompt": "...", "raw_output": "..."}
                    ]
    """
    # Evaluate starting point
    try:
        current = evaluate_smiles(start_smiles, t_qed, t_logp, t_mw)
    except Exception as e:
        raise ValueError(f"Invalid start_smiles for optimization: {e}")

    best_overall = dict(current)
    accepted_chain: List[Dict[str, Any]] = []  # only accepted transitions
    best_path: List[Dict[str, Any]] = []       # accepted_chain up to hop, plus final best hop
    no_improve_streak = 0

    if verbose:
        print(f"[Init] SMILES={current['smiles']},"
              f" QED={current['qed']:.3f}, LogP={current['logp']:.3f}, MW={current['mw']:.3f},"
              f" NTE={current['norm_err']:.3f}")

    for h in range(1, hops + 1):
        # Build prompt for current
        prompt, (q, l, m, dq, dl, dm) = build_prompt_for_second_pass(
            current["smiles"], t_qed, t_logp, t_mw
        )

        # Sample this hop
        cands, trials = optimize_sampled(
            prompt, t_qed, t_logp, t_mw,
            tries=tries_per_hop,
            keep_n=keep_n_per_hop
        )

        if not cands:
            if verbose:
                print(f"[Hop {h}] No valid candidates; stopping.")
            break

        # Best candidate of this hop
        best_hop = cands[0]

        # Find the raw output text that produced best_hop
        best_raw = None
        for t in trials:
            c = t.get("candidate")
            if c and c["smiles"] == best_hop["smiles"] and abs(c["norm_err"] - best_hop["norm_err"]) < 1e-12:
                best_raw = t["raw_output"]
                break

        if verbose:
            print(f"[Hop {h}] Best NTE={best_hop['norm_err']:.3f} "
                  f"(QED={best_hop['qed']:.3f}, LogP={best_hop['logp']:.3f}, MW={best_hop['mw']:.3f})")

        # If this is the new global best, update the best_path to include:
        #   accepted_chain so far + this hop (prompt & raw_output) from current->best_hop
        if best_hop["norm_err"] < best_overall["norm_err"]:
            best_overall = best_hop
            best_path = accepted_chain.copy() + [{
                "hop": h,
                "from": current["smiles"],
                "to": best_hop["smiles"],
                "prompt": prompt,
                "raw_output": best_raw
            }]

        # Accept this hop as the new working molecule only if it improves current by early_stop_delta
        if best_hop["norm_err"] + early_stop_delta < current["norm_err"]:
            accepted_chain.append({
                "hop": h,
                "from": current["smiles"],
                "to": best_hop["smiles"],
                "prompt": prompt,
                "raw_output": best_raw
            })
            current = best_hop
            no_improve_streak = 0
        else:
            no_improve_streak += 1
            if verbose:
                print(f"[Hop {h}] No sufficient improvement (Δ<{early_stop_delta}); "
                      f"streak={no_improve_streak}/{patience}.")
            if no_improve_streak > patience:
                if verbose:
                    print(f"[Stop] Early stop triggered after {h} hops (patience exceeded).")
                break

    if verbose:
        print(f"[Done] Best overall NTE={best_overall['norm_err']:.3f} "
              f"(QED={best_overall['qed']:.3f}, LogP={best_overall['logp']:.3f}, MW={best_overall['mw']:.3f})")

    # If no hop ever improved global best (rare), best_path stays empty.
    # You can still log the initial state externally if needed.

    return best_overall, best_path


